diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b401684148002105bb984563aa902f076e7e86a1
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,24 @@
+Copyright (c) 2024 Boston Dynamics AI Institute LLC
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+1. Redistributions of source code must retain the copyright notice included
+with the software, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the copyright notice, this
+list of conditions and the following disclaimer in the documentation and/or
+other materials provided with the distribution.
+3. Modified versions of the software must be conspicuously marked as such.
+4. The software may only be used for non-commercial research purposes.
+For profit enterprises may use the software, subject to this limitation.
+
+THIS SOFTWARE IS PROVIDED BY THE AI INSTITUTE AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, NON-
+INFRINGEMENT,TITLE, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE AI INSTITUTE OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, DAMAGES ARISING OUT OF CLAIMS OF
+INTELLECTUAL PROPERTY RIGHTS INFRINGEMENT; PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/theia/__init__.py b/theia/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/configs/dataset/ego4d.yaml b/theia/configs/dataset/ego4d.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cd10b2691475e805bf856df9ef6850aef950791d
--- /dev/null
+++ b/theia/configs/dataset/ego4d.yaml
@@ -0,0 +1,5 @@
+defaults:
+    - image_video_default
+
+dataset_mix:
+    - "ego4d_1in150"
diff --git a/theia/configs/dataset/epic_kitchen.yaml b/theia/configs/dataset/epic_kitchen.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d93b4d5fa488ee0689d964e42969c8c55e7a0971
--- /dev/null
+++ b/theia/configs/dataset/epic_kitchen.yaml
@@ -0,0 +1,5 @@
+defaults:
+    - image_video_default
+
+dataset_mix:
+    - "epic_kitchen_1in60"
diff --git a/theia/configs/dataset/image_video_default.yaml b/theia/configs/dataset/image_video_default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f7754ba06e96322b377aed7f28830ffd971365a
--- /dev/null
+++ b/theia/configs/dataset/image_video_default.yaml
@@ -0,0 +1,7 @@
+return_metadata: False
+shuffle: True
+shuffle_buffer_size: 1024
+feature_norm: True
+dataset_root: "/storage/nfs/datasets/jshang/"
+dataset_ratio: 0.1
+load_action: False
diff --git a/theia/configs/dataset/image_video_mix.yaml b/theia/configs/dataset/image_video_mix.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e842cb9690501a1e021bfb27b4989fc3aef930dc
--- /dev/null
+++ b/theia/configs/dataset/image_video_mix.yaml
@@ -0,0 +1,8 @@
+defaults:
+    - image_video_default
+
+dataset_mix:
+    - "ego4d_1in150"
+    - "ssv2_1in32"
+    - "epic_kitchen_1in60"
+    - "imagenet"
diff --git a/theia/configs/dataset/imagenet.yaml b/theia/configs/dataset/imagenet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bc4513fcb01f324033a662c459a1c53fcf8dd82e
--- /dev/null
+++ b/theia/configs/dataset/imagenet.yaml
@@ -0,0 +1,5 @@
+defaults:
+    - image_video_default
+
+dataset_mix:
+    - "imagenet"
diff --git a/theia/configs/dataset/oxe_octo_mix.yaml b/theia/configs/dataset/oxe_octo_mix.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d77cde0cf133f6c69b0c30343ddfd5aa9c5475e1
--- /dev/null
+++ b/theia/configs/dataset/oxe_octo_mix.yaml
@@ -0,0 +1,12 @@
+_target_: dataset.oxe.oxe_data_utils.OXEDataset
+dataset_mix: "oxe_magic_soup"
+image_action_set_root: "/storage/nfs/datasets/jshang/oxe_image_action"
+feature_set_root: "/storage/nfs/datasets/jshang/oxe_vfm_features"
+image_views: null
+split: "train"
+data_portion: 0.01
+load_action: False
+bf16: True
+safe_tensors: True
+trajectory_subsample_len: 32
+return_metadata: False
diff --git a/theia/configs/dataset/ssv2.yaml b/theia/configs/dataset/ssv2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..29f00acd2f741aaf55036df9c70cd22410367211
--- /dev/null
+++ b/theia/configs/dataset/ssv2.yaml
@@ -0,0 +1,5 @@
+defaults:
+    - image_video_default
+
+dataset_mix:
+    - "ssv2_1in32"
diff --git a/theia/configs/logging/default.yaml b/theia/configs/logging/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..73271701f2fb98479bc7ab2774f8a2ca9830b1e0
--- /dev/null
+++ b/theia/configs/logging/default.yaml
@@ -0,0 +1,6 @@
+model_path: "/storage/nfs/jshang/trained_models"
+log_path: "/storage/nfs/jshang/logs"
+save_ckpt_interval: 20000
+notes: ""
+run_identifier_prefix: ""
+project: "theia"
diff --git a/theia/configs/model/backbone/deit.yaml b/theia/configs/model/backbone/deit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7f2ff8291a73dc05d9a3542eff065314019bc07c
--- /dev/null
+++ b/theia/configs/model/backbone/deit.yaml
@@ -0,0 +1,2 @@
+backbone: facebook/deit-small-patch16-224
+pretrained: False
diff --git a/theia/configs/model/backbone/deit_nocls.yaml b/theia/configs/model/backbone/deit_nocls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e0ddada1a029a866d6b9848a35fd746653bbfb3c
--- /dev/null
+++ b/theia/configs/model/backbone/deit_nocls.yaml
@@ -0,0 +1,2 @@
+backbone: nocls-facebook/deit-tiny-patch16-224
+pretrained: False
diff --git a/theia/configs/model/backbone/deit_reg.yaml b/theia/configs/model/backbone/deit_reg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6ee8f155401453bdb065d506cd1aaad58bf22e87
--- /dev/null
+++ b/theia/configs/model/backbone/deit_reg.yaml
@@ -0,0 +1,3 @@
+backbone: reg-facebook/deit-tiny-patch16-224
+pretrained: False
+num_reg_tokens: 7
diff --git a/theia/configs/model/translator/conv.yaml b/theia/configs/model/translator/conv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4d78b80229b387be57abd33c69f0b203c0d69fd8
--- /dev/null
+++ b/theia/configs/model/translator/conv.yaml
@@ -0,0 +1,3 @@
+type: "conv"
+kwargs:
+  translator_hidden_size: 1024
diff --git a/theia/configs/model/translator/lconv.yaml b/theia/configs/model/translator/lconv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3abaef5f4dee8a99c8aa281780b0aa7913b7c605
--- /dev/null
+++ b/theia/configs/model/translator/lconv.yaml
@@ -0,0 +1,3 @@
+type: "lconv"
+kwargs:
+  hidden_size_factor: 1.0
diff --git a/theia/configs/model/translator/mlp.yaml b/theia/configs/model/translator/mlp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a8e2b66e5bdded40dbba486be72cb7055d95d89f
--- /dev/null
+++ b/theia/configs/model/translator/mlp.yaml
@@ -0,0 +1,4 @@
+type: "mlp"
+kwargs:
+  translator_n_layer: 3
+  hidden_size: 1024
diff --git a/theia/configs/model/translator/transformer.yaml b/theia/configs/model/translator/transformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..604d06434a8d36e51899cc0a4bc9a495b063c606
--- /dev/null
+++ b/theia/configs/model/translator/transformer.yaml
@@ -0,0 +1,5 @@
+type: "transformer"
+kwargs:
+  translator_n_layers: 2
+  translator_n_heads: 8
+  translator_hidden_size: 1024
diff --git a/theia/configs/train_rvfm_imagenet.yaml b/theia/configs/train_rvfm_imagenet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..476f2f0a77e54febbd0bcff8db2d2efdfb1334a8
--- /dev/null
+++ b/theia/configs/train_rvfm_imagenet.yaml
@@ -0,0 +1,9 @@
+defaults:
+  - dataset: imagenet
+  - model/backbone: deit
+  - model/translator: lconv
+  - training: frame_level
+  - logging: default
+  - _self_
+
+seed: 0
diff --git a/theia/configs/training/frame_level.yaml b/theia/configs/training/frame_level.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34748e30e3e166d87e41aced2a2cb56974fdd934
--- /dev/null
+++ b/theia/configs/training/frame_level.yaml
@@ -0,0 +1,35 @@
+defaults:
+ - target_models: cdiv
+
+epochs: 50
+warm_up_steps_ratio: 0.1
+
+base_lr: 2e-3
+batch_size: 16
+random_target_models: -1
+num_workers: 8
+# base training settings to scale lr, rarely changed
+base_batch_size: 64
+base_world_size: 8
+
+weight_decay: 0.01
+
+
+optimizer:
+  _target_: torch.optim.AdamW
+  betas: [0.9, 0.999]
+
+lr_scheduler:
+  _target_: theia.lr_schedulers.get_constant_lrs_with_linear_warm_up
+  warm_up_lr_start_factor: 1e-2
+
+
+grad_clip: False
+grad_clip_norm_warmup: 10.0
+grad_clip_norm: 1.0
+
+freeze_translator: False
+freeze_translator_start_steps_ratio: 0.2
+translator_lr_factor: 1.0
+
+main_loss: cos_l1
diff --git a/theia/configs/training/target_models/cdds.yaml b/theia/configs/training/target_models/cdds.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f3d1e92b2f22d27c4806d76af7b3baa58952a5cd
--- /dev/null
+++ b/theia/configs/training/target_models/cdds.yaml
@@ -0,0 +1,6 @@
+target_model_names:
+ - "facebook/dinov2-large"
+ - "openai/clip-vit-large-patch14"
+ - "facebook/sam-vit-huge"
+ - "LiheYoung/depth-anything-large-hf"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/cddsv.yaml b/theia/configs/training/target_models/cddsv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f037a78518a7f2d03a1bf24068bff38ed6581839
--- /dev/null
+++ b/theia/configs/training/target_models/cddsv.yaml
@@ -0,0 +1,7 @@
+target_model_names:
+ - "google/vit-huge-patch14-224-in21k"
+ - "facebook/dinov2-large"
+ - "openai/clip-vit-large-patch14"
+ - "facebook/sam-vit-huge"
+ - "LiheYoung/depth-anything-large-hf"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/cddv.yaml b/theia/configs/training/target_models/cddv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..13db812c5cde279865f0c08677023c93ceb709fd
--- /dev/null
+++ b/theia/configs/training/target_models/cddv.yaml
@@ -0,0 +1,6 @@
+target_model_names:
+ - "google/vit-huge-patch14-224-in21k"
+ - "facebook/dinov2-large"
+ - "openai/clip-vit-large-patch14"
+ - "LiheYoung/depth-anything-large-hf"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/cdesv.yaml b/theia/configs/training/target_models/cdesv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5696b402f914f5cbf7a35cd10c3991ff1445185e
--- /dev/null
+++ b/theia/configs/training/target_models/cdesv.yaml
@@ -0,0 +1,6 @@
+target_model_names:
+ - "google/vit-huge-patch14-224-in21k"
+ - "openai/clip-vit-large-patch14"
+ - "facebook/sam-vit-huge"
+ - "LiheYoung/depth-anything-large-hf"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/cdis.yaml b/theia/configs/training/target_models/cdis.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f298afcca8ef9bb6b6c58b972b5987a33090bec1
--- /dev/null
+++ b/theia/configs/training/target_models/cdis.yaml
@@ -0,0 +1,5 @@
+target_model_names:
+ - "facebook/dinov2-large"
+ - "openai/clip-vit-large-patch14"
+ - "facebook/sam-vit-huge"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/cdisv.yaml b/theia/configs/training/target_models/cdisv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f2e8f585a3428486842fa9538cbad26a498f2afd
--- /dev/null
+++ b/theia/configs/training/target_models/cdisv.yaml
@@ -0,0 +1,6 @@
+target_model_names:
+ - "google/vit-huge-patch14-224-in21k"
+ - "facebook/dinov2-large"
+ - "openai/clip-vit-large-patch14"
+ - "facebook/sam-vit-huge"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/cdiv.yaml b/theia/configs/training/target_models/cdiv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a9130d9ecf8ee16537af8e6037cd091fd839df7b
--- /dev/null
+++ b/theia/configs/training/target_models/cdiv.yaml
@@ -0,0 +1,5 @@
+target_model_names:
+ - "google/vit-huge-patch14-224-in21k"
+ - "facebook/dinov2-large"
+ - "openai/clip-vit-large-patch14"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/clip.yaml b/theia/configs/training/target_models/clip.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f2a31815bae9bfda0859bedd7854b19e6d276961
--- /dev/null
+++ b/theia/configs/training/target_models/clip.yaml
@@ -0,0 +1,3 @@
+target_model_names:
+ - "openai/clip-vit-large-patch14"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/ddsv.yaml b/theia/configs/training/target_models/ddsv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..df5197be300fc301f39b847f7e88d8df78b3ca1a
--- /dev/null
+++ b/theia/configs/training/target_models/ddsv.yaml
@@ -0,0 +1,6 @@
+target_model_names:
+ - "google/vit-huge-patch14-224-in21k"
+ - "facebook/dinov2-large"
+ - "facebook/sam-vit-huge"
+ - "LiheYoung/depth-anything-large-hf"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/depth_anything.yaml b/theia/configs/training/target_models/depth_anything.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..57367c3cb527860d348c501e70abc009bdd2fa40
--- /dev/null
+++ b/theia/configs/training/target_models/depth_anything.yaml
@@ -0,0 +1,3 @@
+target_model_names:
+ - "LiheYoung/depth-anything-large-hf"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/dinov2.yaml b/theia/configs/training/target_models/dinov2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f8aca1f2452e022512dde25ffa416a613dbb4734
--- /dev/null
+++ b/theia/configs/training/target_models/dinov2.yaml
@@ -0,0 +1,3 @@
+target_model_names:
+ - "facebook/dinov2-large"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/sam.yaml b/theia/configs/training/target_models/sam.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9c9c4a53ca57b399fc5f03f00bfb449906a6f4e9
--- /dev/null
+++ b/theia/configs/training/target_models/sam.yaml
@@ -0,0 +1,3 @@
+target_model_names:
+ - "facebook/sam-vit-huge"
+target_model_weights: null
diff --git a/theia/configs/training/target_models/vit.yaml b/theia/configs/training/target_models/vit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8ae383eb0d8dd87a25658e900adaaa7680c22e91
--- /dev/null
+++ b/theia/configs/training/target_models/vit.yaml
@@ -0,0 +1,3 @@
+target_model_names:
+ - "google/vit-huge-patch14-224-in21k"
+target_model_weights: null
diff --git a/theia/dataset/__init__.py b/theia/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4216d210244524d8b1abec2ad5fddd19f068faa0
--- /dev/null
+++ b/theia/dataset/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from .image.image_common import ALL_IMAGE_DATASETS
+from .oxe.oxe_common import ALL_OXE_DATASETS
+from .video.video_common import ALL_VIDEO_DATASETS
diff --git a/theia/dataset/data_utils.py b/theia/dataset/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4edf36c441220172610b46039cafdbe4d3156643
--- /dev/null
+++ b/theia/dataset/data_utils.py
@@ -0,0 +1,591 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+"""Defines PyTorch datasets of dataloaders for multiple image, video, and OXE datasets.
+Should use with webdataset >= 0.2.90. See https://github.com/webdataset/webdataset/pull/347"""
+
+import glob
+import json
+import math
+import os.path as osp
+from collections import OrderedDict
+from functools import partial
+from io import BytesIO
+from typing import Any, Callable, Generator, Iterator, Literal, Optional
+
+import cv2
+import numpy as np
+import omegaconf
+import torch
+import webdataset as wds
+from datasets.combine import DatasetType
+from einops import rearrange
+from numpy.typing import NDArray
+from safetensors.torch import load as sft_load
+from torch import default_generator
+from torch.utils.data import DataLoader, Dataset, IterableDataset, default_collate
+
+from theia.foundation_models.common import MODELS
+from theia.dataset.oxe.oxe_common import ALL_OXE_DATASETS
+from theia.dataset.oxe.oxe_mixes import OXE_NAMED_MIXES
+
+PACKED_FEATURES = [model_name for model_name in MODELS if "llava" not in model_name]
+
+
+def normalize_ds_weights_by_ds_len(weights: list[float], lengths: list[int]) -> tuple[list[float], float | Literal[0]]:
+    """Normalize dataset weights by dataset lengths (frames).
+
+    Args:
+        weights (list[float]): assigned weights.
+        lengths (list[int]): lengths of datasets.
+
+    Returns:
+        tuple[list[float], int]: normalized weights, and sum of the expected lengths of datasets
+    """
+    expected_lengths = [weight * length for weight, length in zip(weights, lengths, strict=False)]
+    sum_expected_lengths = sum(expected_lengths)
+    if sum_expected_lengths == 0:
+        raise ValueError("Sum of dataset length is 0.")
+    normalized_weights = [length * 1.0 / sum_expected_lengths for length in expected_lengths]
+    return normalized_weights, sum_expected_lengths
+
+
+def get_vo_keys(dataset_name: str, image_views: Optional[list | str | dict[str, str | list[str]]] = None) -> list[str]:
+    """Get visual observation keys of datasets (to be compatible with OXE).
+
+    Args:
+        dataset_name (str): name of the dataset.
+        image_views (Optional[dict[str, str  |  list[str]]], optional): keys of selected views.
+            Defaults to None.
+
+    Returns:
+        list[str]: keys to the views in the dataset.
+    """
+    default_visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"][:1]
+    visual_observation_keys = []
+    if image_views is None:
+        visual_observation_keys = default_visual_observation_keys
+    elif isinstance(image_views, list):
+        visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
+    elif isinstance(image_views, str):
+        if image_views == "static":
+            visual_observation_keys = [
+                k
+                for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
+                if "wrist" not in k and "hand" not in k
+            ]
+        elif image_views == "wrist":
+            visual_observation_keys = [
+                k for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] if "wrist" in k or "hand" in k
+            ]
+    if len(visual_observation_keys) == 0:
+        visual_observation_keys = default_visual_observation_keys
+    return visual_observation_keys
+
+
+class RandomMix(IterableDataset):
+    """A random interleave of multiple iterable datasets."""
+
+    def __init__(
+        self,
+        datasets: list[IterableDataset],
+        probs: list[float] | NDArray | None = None,
+        stopping_strategy: str = "all_exhausted",
+        seed: Optional[int | str] = 0,
+    ) -> None:
+        """Initialization of a random interleave dataset.
+
+        Args:
+            datasets (list[IterableDataset]): datasets to be interleaved.
+            probs (list[float] | NDArray, optional): probability of each dataset. Defaults to None.
+            stopping_strategy (str, optional): when to end the sampling for one epoch. Defaults to `all_exhausted`.
+                `all_exhausted`: each sample in the dataset will be sampled at least once.
+                `first_exhausted`: when the first dataset is ran out, this episode ends.
+                See also https://huggingface.co/docs/datasets/en/stream#interleave for definitions.
+            seed (Optional[int | str]): seed. Defaults to 0.
+        """
+        self.datasets = datasets
+        if probs is None:
+            self.probs = [1.0] * len(self.datasets)
+        elif isinstance(probs, np.ndarray):
+            self.probs = probs.tolist()
+        else:
+            self.probs = probs
+        self.stopping_strategy = stopping_strategy
+        self.seed = seed
+
+    def __iter__(self) -> Generator:
+        """Return an iterator over the sources."""
+        sources = [iter(d) for d in self.datasets]
+        probs = self.probs[:]
+        seed_gen = torch.Generator()
+        seed_gen.manual_seed(self.seed)
+        cum = (np.array(probs) / np.sum(probs)).cumsum()
+        while len(sources) > 0:
+            r = torch.rand(1, generator=seed_gen).item()
+            i = np.searchsorted(cum, r)
+            try:
+                yield next(sources[i])
+            except StopIteration:
+                if self.stopping_strategy == "all_exhausted":
+                    del sources[i]
+                    del probs[i]
+                    cum = (np.array(probs) / np.sum(probs)).cumsum()
+                elif self.stopping_strategy == "first_exhausted":
+                    break
+
+
+def decode_sample(
+    key: str, data: bytes, image_transform: Optional[Callable] = None, feature_transform: Optional[Callable] = None
+) -> Any:
+    """Decode a sample from bytes with optional image and feature transforms
+
+    Args:
+        key (str): key of an attribute (a column) of the sample.
+        data (bytes): original data bytes.
+        image_transform (Optional[Callable], optional): image transform. Defaults to None.
+        feature_transform (Optional[Callable], optional): feature transform. Defaults to None.
+
+    Returns:
+        Any: decoded data.
+    """
+    if ".safetensors" in key:
+        sft = sft_load(data)
+        embedding = rearrange(sft["embedding"], "c h w -> (h w) c")
+        if feature_transform is not None:
+            embedding = feature_transform(embedding)
+        if "cls_token" in sft:
+            cls = sft["cls_token"]
+            if feature_transform is not None:
+                cls = feature_transform(cls)
+                return {"embedding": embedding, "cls": cls}
+        return {"embedding": embedding}
+    elif key == ".image":
+        image = np.load(BytesIO(data))
+        if len(image.shape) == 2:
+            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
+        elif len(image.shape) == 3 and image.shape[-1] == 4:
+            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
+        if image_transform is not None:
+            return image_transform(image)
+        return image
+    else:
+        return data
+
+
+def get_oxe_frame_dataset(
+    dataset_root: str,
+    dataset_mix: Optional[str | dict[str, float] | list] = "oxe_magic_soup",
+    feature_models: Optional[list[str]] = None,
+    split: str = "train",
+    dataset_ratio: float = 1.0,
+    image_views: Optional[dict[str, str | list[str]]] = None,
+    image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
+    seed: Optional[int | str] = 0,
+    shuffle: bool = False,
+    world_size: int = 1,
+) -> tuple[dict[str, DatasetType], float | Literal[0]]:
+    """Get OXE datasets at frame level.
+
+    Args:
+        dataset_root (str): root dir of the datasets.
+        dataset_mix (Optional[str  |  dict[str, float]  |  list], optional): how to mix the datasets.
+            Defaults to "oxe_magic_soup".
+        feature_models (Optional[list[str]], optional): models to load their features. Defaults to None.
+        split (str, optional): split "train" or "val" or "test". Defaults to "train".
+        dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
+        image_views (Optional[dict[str, str  |  list[str]]], optional): image views to select. Defaults to None.
+        image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
+            Defaults to None.
+        seed (Optional[int  |  str], optional): seed. Defaults to 0.
+        shuffle (bool, optional): shuffle or not. Defaults to False.
+        world_size (int, optional): world size of DDP training. Defaults to 1.
+
+    Returns:
+        tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
+    """
+    # read dataset mix from any acceptable form
+    if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
+        dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
+    elif isinstance(dataset_mix, dict):
+        dataset_mix = OrderedDict(**dataset_mix)
+    elif isinstance(dataset_mix, list):
+        dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
+    else:
+        raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
+
+    if split == "eval" or split == "val":
+        dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
+
+    # note down the dataset weights
+    dataset_weights: list[float] = []
+    # get frame level length
+    dataset_lens: list[int] = []
+
+    all_feature_datasets: dict[str, DatasetType] = {}
+    for dataset in dataset_mix:
+        visual_observation_keys = get_vo_keys(dataset_name=dataset, image_views=image_views)
+
+        if feature_models is None:
+            feature_models = PACKED_FEATURES
+
+        with open(osp.join(dataset_root, dataset, "splits.json"), "r") as splitf:
+            dataset_len = json.load(splitf)[split]
+        # if the length is 0, skip
+        # this may happen for small datasets with very few shards
+        if dataset_len == 0:
+            continue
+
+        for vo_key in visual_observation_keys:
+            for model_name in feature_models:
+                if model_name not in PACKED_FEATURES:
+                    feature_set_name = model_name
+                    path_pattern = osp.join(
+                        dataset_root, dataset, vo_key + f"_{model_name.replace('/', '_')}", f"*-{split}*.tar"
+                    )
+                    rename_kw = {model_name: model_name.replace("/", "_") + ".safetensors"}  # replace v by k
+                elif "packed" in all_feature_datasets:
+                    continue
+                else:
+                    feature_set_name = "packed"
+                    path_pattern = osp.join(dataset_root, dataset, vo_key, f"*-{split}*.tar")
+                    rename_kw = {
+                        name: name.replace("/", "_") + ".safetensors" for name in PACKED_FEATURES
+                    }  # replace v by k
+                rename_kw["image"] = "image"
+
+                if feature_set_name not in all_feature_datasets:
+                    all_feature_datasets[feature_set_name] = []
+
+                shard_paths = sorted(glob.glob(path_pattern))
+                num_shards = len(shard_paths)
+                if num_shards < world_size * 8:
+                    shard_paths *= math.ceil(world_size * 8 / num_shards)
+                ds = (
+                    wds.WebDataset(
+                        shard_paths,
+                        nodesplitter=wds.split_by_node,
+                        workersplitter=wds.split_by_worker,
+                        detshuffle=True,
+                        shardshuffle=shuffle,
+                        seed=seed,
+                    )
+                    .decode(partial(decode_sample, image_transform=image_transform))
+                    .rename(keep=False, **rename_kw)
+                )
+                all_feature_datasets[feature_set_name].append(ds)
+
+            dataset_weights.append(dataset_mix[dataset])
+            dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
+
+    normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
+
+    combined_feature_datasets: dict[str, Dataset] = {}
+    for feature_set_name, fds in all_feature_datasets.items():
+        ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted")
+        combined_feature_datasets[feature_set_name] = ds
+
+    return combined_feature_datasets, sum_expected_lengths
+
+
+def get_oxe_frame_dataloader(
+    datasets: dict[str, DatasetType], batch_size: Optional[int] = None, shuffle_buffer_size: int = 1_000, **kwargs: Any
+) -> dict[str, DataLoader]:
+    """Get dataloaders of OXE datasets. Corresponding to `get_oxe_frame_dataset()`.
+
+    Args:
+        datasets (dict[str, DatasetType]): OXE datasets from `get_oxe_frame_dataset().
+        batch_size (Optional[int], optional): batch size. Defaults to None.
+        shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
+
+    Returns:
+        dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
+    """
+    loaders = {
+        k: (
+            wds.WebLoader(datasets[k], batch_size=None, **kwargs)
+            .shuffle(shuffle_buffer_size)  # shuffle after mix
+            .batched(batch_size, collation_fn=default_collate)
+        )
+        for k in datasets
+    }
+    return loaders
+
+
+def get_oxe_frame_iterator(
+    data_loaders: dict[str, DataLoader],
+) -> Iterator[dict[str, Any]]:
+    """Get iterator from dataloders. Corresponding to `get_oxe_frame_dataloader()`.
+
+    Args:
+        data_loaders (dict[str, DataLoader]): dataloaders from `get_oxe_frame_dataloader()`.
+
+    Yields:
+        Iterator[dict[str, Any]]: data sample.
+    """
+    packed_loader = data_loaders.get("packed", None)
+    # place packed_loader at the first
+    if packed_loader is not None:
+        loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
+    else:
+        loaders = list(data_loaders.values())
+
+    # merge dicts
+    for data in zip(*loaders, strict=False):
+        # yield data
+        for i in range(1, len(loaders)):
+            for k in data[i]:
+                if k not in data[0]:
+                    data[0][k] = data[i][k]
+        yield data[0]
+
+
+def normalize_feature(
+    x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+    """Normalize the feature given mean and std.
+
+    Args:
+        x (torch.Tensor): input features
+        mean (Optional[torch.Tensor], optional): mean values. Defaults to None.
+        std (Optional[torch.Tensor], optional): std values. Defaults to None.
+
+    Returns:
+        torch.Tensor: feature after normalization
+    """
+    return x if mean is None or std is None else (x - mean) / std
+
+
+def load_feature_stats(
+    dataset_root: str, feature_models: list[str]
+) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
+    """Load feature statictics (mean and variance).
+
+    Args:
+        dataset_root (str): root dir of the dataset (or where to hold the statistics).
+        feature_models (list[str]): names of the models/features.
+
+    Returns:
+        tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variances. Keys are model names.
+    """
+    feature_means: dict[str, torch.Tensor] = {}
+    feature_vars: dict[str, torch.Tensor] = {}
+    for model in feature_models:
+        model_name = model.replace("/", "_")
+        feature_means[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_mean_{model_name}.npy"))).to(
+            torch.bfloat16
+        )
+        feature_vars[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_var_{model_name}.npy"))).to(
+            torch.bfloat16
+        )
+    return feature_means, feature_vars
+
+
+def pad_shard_paths(shard_paths: list[str], num_shards: int, num_parts: int) -> list[str]:
+    """Pad shard paths to be divided by number of partitions (ranks*nodes).
+
+    Args:
+        shard_paths (list[str]): pathes of dataset shards.
+        num_shards (int): number of shards.
+        num_parts (int): number of partitions.
+
+    Returns:
+        list[str]: shard paths padded.
+    """
+    final_shard_paths = shard_paths
+    if num_shards % num_parts != 0:
+        if num_shards < num_parts - num_shards:
+            for _ in range(math.floor((num_parts - num_shards) / num_shards)):
+                final_shard_paths += shard_paths[:]
+            final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
+        else:
+            final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
+    return final_shard_paths
+
+
+def get_image_video_dataset(
+    dataset_root: str,
+    feature_models: list[str],
+    dataset_mix: Optional[str | dict[str, float] | list] = None,
+    split: str = "train",
+    dataset_ratio: float = 1.0,
+    image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
+    feature_norm: bool = False,
+    seed: Optional[int | str] = 0,
+    shuffle: bool = False,
+    world_size: int = 1,
+    **kwargs: Any,
+) -> tuple[dict[str, DatasetType], float | Literal[0]]:
+    """Get image and video datasets at frame level.
+
+    Args:
+        dataset_root (str): root dir of the datasets.
+        feature_models (list[str]): models to load their features.
+        dataset_mix (Optional[str  |  dict[str, float]  |  list], optional): how to mix the datasets.
+        split (str, optional): split "train" or "val" or "test". Defaults to "train".
+        dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
+        image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
+            Defaults to None.
+        feature_norm: (bool, optional): whether to normalize the feature. Defaults to False.
+        seed (Optional[int  |  str], optional): seed. Defaults to 0.
+        shuffle (bool, optional): shuffle or not. Defaults to False.
+        world_size (int, optional): world size of DDP training. Defaults to 1.
+        kwargs (Any): arguments to pass-through.
+
+    Returns:
+        tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
+    """
+    # read dataset mix from any acceptable form
+    if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
+        dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
+    elif isinstance(dataset_mix, dict):
+        dataset_mix = OrderedDict(**dataset_mix)
+    elif isinstance(dataset_mix, list) or isinstance(dataset_mix, omegaconf.listconfig.ListConfig):
+        dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
+    else:
+        raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
+
+    if split == "eval" or split == "val":
+        dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
+
+    # note down the dataset weights
+    dataset_weights: list[float] = []
+    # get frame level length
+    dataset_lens: list[int] = []
+
+    all_feature_datasets: dict[str, DatasetType] = {}
+
+    if feature_norm:
+        feature_means, feature_vars = load_feature_stats(dataset_root, feature_models)
+
+    for d in dataset_mix:
+
+        with open(osp.join(dataset_root, d, "splits.json"), "r") as splitf:
+            dataset_len = json.load(splitf)[split]
+
+        # if the length is 0, skip
+        # this may happen for small datasets with very few shards
+        if dataset_len == 0:
+            continue
+
+        path_pattern = osp.join(dataset_root, d, "images", f"*-{split}.tar")
+        if "image" not in all_feature_datasets:
+            all_feature_datasets["image"] = []
+        shard_paths = sorted(glob.glob(path_pattern))
+        num_shards = len(shard_paths)
+        num_parts = world_size
+        final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
+        ds = wds.WebDataset(
+            final_shard_paths,
+            nodesplitter=wds.split_by_node,
+            workersplitter=wds.split_by_worker,
+            detshuffle=True,
+            shardshuffle=shuffle,
+            seed=seed,
+        ).decode(partial(decode_sample, image_transform=image_transform))
+        all_feature_datasets["image"].append(ds)
+
+        for model_name in feature_models:
+            path_pattern = osp.join(dataset_root, d, f"{model_name.replace('/', '_')}", f"*-{split}.tar")
+            rename_kw = {model_name: model_name.replace("/", "_").lower() + ".safetensors"}  # replace v by k
+
+            if model_name not in all_feature_datasets:
+                all_feature_datasets[model_name] = []
+
+            shard_paths = sorted(glob.glob(path_pattern))
+            num_shards = len(shard_paths)
+            num_parts = world_size
+            final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
+            if feature_norm:
+                feature_transform = partial(
+                    normalize_feature, mean=feature_means[model_name], std=feature_vars[model_name]
+                )
+            else:
+                feature_transform = None
+            ds = (
+                wds.WebDataset(
+                    final_shard_paths,
+                    nodesplitter=wds.split_by_node,
+                    workersplitter=wds.split_by_worker,
+                    detshuffle=True,
+                    shardshuffle=shuffle,
+                    seed=seed,
+                )
+                .decode(partial(decode_sample, image_transform=image_transform, feature_transform=feature_transform))
+                .rename(keep=False, **rename_kw)
+            )
+            all_feature_datasets[model_name].append(ds)
+
+        dataset_weights.append(dataset_mix[d])
+        dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
+
+    normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
+
+    combined_feature_datasets: dict[str, Dataset] = {}
+    for feature_set_name, fds in all_feature_datasets.items():
+        ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted", seed=seed)
+        combined_feature_datasets[feature_set_name] = ds
+
+    return combined_feature_datasets, sum_expected_lengths
+
+
+def get_frame_dataloader(
+    datasets: dict[str, DatasetType],
+    batch_size: Optional[int] = None,
+    shuffle: bool = False,
+    shuffle_buffer_size: int = 1_000,
+    seed: Optional[int] = 0,
+    **kwargs: Any,
+) -> dict[str, DataLoader]:
+    """Get dataloaders of image and video datasets. Corresponding to `get_image_video_dataset()`.
+
+    Args:
+        datasets (dict[str, DatasetType]): image and video datasets from `get_image_video_dataset().
+        batch_size (Optional[int], optional): batch size. Defaults to None.
+        shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
+
+    Returns:
+        dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
+    """
+    loaders = {}
+    for k in datasets:
+        loader = wds.WebLoader(datasets[k], batch_size=None, generator=default_generator, **kwargs)
+        if shuffle:
+            loader = loader.shuffle(shuffle_buffer_size, seed=seed)  # shuffle after mix
+        loader = loader.batched(batch_size, collation_fn=default_collate)
+        loaders[k] = loader
+    return loaders
+
+
+def get_frame_iterator(
+    data_loaders: dict[str, DataLoader],
+) -> Iterator[dict[str, Any]]:
+    """Get iterator from image and video dataset dataloders. Corresponding to `get_frame_dataloader()`.
+
+    Args:
+        data_loaders (dict[str, DataLoader]): dataloaders from `get_frame_dataloader()`.
+
+    Yields:
+        Iterator[dict[str, Any]]: data sample.
+    """
+    packed_loader = data_loaders.get("packed", None)
+    # place packed_loader at the first
+    if packed_loader is not None:
+        loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
+    else:
+        loaders = list(data_loaders.values())
+
+    # merge dicts
+    # this is to accommodate the old organization of datasets (each shard contains one or more columns,
+    # and images are duplicated columns).
+    # In new (current) dataset organization (columns are completely separated),
+    # column keys are all different except some "built-in" keys added by webdataset,
+    # but they are not related to any data, training, so on.
+    # During transit from old to new, where two organizations exist at the same time,
+    # this is to ignore extra "image" field in datasets loaded.
+    for data in zip(*loaders, strict=False):
+        # yield data
+        for i in range(1, len(loaders)):
+            for k in data[i]:
+                if k not in data[0]:
+                    data[0][k] = data[i][k]
+        yield data[0]
diff --git a/theia/dataset/image/__init__.py b/theia/dataset/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81d96b9788b4242dec0396c89f7cd1037cffcb05
--- /dev/null
+++ b/theia/dataset/image/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from .image_common import ALL_IMAGE_DATASETS
diff --git a/theia/dataset/image/image_common.py b/theia/dataset/image/image_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd657b3f8c00528109448433f0963b1a85efd89c
--- /dev/null
+++ b/theia/dataset/image/image_common.py
@@ -0,0 +1,5 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from collections import OrderedDict
+
+ALL_IMAGE_DATASETS = OrderedDict({"imagenet": {"steps": 1_281_167}})
diff --git a/theia/dataset/oxe/__init__.py b/theia/dataset/oxe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/dataset/oxe/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/dataset/oxe/oxe_common.py b/theia/dataset/oxe/oxe_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..36c7ffed4eaf81c2a58d2d717df0a86434b4ef8e
--- /dev/null
+++ b/theia/dataset/oxe/oxe_common.py
@@ -0,0 +1,430 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from collections import OrderedDict
+from typing import Optional
+
+"""
+This ALL_OXE_DATASETS below records metadata of all subsets of OXE dataset.
+The datasets are in alphabetical order.
+
+versions (list[str]): available and usable versions, sorted from older to newer.
+                      Usually use the last one.
+episodes (int): total episodes in the dataset.
+steps    (int): total steps in the dataset.
+visual_observation_keys (list[str]): keys to specify image observations.
+"""
+ALL_OXE_DATASETS: OrderedDict = OrderedDict(
+    {
+        "agent_aware_affordances": {
+            "versions": ["1.0.0"],
+            "episodes": 118,
+            "steps": 151628,
+            "visual_observation_keys": ["image"],
+        },
+        "asu_table_top_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 110,
+            "steps": 26113,
+            "visual_observation_keys": ["image"],
+        },
+        "austin_buds_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 50,
+            "steps": 34112,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "austin_sailor_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 240,
+            "steps": 353094,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "austin_sirius_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 559,
+            "steps": 279939,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "bc_z": {
+            "versions": [
+                "0.1.0",  # "1.0.0", "old1.0.1", and "1.0.1" are not usable
+            ],
+            "episodes": 39350,
+            "steps": 5471693,
+            "visual_observation_keys": ["image"],
+        },
+        "berkeley_autolab_ur5": {
+            "versions": ["0.1.0"],
+            "episodes": 896,
+            "steps": 87783,
+            "visual_observation_keys": ["image", "hand_image"],
+        },
+        "berkeley_cable_routing": {
+            "versions": ["0.1.0"],
+            "episodes": 1482,
+            "steps": 38240,
+            "visual_observation_keys": ["image", "top_image", "wrist225_image", "wrist45_image"],
+        },
+        "berkeley_fanuc_manipulation": {
+            "versions": ["0.1.0"],
+            "episodes": 415,
+            "steps": 62613,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "berkeley_gnm_cory_hall": {
+            "versions": ["0.1.0"],
+            "episodes": 7331,
+            "steps": 156012,
+            "visual_observation_keys": ["image"],
+        },
+        "berkeley_gnm_recon": {
+            "versions": ["0.1.0"],
+            "episodes": 11834,
+            "steps": 610907,
+            "visual_observation_keys": ["image"],
+        },
+        "berkeley_gnm_sac_son": {
+            "versions": ["0.1.0"],
+            "episodes": 2955,
+            "steps": 241059,
+            "visual_observation_keys": ["image"],
+        },
+        "berkeley_mvp_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 480,
+            "steps": 45308,
+            "visual_observation_keys": ["hand_image"],
+        },
+        "berkeley_rpt_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 908,
+            "steps": 392578,
+            "visual_observation_keys": ["hand_image"],
+        },
+        "bridge": {"versions": ["0.1.0"], "episodes": 25460, "steps": 864292, "visual_observation_keys": ["image"]},
+        "cmu_franka_exploration_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 199,
+            "steps": 1990,
+            "visual_observation_keys": ["image"],
+        },
+        "cmu_play_fusion": {
+            "versions": ["0.1.0"],
+            "episodes": 576,
+            "steps": 235922,
+            "visual_observation_keys": ["image"],
+        },
+        "cmu_playing_with_food": {  # this dataset seems to be corrupted
+            "versions": ["1.0.0"],
+            "episodes": 4200,
+            "steps": 83240,
+            "visual_observation_keys": ["image"],
+        },
+        "cmu_stretch": {"versions": ["0.1.0"], "episodes": 135, "steps": 25016, "visual_observation_keys": ["image"]},
+        "columbia_cairlab_pusht_real": {
+            "versions": ["0.1.0"],
+            "episodes": 122,
+            "steps": 24924,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "dlr_edan_shared_control_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 104,
+            "steps": 8928,
+            "visual_observation_keys": ["image"],
+        },
+        "dlr_sara_grid_clamp_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 107,
+            "steps": 7622,
+            "visual_observation_keys": ["image"],
+        },
+        "dlr_sara_pour_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 100,
+            "steps": 12971,
+            "visual_observation_keys": ["image"],
+        },
+        "eth_agent_affordances": {
+            "versions": ["0.1.0"],
+            "episodes": 118,
+            "steps": 151628,
+            "visual_observation_keys": ["image"],
+        },
+        "fanuc_manipulation_v2": {
+            "versions": ["1.0.0"],
+            "episodes": 415,
+            "steps": 62613,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "fractal20220817_data": {
+            "versions": ["0.1.0"],
+            "episodes": 87212,
+            "steps": 3786400,
+            "visual_observation_keys": ["image"],
+        },
+        "furniture_bench_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 5100,
+            "steps": 3948057,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 631,
+            "steps": 146241,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "imperial_wrist_dataset": {
+            "versions": ["1.0.0"],
+            "episodes": 170,
+            "steps": 7148,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "imperialcollege_sawyer_wrist_cam": {
+            "versions": ["0.1.0"],
+            "episodes": 170,
+            "steps": 7148,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "jaco_play": {
+            "versions": ["0.1.0"],
+            "episodes": 976,
+            "steps": 70127,
+            "visual_observation_keys": ["image", "image_wrist"],
+        },
+        "kaist_nonprehensile_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 201,
+            "steps": 32429,
+            "visual_observation_keys": ["image"],
+        },
+        "kuka": {"versions": ["0.1.0"], "episodes": 580392, "steps": 8583978, "visual_observation_keys": ["image"]},
+        "language_table": {
+            "versions": ["0.0.1", "0.1.0"],
+            "episodes": 442226,
+            "steps": 7045476,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_blocktoabsolute_oracle_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 200000,
+            "steps": 15866385,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_blocktoblock_4block_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 8298,
+            "steps": 326768,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_blocktoblock_oracle_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 200000,
+            "steps": 12970620,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_blocktoblock_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 8000,
+            "steps": 351688,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_blocktoblockrelative_oracle_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 200000,
+            "steps": 13016749,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_blocktorelative_oracle_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 200000,
+            "steps": 8655815,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_separate_oracle_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 200000,
+            "steps": 3196661,
+            "visual_observation_keys": ["rgb"],
+        },
+        "language_table_sim": {
+            "versions": ["0.0.1"],
+            "episodes": 181020,
+            "steps": 4665423,
+            "visual_observation_keys": ["rgb"],
+        },
+        "maniskill_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 30213,
+            "steps": 4537402,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "mutex_dataset": {
+            "versions": ["1.0.0"],
+            "episodes": 1500,
+            "steps": 361883,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "nyu_door_opening_surprising_effectiveness": {
+            "versions": ["0.1.0"],
+            "episodes": 435,
+            "steps": 18196,
+            "visual_observation_keys": ["image"],
+        },
+        "nyu_franka_play_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 365,
+            "steps": 34448,
+            "visual_observation_keys": ["image", "image_additional_view"],
+        },
+        "nyu_rot_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 14,
+            "steps": 440,
+            "visual_observation_keys": ["image"],
+        },
+        "qut_dexterous_manpulation": {
+            "versions": ["0.1.0"],
+            "episodes": 200,
+            "steps": 176278,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "robo_net": {
+            "versions": ["0.1.0", "1.0.0"],
+            "episodes": 82775,
+            "steps": 2483250,
+            "visual_observation_keys": ["image", "image1", "image2"],
+        },
+        "robot_vqa": {
+            "versions": ["0.1.0"],
+            "episodes": 3331523,
+            "steps": 3331523,
+            "visual_observation_keys": ["images"],
+        },
+        "roboturk": {
+            "versions": ["0.1.0"],
+            "episodes": 1796,
+            "steps": 168423,
+            "visual_observation_keys": ["front_rgb"],
+        },
+        "stanford_hydra_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 570,
+            "steps": 358234,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 3000,
+            "steps": 149985,
+            "visual_observation_keys": ["image"],
+        },
+        "stanford_mask_vit_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 9109,
+            "steps": 282379,
+            "visual_observation_keys": ["image"],
+        },
+        "stanford_robocook_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 2460,
+            "steps": 112980,
+            "visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
+        },
+        "taco_play": {
+            "versions": ["0.1.0"],
+            "episodes": 3242,
+            "steps": 213972,
+            "visual_observation_keys": ["rgb_static", "rgb_gripper"],
+        },
+        "tokyo_u_lsmo_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 50,
+            "steps": 11925,
+            "visual_observation_keys": ["image"],
+        },
+        "toto": {"versions": ["0.1.0"], "episodes": 902, "steps": 294139, "visual_observation_keys": ["image"]},
+        "ucsd_kitchen_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 150,
+            "steps": 3970,
+            "visual_observation_keys": ["image"],
+        },
+        "ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 1355,
+            "steps": 67750,
+            "visual_observation_keys": ["image"],
+        },
+        "uiuc_d3field": {  # this dataset seems to be corrupted
+            "versions": ["0.1.0", "1.1.2"],
+            "episodes": 196,
+            "steps": 13384,
+            "visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
+        },
+        "usc_cloth_sim_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 800,
+            "steps": 80000,
+            "visual_observation_keys": ["image"],
+        },
+        "utaustin_mutex": {
+            "versions": ["0.1.0"],
+            "episodes": 1500,
+            "steps": 361883,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 64,
+            "steps": 9140,
+            "visual_observation_keys": ["image"],
+        },
+        "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 192,
+            "steps": 26346,
+            "visual_observation_keys": ["image"],
+        },
+        "utokyo_saytap_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 20,
+            "steps": 22937,
+            "visual_observation_keys": ["image", "wrist_image"],
+        },
+        "utokyo_xarm_bimanual_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 64,
+            "steps": 1388,
+            "visual_observation_keys": ["image"],
+        },
+        "utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
+            "versions": ["0.1.0"],
+            "episodes": 92,
+            "steps": 6789,
+            "visual_observation_keys": ["image", "hand_image", "image2"],
+        },
+        "viola": {
+            "versions": ["0.1.0"],
+            "episodes": 135,
+            "steps": 68913,
+            "visual_observation_keys": ["agentview_rgb", "eye_in_hand_rgb"],
+        },
+    }
+)
+
+
+def oxe_dsname2path(dataset_name: str, version: Optional[str] = None) -> str:
+    """From dataset name to remote google clound path to the dataset.
+
+    Args:
+        dataset_name (str): dataset name.
+        version (Optional[str]): version string.
+
+    Returns:
+        str: google clound path
+    """
+    if version is None:
+        version = ALL_OXE_DATASETS[dataset_name]["versions"][-1]
+    return f"gs://gresearch/robotics/{dataset_name}/{version}"
diff --git a/theia/dataset/oxe/oxe_mixes.py b/theia/dataset/oxe/oxe_mixes.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bb793ce649b8d8f7f689c21363a464189910371
--- /dev/null
+++ b/theia/dataset/oxe/oxe_mixes.py
@@ -0,0 +1,139 @@
+# File modified. Modifications Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+"""MIT License Copyright (c) 2023 Robotic AI & Learning Lab Berkeley
+
+From Octo https://github.com/octo-models/octo/blob/main/octo/data/oxe/oxe_dataset_mixes.py
+"""
+
+BRIDGE_MIX = [
+    ("bridge_dataset", 1.0),
+]
+
+RT_X_MIX = [
+    ("fractal20220817_data", 0.54087122203),
+    ("kuka", 0.8341046294),
+    ("bridge_dataset", 1.0),
+    ("taco_play", 2.0),
+    ("jaco_play", 2.0),
+    ("berkeley_cable_routing", 3.0),
+    ("roboturk", 1.0),
+    ("nyu_door_opening_surprising_effectiveness", 5.0),
+    ("viola", 2.0),
+    ("berkeley_autolab_ur5", 1.0),
+    ("toto", 1.0),
+]
+
+
+OXE_FRANKA_MIX = [
+    ("taco_play", 1.0),
+    ("berkeley_cable_routing", 1.0),
+    ("viola", 1.0),
+    ("toto", 1.0),
+    ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
+    ("austin_buds_dataset_converted_externally_to_rlds", 3.0),
+    ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
+    ("maniskill_dataset_converted_externally_to_rlds", 0.1),
+    ("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
+    ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0),
+    ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
+    ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
+    ("berkeley_rpt_converted_externally_to_rlds", 1.0),
+    ("kaist_nonprehensile_converted_externally_to_rlds", 3.0),
+    ("stanford_robocook_converted_externally_to_rlds", 1.0),
+    ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
+    ("utaustin_mutex", 1.0),
+    # ("cmu_playing_with_food", 1.0),
+    ("cmu_play_fusion", 1.0),
+]
+
+OXE_MAGIC_SOUP = [
+    ("fractal20220817_data", 0.54087122203),
+    ("kuka", 0.8341046294),
+    ("bridge", 1.0),
+    ("taco_play", 2.0),
+    ("jaco_play", 1.0),
+    ("berkeley_cable_routing", 1.0),
+    ("roboturk", 2.0),
+    ("nyu_door_opening_surprising_effectiveness", 1.0),
+    ("viola", 2.0),
+    ("berkeley_autolab_ur5", 2.0),
+    ("toto", 1.0),
+    ("language_table", 0.1),
+    ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0),
+    ("austin_buds_dataset_converted_externally_to_rlds", 1.0),
+    ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
+    ("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
+    ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0),
+    ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
+    ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
+    ("bc_z", 0.2),
+    ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
+    ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
+    # ("uiuc_d3field", 1.0),  --> somehow raw data is broken
+    ("utaustin_mutex", 1.0),
+    ("berkeley_fanuc_manipulation", 2.0),
+    ("cmu_stretch", 1.0),
+]
+
+
+OXE_FULL_MIX = [
+    ("fractal20220817_data", 1.0),
+    ("kuka", 1.0),
+    ("bridge_dataset", 1),
+    ("taco_play", 1.0),
+    ("jaco_play", 1.0),
+    ("berkeley_cable_routing", 1.0),
+    ("roboturk", 1.0),
+    ("nyu_door_opening_surprising_effectiveness", 1.0),
+    ("viola", 1.0),
+    ("berkeley_autolab_ur5", 1.0),
+    ("toto", 1.0),
+    ("language_table", 1.0),
+    ("columbia_cairlab_pusht_real", 1.0),
+    ("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0),
+    ("nyu_rot_dataset_converted_externally_to_rlds", 1.0),
+    ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
+    ("austin_buds_dataset_converted_externally_to_rlds", 1.0),
+    ("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0),
+    ("maniskill_dataset_converted_externally_to_rlds", 1.0),
+    ("furniture_bench_dataset_converted_externally_to_rlds", 1.0),
+    ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0),
+    ("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0),
+    ("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0),
+    ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
+    ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
+    ("bc_z", 1.0),
+    ("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0),
+    ("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0),
+    ("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0),
+    ("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0),
+    ("robo_net", 1.0),
+    ("berkeley_mvp_converted_externally_to_rlds", 1.0),
+    ("berkeley_rpt_converted_externally_to_rlds", 1.0),
+    ("kaist_nonprehensile_converted_externally_to_rlds", 1.0),
+    ("stanford_mask_vit_converted_externally_to_rlds", 1.0),
+    ("tokyo_u_lsmo_converted_externally_to_rlds", 1.0),
+    ("dlr_sara_pour_converted_externally_to_rlds", 1.0),
+    ("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0),
+    ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
+    ("asu_table_top_converted_externally_to_rlds", 1.0),
+    ("stanford_robocook_converted_externally_to_rlds", 1.0),
+    ("imperialcollege_sawyer_wrist_cam", 1.0),
+    ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
+    ("uiuc_d3field", 1.0),
+    ("utaustin_mutex", 1.0),
+    ("berkeley_fanuc_manipulation", 1.0),
+    ("cmu_playing_with_food", 1.0),
+    ("cmu_play_fusion", 1.0),
+    ("cmu_stretch", 1.0),
+    ("berkeley_gnm_recon", 1.0),
+    ("berkeley_gnm_cory_hall", 1.0),
+    ("berkeley_gnm_sac_son", 1.0),
+]
+
+OXE_NAMED_MIXES = {
+    "bridge": BRIDGE_MIX,
+    "rtx": RT_X_MIX,
+    "rtx_franka": RT_X_MIX + OXE_FRANKA_MIX,
+    "oxe_magic_soup": OXE_MAGIC_SOUP,
+}
diff --git a/theia/dataset/oxe/oxe_transforms.py b/theia/dataset/oxe/oxe_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..43cca3118bc0bd419388faee57cd6dcdd89f62ec
--- /dev/null
+++ b/theia/dataset/oxe/oxe_transforms.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import torch
+from numpy.typing import NDArray
+from torchvision.transforms.v2 import Compose, Normalize, ToDtype, ToImage
+
+
+def totensor(arr: NDArray) -> torch.Tensor:
+    """Convert ndarray to tensor."""
+    return torch.from_numpy(arr)
+
+
+oxe_image_transform = Compose(
+    [ToImage(), ToDtype(torch.float32, scale=True), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
+)  # ImageNet statistics normalization
diff --git a/theia/dataset/video/__init__.py b/theia/dataset/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..04f7ecdb72ccb4f56e2b19f30f7c49d2245d8195
--- /dev/null
+++ b/theia/dataset/video/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from .video_common import ALL_VIDEO_DATASETS
diff --git a/theia/dataset/video/video_common.py b/theia/dataset/video/video_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b950ad31dbaa6a68ce0c7b8d0ca778bb8dcbd85
--- /dev/null
+++ b/theia/dataset/video/video_common.py
@@ -0,0 +1,11 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from collections import OrderedDict
+
+ALL_VIDEO_DATASETS = OrderedDict(
+    {
+        "ego4d_1in150": {"steps": 2_800_871},
+        "epic_kitchen_1in60": {"steps": 333_117},
+        "ssv2_1in32": {"steps": 312_772},
+    }
+)
diff --git a/theia/decoding/__init__.py b/theia/decoding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..af4707e751ee70d8a72adb50df0ef4e559b320fc
--- /dev/null
+++ b/theia/decoding/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from .decode import decode_everything, load_feature_stats
+from .depth_anything import prepare_depth_decoder
+from .sam import prepare_mask_generator
diff --git a/theia/decoding/decode.py b/theia/decoding/decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..99a387fd136b65f63425269f22d94fa0eaec41d1
--- /dev/null
+++ b/theia/decoding/decode.py
@@ -0,0 +1,198 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import os
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from numpy.typing import NDArray
+from PIL import Image
+from sklearn.decomposition import PCA
+from transformers import SamModel, SamProcessor
+from transformers.pipelines import MaskGenerationPipeline
+
+from theia.decoding.depth_anything import decode_depth_anything
+from theia.decoding.dinov2 import decode_dinov2
+from theia.decoding.sam import decode_sam
+from theia.preprocessing.feature_extraction_core import (
+    get_feature_outputs,
+    get_model,
+)
+
+
+def denormalize_feature(
+    x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+    """Denormalize the features using mean and std.
+
+    Args:
+        x (torch.Tensor): features to be denomalized.
+        mean (Optional[torch.Tensor], optional): mean value of the features. Defaults to None
+        std (Optional[torch.Tensor], optional): std value of the features. Defaults to None.
+
+    Returns:
+        torch.Tensor: denormalized features.
+    """
+    if mean is None and std is None:
+        return x
+    elif mean is None and std is not None:
+        return x * std
+    elif mean is not None and std is None:
+        return x + mean
+    return x * std + mean
+
+
+def load_feature_stats(
+    feature_models: list[str], stat_file_root: str
+) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
+    """Load the statistics (mean and variance) of the features, per model.
+
+    Args:
+        feature_models (list[str]): names of the models. Note: there are `/` in the name.
+        stat_file_root (str): directory that holds feature stat files.
+
+    Returns:
+        tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variance.
+    """
+    feature_means: dict[str, torch.Tensor] = {}
+    feature_vars: dict[str, torch.Tensor] = {}
+    for model in feature_models:
+        model_name = model.replace("/", "_")
+        feature_means[model] = torch.from_numpy(
+            np.load(os.path.join(stat_file_root, f"imagenet_mean_{model_name}.npy"))
+        )
+        feature_vars[model] = torch.from_numpy(np.load(os.path.join(stat_file_root, f"imagenet_var_{model_name}.npy")))
+    return feature_means, feature_vars
+
+
+def decode_everything(
+    theia_model: nn.Module,
+    feature_means: dict[str, torch.Tensor],
+    feature_vars: dict[str, torch.Tensor],
+    images: list[Image.Image],
+    mask_generator: MaskGenerationPipeline,
+    sam_model: SamModel,
+    depth_anything_decoder: nn.Module,
+    pred_iou_thresh: float = 0.9,
+    stability_score_thresh: float = 0.9,
+    gt: bool = False,
+    pca: Optional[PCA] = None,
+    device: int | str | torch.device = 0,
+) -> tuple[list[NDArray], Optional[list[NDArray]]]:
+    """Decode features from given `theia_model` into different outputs corresponding to upstream models including
+        DINOv2, Sam, and Depth-Anything.
+
+    Args:
+        theia_model (nn.Module): theia model.
+        feature_means (dict[str, torch.Tensor]): means of the features for denormalization.
+        feature_vars (dict[str, torch.Tensor]): variance of the features for denormalization.
+        images (list[Image.Image]): input images.
+        mask_generator (MaskGenerationPipeline): mask generation pipeline.
+        sam_model (SamModel): sam model.
+        depth_anything_decoder (nn.Module): depth anything decoder.
+        pred_iou_thresh (float, optional): iou threshold for mask generation.
+            See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
+        stability_score_thresh (float, optional): stability score threshold for mask generation.
+            See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
+        gt (bool): whether to attach ground truth result in the visualization. Defaults to False.
+        pca (Optional[PCA]): pca for DINOv2 decoding. If provided, will use this pca particular. Defaults to None.
+        device (int | str | torch.device, optional): device for decoding. Defaults to 0.
+
+    Returns:
+        tuple[list[NDArray], Optional[list[NDArray]]]: decoding results from given model,
+            and ground truth (if `gt=True`).
+    """
+    features: dict[str, torch.Tensor] = {}
+    with torch.no_grad():
+        for im in images:
+            feature = theia_model([im])
+            if len(features) == 0:
+                features = {k: [] for k in feature}
+            for k in feature:
+                features[k].append(feature[k].detach().cpu())
+    for k in features:
+        features[k] = torch.cat(features[k], dim=0)
+    for m in features:
+        features[m] = denormalize_feature(features[m], feature_means[m], feature_vars[m])
+
+    dino_model_name = "facebook/dinov2-large"
+    sam_model_name = "facebook/sam-vit-huge"
+    depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
+
+    pca = None
+    # gt
+    gt_decode_results = None
+    if gt:
+        def legit_model_name(model_name: str) -> str:
+            return model_name.replace("/", "_")
+
+        dino_model, dino_processor = get_model(dino_model_name, device=device)
+        dino_gt_feature = []
+        for im in images:
+            dino_gt_feature.append(
+                get_feature_outputs(
+                    legit_model_name(dino_model_name), dino_model, dino_processor, [im], dtype=torch.float
+                )[legit_model_name(dino_model_name)]["embedding"]
+                .detach()
+                .cpu()
+            )
+        dino_gt_feature = torch.cat(dino_gt_feature, dim=0)
+        dino_gt_feature = rearrange(dino_gt_feature, "b c h w -> b (h w) c")
+        dino_gt_dec, pca = decode_dinov2(dino_gt_feature, pca=pca)
+        sam_processor = SamProcessor.from_pretrained(sam_model_name)
+        sam_gt_feature = []
+        for im in images:
+            sam_inputs = sam_processor(images=[im], return_tensors="pt").to(device)
+            with torch.no_grad():
+                sam_gt_feature.append(sam_model.get_image_embeddings(sam_inputs["pixel_values"]).detach().cpu())
+        sam_gt_feature = torch.cat(sam_gt_feature, dim=0)
+        sam_gt_feature = rearrange(sam_gt_feature, "b c h w -> b (h w) c")
+        sam_gt_dec = decode_sam(
+            sam_gt_feature, images, mask_generator, pred_iou_thresh=0.9, stability_score_thresh=0.9, device=device
+        )
+        depth_anything_model, depth_anything_processor = get_model(depth_anything_model_name, device=device)
+        depth_anything_gt_feature = []
+        for im in images:
+            depth_anything_gt_feature.append(
+                get_feature_outputs(
+                    legit_model_name(depth_anything_model_name),
+                    depth_anything_model,
+                    depth_anything_processor,
+                    [im],
+                    dtype=torch.float,
+                )[legit_model_name(depth_anything_model_name)]["embedding"]
+                .detach()
+                .cpu()
+            )
+        depth_anything_gt_feature = torch.cat(depth_anything_gt_feature, dim=0)
+        depth_anything_gt_feature = rearrange(depth_anything_gt_feature, "b c h w -> b (h w) c")
+        depth_gt_dec = decode_depth_anything(depth_anything_gt_feature, depth_anything_decoder, device=device)
+
+        gt_decode_results = [
+            np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_gt_dec[i], sam_gt_dec[i], depth_gt_dec[i]])
+            for i in range(len(images))
+        ]
+
+    dino_dec, _ = decode_dinov2(features[dino_model_name], pca=pca)
+        
+    try:
+        sam_dec = decode_sam(
+            features[sam_model_name],
+            images,
+            mask_generator,
+            pred_iou_thresh=pred_iou_thresh,
+            stability_score_thresh=stability_score_thresh,
+            device=device,
+        )
+    except IndexError:
+        sam_dec = np.zeros_like(dino_dec)
+    depth_dec = decode_depth_anything(features[depth_anything_model_name], depth_anything_decoder, device=device)
+
+    theia_decode_results = [
+        np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_dec[i], sam_dec[i], depth_dec[i]])
+        for i in range(len(images))
+    ]
+
+    return theia_decode_results, gt_decode_results
diff --git a/theia/decoding/depth_anything.py b/theia/decoding/depth_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..26836d53966d4f676afe02d62ec30f78d148cce9
--- /dev/null
+++ b/theia/decoding/depth_anything.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from theia.foundation_models.vision_models.depth_anything import DepthAnythingForDepthEstimation
+from numpy.typing import NDArray
+from torch.nn.functional import interpolate
+
+
+def prepare_depth_decoder(model_name: str, device: int | str | torch.device = 0) -> tuple[nn.Module, int]:
+    """Prepare a depth decoder using DepthAnythingForDepthEstimation.
+
+    Args:
+        model_name (str): name of the depth anything model.
+        device (int | str | torch.device, optional): device to put the model on. Defaults to 0.
+
+    Returns:
+        tuple[nn.Module, int]: the decoder, and the patch size for depth anything model.
+    """
+    decoder_head = DepthAnythingForDepthEstimation.from_pretrained(model_name)
+    patch_size = decoder_head.config.patch_size
+    decoder_head = decoder_head.head
+    decoder_head = decoder_head.to(device)
+    return decoder_head, patch_size
+
+
+def decode_depth_anything(features: torch.Tensor, decoder: nn.Module, device: int | str | torch.device = 0) -> NDArray:
+    """Decode features to predicted depth using depth anything
+
+    Args:
+        features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
+        decoder (nn.Module): depth anything decoder
+        device (int | str | torch.device, optional): device to perform the decoding. Defaults to 0.
+
+    Returns:
+        NDArray: decoded depth in image format, represented by an NDArray in size [batch_size, height, width, channels]
+            with value between [0, 1]. The depth values are min-max normalized to [0, 1] to generate images.
+    """
+    with torch.no_grad():
+        P = int(features.size(1) ** 0.5)
+        features = rearrange(features, "b (h w) c -> b c h w", h=P, w=P)
+        features = interpolate(features, (224, 224))
+        predicted_depths = []
+        for feature in features:
+            feature = feature.unsqueeze(0).to(device)
+
+            predicted_depth = decoder.activation1(feature)
+            predicted_depth = decoder.conv3(predicted_depth)
+            predicted_depth = decoder.activation2(predicted_depth)
+            predicted_depth = predicted_depth.squeeze(dim=1)  # shape (batch_size, height, width)
+            for i in range(len(predicted_depth)):
+                min_depth, max_depth = predicted_depth[i].min(), predicted_depth[i].max()
+                predicted_depth[i] = (predicted_depth[i] - min_depth) / (max_depth - min_depth)
+            predicted_depths.append(predicted_depth.detach().cpu())
+        predicted_depths = torch.cat(predicted_depths, dim=0)
+    return predicted_depths.unsqueeze(-1).repeat((1, 1, 1, 3)).numpy()  # type: ignore [attr-defined]
diff --git a/theia/decoding/dinov2.py b/theia/decoding/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..29c4801c7846dc7f7767993472ca8c59033a94c6
--- /dev/null
+++ b/theia/decoding/dinov2.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Optional
+
+import cv2
+import numpy as np
+from numpy.typing import NDArray
+from sklearn.decomposition import PCA
+from sklearn.preprocessing import minmax_scale
+
+
+def decode_dinov2(
+    features: NDArray, threshold: int | float = -100, interpolation: bool = False, pca: Optional[PCA] = None
+) -> tuple[NDArray, PCA]:
+    """
+    Decode the input `features` in DINOv2 style using PCA.
+
+    Args:
+        features (NDArray): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
+        threshold (int | float): threshold of foreground-background split in PCA visualization.
+            Defaults to -100 (all patches are included).
+        interpolation (bool): whether interpolate the 16x16 pca map to the original image size.
+        pca (Optional[PCA]): if provided, use the provided PCA. This is to keep visualizations stable across samples.
+
+    Returns:
+        tuple[NDArray, PCA]: the rendered image of this visualization, in NDArray in size
+            [batch_size, height, width, channels] with value ranges [0, 1], and the PCA used in this visualization.
+    """
+    features = features.numpy()
+    batch_size, spatial_size, latent_dim = features.shape
+    h = w = int(spatial_size**0.5)
+
+    features = features.reshape(-1, latent_dim)
+
+    if pca is None:
+        pca = PCA(n_components=3)
+        pca.fit(features)
+
+    pca_features = pca.transform(features)
+
+    # segment using the first component
+    bg_mask = pca_features[:, 0] < threshold
+    fg_mask = ~bg_mask
+
+    # PCA for only foreground patches
+    # pca.fit(features[fg_mask])
+    pca_features_fg = pca.transform(features[fg_mask])
+    for i in range(3):
+        pca_features_fg[:, i] = minmax_scale(pca_features_fg[:, i])
+
+    pca_features_rgb = pca_features.copy()
+    pca_features_rgb[bg_mask] = 0
+    pca_features_rgb[fg_mask] = pca_features_fg
+
+    pca_features_rgb = pca_features_rgb.reshape(batch_size, h, w, 3)
+    if not interpolation:
+        H = W = 224
+        scale = H // h
+        interpolated_pca_features = np.zeros((batch_size, H, W, 3), dtype=pca_features_rgb.dtype)
+        for i in range(len(pca_features_rgb)):
+            for j in range(h):
+                for k in range(w):
+                    interpolated_pca_features[i, scale * j : scale * (j + 1), scale * k : scale * (k + 1)] = (
+                        pca_features_rgb[i, j, k]
+                    )
+        pca_features_rgb = interpolated_pca_features
+    else:
+        pca_features_rgb = np.stack([cv2.resize(p, (224, 224)) for p in pca_features_rgb])
+    return pca_features_rgb, pca
diff --git a/theia/decoding/sam.py b/theia/decoding/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..476d9c4af6c829a42d1621886beb7a8a807976d1
--- /dev/null
+++ b/theia/decoding/sam.py
@@ -0,0 +1,191 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Any, Generator, Optional
+
+import numpy as np
+import torch
+from einops import rearrange
+from numpy.typing import NDArray
+from PIL import Image
+from transformers import SamModel, SamProcessor
+from transformers.image_utils import load_image
+from transformers.pipelines import MaskGenerationPipeline
+
+
+class MaskGenerationPipelineWithEmbeddings(MaskGenerationPipeline):
+    """
+    The wrapper class for huggingface transformers.pipelines.MaskGenerationPipeline
+        that can decode from intermediate SAM embeddings.
+    """
+
+    def _sanitize_parameters(self, **kwargs: Any) -> tuple[dict[str, Any], ...]:
+        preprocess_kwargs = {}
+        postprocess_kwargs = {}
+        forward_params = {}
+        # preprocess args
+        if "embeddings" in kwargs:  # inject embeddings here
+            preprocess_kwargs["embeddings"] = kwargs["embeddings"]
+        if "points_per_batch" in kwargs:
+            preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
+        if "points_per_crop" in kwargs:
+            preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
+        if "crops_n_layers" in kwargs:
+            preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
+        if "crop_overlap_ratio" in kwargs:
+            preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
+        if "crop_n_points_downscale_factor" in kwargs:
+            preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
+        if "timeout" in kwargs:
+            preprocess_kwargs["timeout"] = kwargs["timeout"]
+        # postprocess args
+        if "pred_iou_thresh" in kwargs:
+            forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
+        if "stability_score_offset" in kwargs:
+            forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
+        if "mask_threshold" in kwargs:
+            forward_params["mask_threshold"] = kwargs["mask_threshold"]
+        if "stability_score_thresh" in kwargs:
+            forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
+        if "crops_nms_thresh" in kwargs:
+            postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
+        if "output_rle_mask" in kwargs:
+            postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
+        if "output_bboxes_mask" in kwargs:
+            postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
+        return preprocess_kwargs, forward_params, postprocess_kwargs
+
+    def preprocess(
+        self,
+        image: list[Image.Image],
+        points_per_batch: int = 64,
+        crops_n_layers: int = 0,
+        crop_overlap_ratio: float = 512 / 1500,
+        points_per_crop: int = 32,
+        crop_n_points_downscale_factor: int = 1,
+        timeout: Optional[float] = None,
+        embeddings: Optional[torch.Tensor] = None,
+    ) -> Generator[Any, Any, Any]:
+        image = load_image(image, timeout=timeout)
+        target_size = self.image_processor.size["longest_edge"]
+        crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
+            image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
+        )
+        model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
+
+        with self.device_placement():
+            if self.framework == "pt":
+                inference_context = self.get_inference_context()
+                with inference_context():
+                    model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
+                    if embeddings is None:
+                        image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
+                    else:
+                        model_inputs.pop("pixel_values")
+                        image_embeddings = embeddings
+                    model_inputs["image_embeddings"] = image_embeddings
+
+        n_points = grid_points.shape[1]
+        points_per_batch = points_per_batch if points_per_batch is not None else n_points
+
+        if points_per_batch <= 0:
+            raise ValueError(
+                "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
+                "To return all points at once, set points_per_batch to None"
+            )
+
+        for i in range(0, n_points, points_per_batch):
+            batched_points = grid_points[:, i : i + points_per_batch, :, :]
+            labels = input_labels[:, i : i + points_per_batch]
+            is_last = i == n_points - points_per_batch
+            yield {
+                "input_points": batched_points,
+                "input_labels": labels,
+                "input_boxes": crop_boxes,
+                "is_last": is_last,
+                **model_inputs,
+            }
+
+
+def draw_mask(mask: NDArray, random_color: bool = False) -> NDArray:
+    """Draw the mask on an image.
+
+    Args:
+        mask (NDArray): mask in shape [height, width].
+        random_color (bool): if using a random color. Defaults to False.
+
+    Returns:
+        NDArray: NDArray format of the image.
+    """
+    if random_color:
+        color = np.concatenate([np.random.random(3)], axis=0)
+    else:
+        color = np.array([30 / 255, 144 / 255, 255 / 255])
+    h, w = mask.shape[-2:]
+    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
+    return mask_image
+
+
+def decode_sam(
+    features: torch.Tensor,
+    images: list[Image.Image],
+    mask_generator: Any,
+    points_per_batch: int = 64,
+    pred_iou_thresh: float = 0.5,
+    stability_score_thresh: float = 0.6,
+    random_color: bool = True,
+    device: int | str | torch.device = 0,
+) -> NDArray:
+    """Decode features using SAM (auto-prompting) mask generation pipeline.
+
+    Args:
+        features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
+        images (list[Image.Image]): images corresponding to these features.
+        mask_generator (Any): mask generation pipeline.
+        points_per_batch (int): points per batch for auto-prompting. Defaults to 64.
+            See transformers.pipelines.MaskGenerationPipeline for more details. Same below.
+        pred_iou_thresh (float): iou threshold. Defaults to 0.5.
+        stability_score_thresh (float): stability threshold. Defaults to 0.6.
+        random_color (bool): if using a random color. Defaults to True.
+        device (int | str | torch.device): device to perform the decoding. Defaults to 0.
+
+    Returns:
+        NDArray: decoded masks rendered in image format, represented by an NDArray in size
+            [batch_size, height, width, channels] with value between [0, 1].
+    """
+    masks_rgbs = []
+    num_patches = int(features.size(1) ** 0.5)
+    features = rearrange(features, "b (h w) c -> b c h w", h=num_patches, w=num_patches)
+    with torch.no_grad():
+        for im, feature in zip(images, features, strict=False):
+            predicted_ouputs = mask_generator(
+                im,
+                points_per_batch=points_per_batch,
+                embeddings=feature.unsqueeze(0).to(device),
+                pred_iou_thresh=pred_iou_thresh,
+                stability_score_thresh=stability_score_thresh,
+            )
+            predicted_masks = predicted_ouputs["masks"]
+            masks_rgb = np.zeros((224, 224, 3), dtype=np.float32)
+            for mask in predicted_masks:
+                masks_rgb += draw_mask(mask, random_color=random_color)
+            # masks_rgb = cv2.cvtColor(masks_rgb, cv2.COLOR_RGBA2RGB)
+            masks_rgbs.append(masks_rgb)
+    return np.stack(masks_rgbs)
+
+
+def prepare_mask_generator(device: int | str | torch.device = 0) -> MaskGenerationPipeline:
+    """Prepare a mask generation pipeline on device `device`.
+
+    Args:
+        device (int | str | torch.device): device to perform mask generation. Defaults to 0.
+
+    Returns:
+        MaskGenerationPipeline: mask generator.
+    """
+    sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
+    processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
+    sam_model.eval()
+    mask_generator = MaskGenerationPipelineWithEmbeddings(
+        task="mask_generation", model=sam_model, image_processor=processor.image_processor, device=device
+    )
+    return mask_generator, sam_model
diff --git a/theia/example/decode_to_vfms.ipynb b/theia/example/decode_to_vfms.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..5f5aae323740508df6e30289fc9979c1960fd50c
--- /dev/null
+++ b/theia/example/decode_to_vfms.ipynb
@@ -0,0 +1,69 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import cv2\n",
+    "import torch\n",
+    "from PIL import Image\n",
+    "import numpy as np\n",
+    "from transformers import AutoModel\n",
+    "from torchvision.io import read_video, write_video\n",
+    "from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything\n",
+    "\n",
+    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
+    "theia_model = AutoModel.from_pretrained(\"theaiinstitute/theia-base-patch16-224-cdiv\", trust_remote_code=True)\n",
+    "theia_model = theia_model.to(device)\n",
+    "target_model_names = [\n",
+    "    \"google/vit-huge-patch14-224-in21k\",\n",
+    "    \"facebook/dinov2-large\",\n",
+    "    \"openai/clip-vit-large-patch14\",\n",
+    "    \"facebook/sam-vit-huge\",\n",
+    "    \"LiheYoung/depth-anything-large-hf\",\n",
+    "]\n",
+    "feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root=\"../../../feature_stats\")\n",
+    "\n",
+    "mask_generator, sam_model = prepare_mask_generator(device)\n",
+    "depth_anything_model_name = \"LiheYoung/depth-anything-large-hf\"\n",
+    "depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device)\n",
+    "\n",
+    "example_video_path = \"../../../media/example_video_to_visualize.mp4\"\n",
+    "video, _, _ = read_video(example_video_path, pts_unit=\"sec\", output_format=\"THWC\")\n",
+    "video = video.numpy()\n",
+    "images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video]\n",
+    "\n",
+    "theia_decode_results, gt_decode_results = decode_everything(\n",
+    "    theia_model=theia_model,\n",
+    "    feature_means=feature_means,\n",
+    "    feature_vars=feature_vars,\n",
+    "    images=images,\n",
+    "    mask_generator=mask_generator,\n",
+    "    sam_model=sam_model,\n",
+    "    depth_anything_decoder=depth_anything_decoder,\n",
+    "    pred_iou_thresh=0.5,\n",
+    "    stability_score_thresh=0.7,\n",
+    "    gt=True,\n",
+    "    device=device,\n",
+    ")\n",
+    "\n",
+    "vis_video = np.stack(\n",
+    "    [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]\n",
+    ")\n",
+    "vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8)\n",
+    "vis_save_path = \"./visualized.mp4\"\n",
+    "write_video(vis_save_path, vis_video, fps=10)"
+   ]
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "name": "python"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/theia/foundation_models/__init__.py b/theia/foundation_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..adcb4c57dd9d8d87bc7c4e0fbc4a9da7a0e5cc5d
--- /dev/null
+++ b/theia/foundation_models/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from .vision_language_models.clip import get_clip_feature, get_clip_model
+from .vision_language_models.llava import get_llava_vision_model, get_llava_visual_feature
+from .vision_models.deit import get_deit_feature, get_deit_model
+from .vision_models.depth_anything import get_depth_anything_feature, get_depth_anything_model
+from .vision_models.dinov2 import get_dinov2_feature, get_dinov2_model
+from .vision_models.sam import get_sam_feature, get_sam_model
+from .vision_models.vit import get_vit_feature, get_vit_model
diff --git a/theia/foundation_models/common.py b/theia/foundation_models/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..0358b35ebdae4129df80a09dc9e46d991e3d61e3
--- /dev/null
+++ b/theia/foundation_models/common.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import math
+
+import torch
+
+MODELS = [
+    "facebook/dinov2-large",
+    "facebook/sam-vit-huge",
+    "google/vit-huge-patch14-224-in21k",
+    "llava-hf/llava-1.5-7b-hf",
+    "openai/clip-vit-large-patch14",
+    "LiheYoung/depth-anything-large-hf",
+]
+
+# handy model feature size constants
+# in the format of (latent_dim, width, height)
+MODEL_FEATURE_SIZES = {
+    "facebook/dinov2-large": (1024, 16, 16),
+    "facebook/sam-vit-huge": (256, 64, 64),
+    "google/vit-huge-patch14-224-in21k": (1280, 16, 16),
+    "llava-hf/llava-1.5-7b-hf": (1024, 24, 24),
+    "openai/clip-vit-large-patch14": (1024, 16, 16),
+    "LiheYoung/depth-anything-large-hf": (32, 64, 64),
+}
+
+
+def get_model_feature_size(
+    model_name: str, keep_spatial: bool = False, return_torch_size: bool = False
+) -> tuple[int, ...] | torch.Size:
+    """
+    Get the size of queried model feature.
+
+    Args:
+        model_name (str): name of the model.
+        keep_spatial (bool): whether to preserve spatial dim. Defaults to False.
+        return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
+
+    Returns:
+        tuple[int, ...] | torch.Size: the size of the feature.
+    """
+    size: tuple[int, ...] = MODEL_FEATURE_SIZES[model_name]
+
+    if not keep_spatial:
+        size = (size[0], math.prod(size[1:]))
+
+    if return_torch_size:
+        size = torch.Size(size)
+
+    return size
+
+
+def get_max_model_spatial_size(
+    keep_spatial: bool = True,
+    return_torch_size: bool = False,
+    return_model_name: bool = False,
+) -> tuple[int, ...] | tuple[tuple[int, ...], str]:
+    """Get the maximal spatial dimensions from available models
+
+    Args:
+        keep_spatial (bool): whether to preserve spatial dim. Defaults to True.
+        return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
+        return_model_name (bool): the name of the model with maximal size. Defaults to False.
+
+    Returns:
+        tuple[int, ...] | tuple[tuple[int, ...], str]: the maximal size and optional model name.
+    """
+    max_flatten_size = -1
+    max_size: tuple[int, ...] = ()
+    max_size_model_name: str = ""
+    for model, size in MODEL_FEATURE_SIZES.items():
+        flatten_size = math.prod(size[1:])
+        if flatten_size > max_flatten_size:
+            max_flatten_size = flatten_size
+            max_size = size[1:]
+            max_size_model_name = model
+
+    if not keep_spatial:
+        max_size = (max_flatten_size,)
+
+    if return_torch_size:
+        max_size = torch.Size(max_size)
+
+    if return_model_name:
+        return max_size, max_size_model_name
+    else:
+        return max_size
diff --git a/theia/foundation_models/vision_language_models/__init__.py b/theia/foundation_models/vision_language_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/foundation_models/vision_language_models/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/foundation_models/vision_language_models/clip.py b/theia/foundation_models/vision_language_models/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb54c1522e220637ced52204b5ec5ce6a2054466
--- /dev/null
+++ b/theia/foundation_models/vision_language_models/clip.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import numpy as np
+import torch
+from transformers import AutoProcessor, CLIPVisionModel
+
+
+def get_clip_feature(
+    model: CLIPVisionModel, processor: AutoProcessor, images: list[np.ndarray], requires_grad: bool = False
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Get features from the visual encoder of CLIP.
+
+    Args:
+        model (CLIPVisionModel): CLIP model.
+        processor (AutoProcessor): CLIP input processor.
+        images (list[np.ndarray]): images to be encoded, in RGB, uint8.
+        requires_grad (bool): maintains gradient. Defaults to False.
+
+    Returns:
+        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: features from clip (
+            cls_token:      last layer embedding from cls token         # (1, 1, 1024) if vit-large,
+            visual_tokens:  last layer embeddings from image            # (1, 1024, 16, 16) BCHW if vit-large,
+            pooled_cls_token: last layer embedding from cls + layernorm # (1, 1, 1024) if vit-large
+        )
+    """
+    inputs = processor(images=images, return_tensors="pt").to(model.device)
+    if requires_grad:
+        outputs = model(**inputs)
+    else:
+        with torch.no_grad():
+            outputs = model(**inputs)
+    cls_token = outputs.last_hidden_state[:, :1]  # (1, 1, 1024) if vit-large
+    visual_tokens = outputs.last_hidden_state[:, 1:]  # (1, 256, 1024) if vit-large
+    pooled_cls_token = outputs.pooler_output.unsqueeze(1)  # (1, 1, 1024) if vit-large
+    batch_size, num_patches, num_channels = visual_tokens.size()
+    visual_tokens = visual_tokens.transpose(1, 2)
+    visual_tokens = visual_tokens.reshape(
+        batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches))
+    )  # (1, 1024, 16, 16) BCHW for vit-huge
+    return cls_token, visual_tokens, pooled_cls_token
+
+
+def get_clip_model(
+    model_name: str = "openai/clip-vit-large-patch14", device: str | torch.device = "cuda"
+) -> tuple[CLIPVisionModel, AutoProcessor]:
+    """Get CLIP model and its input processor.
+
+    Args:
+        model_name (str, optional): name of CLIP model. Defaults to "openai/clip-vit-large-patch14".
+        device (str | torch.device, optional): device to put the model on. Defaults to "cuda".
+
+    Returns:
+        tuple[CLIPVisionModel, AutoProcessor]: CLIP model and the correponding input processor.
+    """
+    processor = AutoProcessor.from_pretrained(model_name)
+    model = CLIPVisionModel.from_pretrained(model_name).to(device)
+    return model, processor
+
+
+def print_feature_size(model_name: str = "openai/clip-vit-large-patch14") -> None:
+    """Print the sizes of features from CLIP.
+
+    Args:
+        model_name (str, optional): the name of CLIP model. Defaults to "openai/clip-vit-large-patch14".
+    """
+    import requests
+    from PIL import Image
+
+    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+    image = [np.array(Image.open(requests.get(url, stream=True).raw))]
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model, processor = get_clip_model(model_name, device=device)
+    cls_token, visual_tokens, pooled_cls_token = get_clip_feature(model, processor, image)
+
+    print(model_name, cls_token.size(), visual_tokens.size(), pooled_cls_token.size())
+
+
+if __name__ == "__main__":
+    print_feature_size()
diff --git a/theia/foundation_models/vision_language_models/llava.py b/theia/foundation_models/vision_language_models/llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..87ea2f7aea12ef7a9c03cd73fc4c8e8110ea14ef
--- /dev/null
+++ b/theia/foundation_models/vision_language_models/llava.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+import torch
+from transformers import AutoProcessor, LlavaForConditionalGeneration
+from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
+
+
+@dataclass
+class LlavaVisualFeatureOutput(LlavaCausalLMOutputWithPast):
+    """Visual feature output for LLaVA.
+
+    Args:
+        visual_embeddings (Optional[torch.FloatTensor]): feature from visual encoder.
+    """
+
+    visual_embeddings: Optional[torch.FloatTensor] = None
+
+
+class LlavaVisualFeature(LlavaForConditionalGeneration):
+    """LLaVA model with only visual feature returned. Borrowed from transformers."""
+
+    # TODO: reduce VRAM use of language model part, because only vocabulary is used, not the whole model
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        pixel_values: torch.FloatTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[list[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        vision_feature_layer: Optional[int] = None,
+        vision_feature_select_strategy: Optional[str] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> tuple | LlavaVisualFeatureOutput:
+        """LLaVA visual encoder forward pass, from transformers package.
+
+        Returns:
+            tuple | LlavaVisualFeatureOutput: feature from visual encoder.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        vision_feature_layer = (
+            vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+        )
+        vision_feature_select_strategy = (
+            vision_feature_select_strategy
+            if vision_feature_select_strategy is not None
+            else self.config.vision_feature_select_strategy
+        )
+
+        image_features = None
+        if inputs_embeds is None:
+            inputs_embeds = self.get_input_embeddings()(input_ids)
+            if pixel_values is not None and input_ids.shape[1] != 1:
+                image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
+                # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
+                selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+
+                if vision_feature_select_strategy == "default":
+                    image_features = selected_image_feature[:, 1:]
+                elif vision_feature_select_strategy == "full":
+                    image_features = selected_image_feature
+                else:
+                    raise ValueError(
+                        f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
+                    )
+        return LlavaVisualFeatureOutput(visual_embeddings=image_features)
+
+
+def get_llava_visual_feature(
+    model: LlavaVisualFeature, processor: AutoProcessor, images: list[np.array], requires_grad: bool = False
+) -> torch.FloatTensor:
+    """Get the feature from the visual encoder of LLaVA.
+
+    Args:
+        model (LlavaVisualFeature): LLaVA model
+        processor (AutoProcessor): LLaVA input processor
+        images (list[np.array]): images to be encoded, in RGB, uint8
+        requires_grad (bool): maintains gradient. Defaults to False.
+
+    Returns:
+        torch.FloatTensor: LLaVA feature. (1, 1024, 24, 24) if using llava-7b
+    """
+    inputs = processor(text=["placeholder"], images=images, return_tensors="pt").to(model.device)
+    if requires_grad:
+        outputs = model(**inputs)
+    else:
+        with torch.no_grad():
+            outputs = model(**inputs)
+    batch_size, num_patches, num_channels = outputs.visual_embeddings.size()
+    visual_tokens = outputs.visual_embeddings.transpose(1, 2).reshape(
+        batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches))
+    )
+    return visual_tokens  # (1, 1024, 24, 24) if llava-7b
+
+
+def get_llava_vision_model(
+    model_name: str = "llava-hf/llava-1.5-7b-hf", device: str | torch.device = "cuda"
+) -> tuple[LlavaVisualFeature, AutoProcessor]:
+    """Get LLaVA model and its input processor.
+
+    Args:
+        model_name (str, optional): name of LLaVA model. Defaults to "llava-hf/llava-1.5-7b-hf".
+        device (str | torch.device, optional): device to put the model on. Defaults to "cuda".
+
+    Returns:
+        tuple[LlavaVisualFeature, AutoProcessor]: LLaVA model and the corresponding input processor.
+    """
+    model = LlavaVisualFeature.from_pretrained(model_name).to(device)
+    processor = AutoProcessor.from_pretrained(model_name)
+    return model, processor
+
+
+def print_feature_size(model_name: str = "llava-hf/llava-1.5-7b-hf") -> None:
+    """Print the size of the feature from LLaVA.
+
+    Args:
+        model_name (str, optional): the name of LLaVA model. Defaults to "llava-hf/llava-1.5-7b-hf".
+    """
+    from datasets import load_dataset
+
+    dataset = load_dataset("huggingface/cats-image")
+    image = dataset["test"]["image"][0]
+    image = [np.array(image)]
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model, processor = get_llava_vision_model(model_name=model_name, device=device)
+    feature = get_llava_visual_feature(model, processor, image)
+    print(model_name, feature.size())
+    # (1, 1024, 24, 24) if llava-7b
+
+
+if __name__ == "__main__":
+    print_feature_size()
diff --git a/theia/foundation_models/vision_models/__init__.py b/theia/foundation_models/vision_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/foundation_models/vision_models/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/foundation_models/vision_models/deit.py b/theia/foundation_models/vision_models/deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf4c618cc4ba308521051d28cd782a87984358a6
--- /dev/null
+++ b/theia/foundation_models/vision_models/deit.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import numpy as np
+import torch
+from transformers import AutoImageProcessor, AutoModel
+
+
+def get_deit_feature(
+    model: AutoModel, processor: AutoImageProcessor, images: list[np.ndarray], requires_grad: bool = False
+) -> torch.Tensor:
+    """Get feature from DeiT model.
+
+    Args:
+        model (AutoModel): DeiT model.
+        processor (AutoImageProcessor): DeiT input processor.
+        images (list[np.ndarray]): images to be encoded.
+        requires_grad (bool): maintains gradient. Defaults to False.
+
+    Returns:
+        torch.Tensor: feature from last layer, (1, 768, 14, 14) BCHW deit-base
+    """
+    inputs = processor(images, return_tensors="pt").to(model.device)
+    if requires_grad:
+        outputs = model(**inputs)
+    else:
+        with torch.no_grad():
+            outputs = model(**inputs)
+    last_hidden_state = outputs.last_hidden_state[:, 1:]
+    batch_size, num_patches, num_channels = last_hidden_state.size()
+    last_hidden_state = last_hidden_state.transpose(1, 2)
+    last_hidden_state = last_hidden_state.reshape(
+        batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches))
+    )
+    return last_hidden_state  # (1, 768, 14, 14) BCHW for deit-base
+
+
+def get_deit_model(
+    model_name: str = "facebook/deit-tiny-patch16-224", device: str | torch.device = "cuda"
+) -> tuple[AutoModel, AutoImageProcessor]:
+    """Get DeiT model and its corresponding input processor.
+
+    Args:
+        model_name (str, optional): the name of DeiT model. Defaults to "facebook/deit-tiny-patch16-224".
+        device (str | torch.device, optional): device to put model on. Defaults to "cuda".
+
+    Returns:
+        tuple[DeiTModel, AutoImageProcessor]: DeiT model and its processor.
+    """
+    processor = AutoImageProcessor.from_pretrained(model_name)
+    model = AutoModel.from_pretrained(model_name).to(device)
+    return model, processor
+
+
+def print_feature_size(model_name: str = "facebook/deit-tiny-patch16-224") -> None:
+    """Print the size of the feature from ViT.
+
+    Args:
+        model_name (str, optional): the name of ViT model. Defaults to "facebook/deit-tiny-patch16-224".
+    """
+    from datasets import load_dataset
+
+    dataset = load_dataset("huggingface/cats-image")
+    image = dataset["test"]["image"][0]
+    image = np.array(image)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model, processor = get_deit_model(model_name=model_name, device=device)
+    feature = get_deit_feature(model, processor, image)
+    print(feature.size())
+    # (1, 768, 14, 14) BCHW for deit-base
diff --git a/theia/foundation_models/vision_models/depth_anything.py b/theia/foundation_models/vision_models/depth_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ad54cdede8c3bd4b15cddfb773139bc47e4396b
--- /dev/null
+++ b/theia/foundation_models/vision_models/depth_anything.py
@@ -0,0 +1,681 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+# File modified.
+
+# -----------------------------------------------------------------------
+# Copyright 2024 TikTok and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Depth Anything model."""
+import copy
+from typing import Any, Optional
+
+import numpy.typing as npt
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from transformers import AutoImageProcessor
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_outputs import DepthEstimatorOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.auto import AutoBackbone
+from transformers.models.auto.configuration_auto import CONFIG_MAPPING
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class DepthAnythingConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DepthAnythingModel`].
+    It is used to instantiate an DepthAnything model according to the specified arguments,
+    defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the DepthAnything
+    [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        backbone_config (`dict[str, Any] | PretrainedConfig`, *optional*):
+            The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
+            leverage the [`AutoBackbone`] API.
+        backbone (`str`, *optional*):
+            Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+            will load the corresponding pretrained weights
+            from the timm or transformers library. If `use_pretrained_backbone`
+            is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+        use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+            Whether to use pretrained weights for the backbone.
+        patch_size (`int`, *optional*, defaults to 14):
+            The size of the patches to extract from the backbone features.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        reassemble_hidden_size (`int`, *optional*, defaults to 384):
+            The number of input channels of the reassemble layers.
+        reassemble_factors (`tuple[int | float, ...]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
+            The up/downsampling factors of the reassemble layers.
+        neck_hidden_sizes (`tuple[int]`, *optional*, defaults to `[48, 96, 192, 384]`):
+            The hidden sizes to project to for the feature maps of the backbone.
+        fusion_hidden_size (`int`, *optional*, defaults to 64):
+            The number of channels before fusion.
+        head_in_index (`int`, *optional*, defaults to -1):
+            The index of the features to use in the depth estimation head.
+        head_hidden_size (`int`, *optional*, defaults to 32):
+            The number of output channels in the second convolution of the depth estimation head.
+    ```"""
+
+    model_type = "depth_anything"
+
+    def __init__(
+        self,
+        backbone_config: dict[str, Any] | PretrainedConfig = None,
+        backbone: Optional[str] = None,
+        use_pretrained_backbone: bool = False,
+        patch_size: int = 14,
+        initializer_range: float = 0.02,
+        reassemble_hidden_size: int = 384,
+        reassemble_factors: tuple[int | float, ...] = (4, 2, 1, 0.5),
+        neck_hidden_sizes: tuple[int, ...] = (48, 96, 192, 384),
+        fusion_hidden_size: int = 64,
+        head_in_index: int = -1,
+        head_hidden_size: int = 32,
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+
+        if use_pretrained_backbone:
+            raise ValueError("Pretrained backbones are not supported yet.")
+
+        if backbone_config is not None and backbone is not None:
+            raise ValueError("You can't specify both `backbone` and `backbone_config`.")
+
+        if backbone_config is None and backbone is None:
+            logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.")
+            backbone_config = CONFIG_MAPPING["dinov2"](
+                image_size=518,
+                hidden_size=384,
+                num_attention_heads=6,
+                out_indices=[9, 10, 11, 12],
+                apply_layernorm=True,
+                reshape_hidden_states=False,
+            )
+        elif isinstance(backbone_config, dict):
+            backbone_model_type = backbone_config.get("model_type")
+            config_class = CONFIG_MAPPING[backbone_model_type]
+            backbone_config = config_class.from_dict(backbone_config)
+
+        self.backbone_config = backbone_config
+        self.backbone = backbone
+        self.use_pretrained_backbone = use_pretrained_backbone
+        self.reassemble_hidden_size = reassemble_hidden_size
+        self.patch_size = patch_size
+        self.initializer_range = initializer_range
+        self.reassemble_factors = reassemble_factors
+        self.neck_hidden_sizes = neck_hidden_sizes
+        self.fusion_hidden_size = fusion_hidden_size
+        self.head_in_index = head_in_index
+        self.head_hidden_size = head_hidden_size
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+            `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
+        """
+        output = copy.deepcopy(self.__dict__)
+
+        if output["backbone_config"] is not None:
+            output["backbone_config"] = self.backbone_config.to_dict()
+
+        output["model_type"] = self.__class__.model_type
+        return output
+
+
+class DepthAnythingReassembleLayer(nn.Module):
+    def __init__(self, config: DepthAnythingConfig, channels: int, factor: int | float):
+        super().__init__()
+        self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1)
+
+        # up/down sampling depending on factor
+        if factor > 1:
+            self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
+        elif factor == 1:
+            self.resize = nn.Identity()
+        elif factor < 1:
+            # so should downsample
+            self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
+
+    # Copied from transformers.models.dpt.modeling_dpt.DPTReassembleLayer.forward
+    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+        hidden_state = self.projection(hidden_state)
+        hidden_state = self.resize(hidden_state)
+
+        return hidden_state
+
+
+class DepthAnythingReassembleStage(nn.Module):
+    """
+    This class reassembles the hidden states of the backbone into image-like feature representations at various
+    resolutions.
+
+    This happens in 3 stages:
+    1. Take the patch embeddings and reshape them to image-like feature representations.
+    2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
+    3. Resizing the spatial dimensions (height, width).
+
+    Args:
+        config (`[DepthAnythingConfig]`):
+            Model configuration class defining the model architecture.
+    """
+
+    def __init__(self, config: DepthAnythingConfig):
+        super().__init__()
+
+        self.config = config
+        self.layers = nn.ModuleList()
+        for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors, strict=False):
+            self.layers.append(DepthAnythingReassembleLayer(config, channels=channels, factor=factor))
+
+    def forward(
+        self, hidden_states: list[torch.Tensor], patch_height: Optional[int] = None, patch_width: Optional[int] = None
+    ) -> list[torch.Tensor]:
+        """
+        Args:
+            hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
+                List of hidden states from the backbone.
+        """
+        out = []
+
+        for i, hidden_state in enumerate(hidden_states):
+            # reshape to (batch_size, num_channels, height, width)
+            hidden_state = hidden_state[:, 1:]
+            batch_size, _, num_channels = hidden_state.shape
+            hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
+            hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+            hidden_state = self.layers[i](hidden_state)
+            out.append(hidden_state)
+
+        return out
+
+
+class DepthAnythingPreActResidualLayer(nn.Module):
+    """
+    ResidualConvUnit, pre-activate residual unit.
+
+    Args:
+        config (`[DepthAnythingConfig]`):
+            Model configuration class defining the model architecture.
+    """
+
+    def __init__(self, config: DepthAnythingConfig):
+        super().__init__()
+
+        self.activation1 = nn.ReLU()
+        self.convolution1 = nn.Conv2d(
+            config.fusion_hidden_size,
+            config.fusion_hidden_size,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias=True,
+        )
+
+        self.activation2 = nn.ReLU()
+        self.convolution2 = nn.Conv2d(
+            config.fusion_hidden_size,
+            config.fusion_hidden_size,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias=True,
+        )
+
+    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+        residual = hidden_state
+        hidden_state = self.activation1(hidden_state)
+        hidden_state = self.convolution1(hidden_state)
+        hidden_state = self.activation2(hidden_state)
+        hidden_state = self.convolution2(hidden_state)
+
+        return hidden_state + residual
+
+
+class DepthAnythingFeatureFusionLayer(nn.Module):
+    """Feature fusion layer, merges feature maps from different stages.
+
+    Args:
+        config (`[DepthAnythingConfig]`):
+            Model configuration class defining the model architecture.
+    """
+
+    def __init__(self, config: DepthAnythingConfig):
+        super().__init__()
+
+        self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
+
+        self.residual_layer1 = DepthAnythingPreActResidualLayer(config)
+        self.residual_layer2 = DepthAnythingPreActResidualLayer(config)
+
+    def forward(
+        self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None, size: Optional[int] = None
+    ) -> torch.Tensor:
+        if residual is not None:
+            if hidden_state.shape != residual.shape:
+                residual = nn.functional.interpolate(
+                    residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
+                )
+            hidden_state = hidden_state + self.residual_layer1(residual)
+
+        hidden_state = self.residual_layer2(hidden_state)
+
+        modifier = {"scale_factor": 2} if size is None else {"size": size}
+
+        hidden_state = nn.functional.interpolate(
+            hidden_state,
+            **modifier,
+            mode="bilinear",
+            align_corners=True,
+        )
+        hidden_state = self.projection(hidden_state)
+
+        return hidden_state
+
+
+class DepthAnythingFeatureFusionStage(nn.Module):
+    # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage.__init__ with DPT->DepthAnything
+    def __init__(self, config: DepthAnythingConfig):
+        super().__init__()
+        self.layers = nn.ModuleList()
+        for _ in range(len(config.neck_hidden_sizes)):
+            self.layers.append(DepthAnythingFeatureFusionLayer(config))
+
+    def forward(self, hidden_states: torch.Tensor, size: Optional[int] = None) -> list[torch.Tensor]:
+        # reversing the hidden_states, we start from the last
+        hidden_states = hidden_states[::-1]
+
+        fused_hidden_states = []
+        # first layer only uses the last hidden_state
+        size = hidden_states[1].shape[2:]
+        fused_hidden_state = self.layers[0](hidden_states[0], size=size)
+        fused_hidden_states.append(fused_hidden_state)
+
+        # looping from the last layer to the second
+        for idx, (hidden_state, layer) in enumerate(zip(hidden_states[1:], self.layers[1:], strict=False)):
+            size = hidden_states[1:][idx + 1].shape[2:] if idx != (len(hidden_states[1:]) - 1) else None
+
+            fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
+
+            fused_hidden_states.append(fused_hidden_state)
+
+        return fused_hidden_states
+
+
+# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything
+class DepthAnythingPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = DepthAnythingConfig
+    base_model_prefix = "depth_anything"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+
+    def _init_weights(self, module: nn.Module) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+class DepthAnythingNeck(nn.Module):
+    """
+    DepthAnythingNeck. A neck is a module that is normally used between the backbone and the head.
+    It takes a list of tensors as input and produces another list of tensors as output.
+    For DepthAnything, it includes 2 stages:
+
+    * DepthAnythingReassembleStage
+    * DepthAnythingFeatureFusionStage.
+
+    Args:
+        config (dict): config dict.
+    """
+
+    def __init__(self, config: DepthAnythingConfig):
+        super().__init__()
+        self.config = config
+
+        self.reassemble_stage = DepthAnythingReassembleStage(config)
+
+        self.convs = nn.ModuleList()
+        for channel in config.neck_hidden_sizes:
+            self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
+
+        # fusion
+        self.fusion_stage = DepthAnythingFeatureFusionStage(config)
+
+    def forward(
+        self, hidden_states: list[torch.Tensor], patch_height: Optional[int] = None, patch_width: Optional[int] = None
+    ) -> list[torch.Tensor]:
+        """
+        Args:
+            hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)`
+            or `(batch_size, hidden_size, height, width)`): List of hidden states from the backbone.
+        """
+        if not isinstance(hidden_states, (tuple, list)):
+            raise ValueError("hidden_states should be a tuple or list of tensors")
+
+        if len(hidden_states) != len(self.config.neck_hidden_sizes):
+            raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
+
+        # postprocess hidden states
+        hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
+
+        features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
+
+        # fusion blocks
+        output = self.fusion_stage(features)
+
+        return output
+
+
+class DepthAnythingDepthEstimationHead(nn.Module):
+    """
+    Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
+    the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's
+    supplementary material).
+    """
+
+    def __init__(self, config: DepthAnythingConfig):
+        super().__init__()
+
+        self.head_in_index = config.head_in_index
+        self.patch_size = config.patch_size
+
+        features = config.fusion_hidden_size
+        self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
+        self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1)
+        self.activation1 = nn.ReLU()
+        self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0)
+        self.activation2 = nn.ReLU()
+
+    def forward(self, hidden_states: list[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
+        hidden_states = hidden_states[self.head_in_index]
+
+        predicted_depth = self.conv1(hidden_states)
+        predicted_depth = nn.functional.interpolate(
+            predicted_depth,
+            (int(patch_height * self.patch_size), int(patch_width * self.patch_size)),
+            mode="bilinear",
+            align_corners=True,
+        )
+        predicted_depth = self.conv2(predicted_depth)
+        predicted_depth = self.activation1(predicted_depth)
+        predicted_depth = self.conv3(predicted_depth)
+        predicted_depth = self.activation2(predicted_depth)
+        predicted_depth = predicted_depth.squeeze(dim=1)  # shape (batch_size, height, width)
+
+        return predicted_depth
+
+
+class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
+    def __init__(self, config: DepthAnythingConfig):
+        super().__init__(config)
+
+        self.backbone = AutoBackbone.from_config(config.backbone_config)
+        self.neck = DepthAnythingNeck(config)
+        self.head = DepthAnythingDepthEstimationHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> tuple[torch.Tensor, ...] | DepthEstimatorOutput:
+        r"""
+        Forward pass for Depth Anything.
+
+        Args:
+            pixel_values (torch.FloatTensor): input images.
+            labels (Optional[torch.LongTensor]: labels for loss. Defaults to None.
+            output_attentions (Optional[bool]): whether to return attentions. Defaults to None.
+            output_hidden_states (Optional[bool]): whether to return hidden_states. Defaults to None.
+            return_dict (Optional[bool]): whether to return dict. Defaults to None.
+
+        Returns:
+            Tuple[torch.Tensor] | DepthEstimatorOutput: forward output
+
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+        outputs = self.backbone.forward_with_filtered_kwargs(
+            pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
+        )
+        hidden_states = outputs.feature_maps
+
+        _, _, height, width = pixel_values.shape
+        patch_size = self.config.patch_size
+        patch_height = height // patch_size
+        patch_width = width // patch_size
+
+        hidden_states = self.neck(hidden_states, patch_height, patch_width)
+
+        predicted_depth = self.head(hidden_states, patch_height, patch_width)
+
+        loss = None
+        if labels is not None:
+            raise NotImplementedError("Training is not implemented yet")
+
+        if not return_dict:
+            if output_hidden_states:
+                output = (predicted_depth,) + outputs[1:]
+            else:
+                output = (predicted_depth,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output  # noqa
+
+        return DepthEstimatorOutput(
+            loss=loss,
+            predicted_depth=predicted_depth,
+            hidden_states=outputs.hidden_states if output_hidden_states else None,
+            attentions=outputs.attentions,
+        )
+
+
+class DepthAnythingNeckFeature(DepthAnythingForDepthEstimation):
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> torch.Tensor:
+        """Forward pass for Depth Anything with only neck feature returned.
+
+        Args:
+            pixel_values (torch.FloatTensor): input images.
+            labels (Optional[torch.LongTensor]: labels for loss. Defaults to None.
+            output_attentions (Optional[bool]): whether to return attentions. Defaults to None.
+            output_hidden_states (Optional[bool]): whether to return hidden_states. Defaults to None.
+            return_dict (Optional[bool]): whether to return dict. Defaults to None.
+
+        Returns:
+            torch.Tensor: neck feature.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+        outputs = self.backbone.forward_with_filtered_kwargs(
+            pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
+        )
+        hidden_states = outputs.feature_maps
+
+        _, _, height, width = pixel_values.shape
+        patch_size = self.config.patch_size
+        patch_height = height // patch_size
+        patch_width = width // patch_size
+
+        hidden_states = self.neck(hidden_states, patch_height, patch_width)
+
+        return hidden_states
+
+
+class DepthAnythingHeadFeature(DepthAnythingForDepthEstimation):
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> torch.Tensor:
+        """Forward pass for Depth Anything with only last layer (head) feature returned.
+
+        Args:
+            pixel_values (torch.FloatTensor): input images.
+            labels (Optional[torch.LongTensor]: labels for loss. Defaults to None.
+            output_attentions (Optional[bool]): whether to return attentions. Defaults to None.
+            output_hidden_states (Optional[bool]): whether to return hidden_states. Defaults to None.
+            return_dict (Optional[bool]): whether to return dict. Defaults to None.
+
+        Returns:
+            torch.Tensor: last layer (head) feature
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+        outputs = self.backbone.forward_with_filtered_kwargs(
+            pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
+        )
+        hidden_states = outputs.feature_maps
+
+        _, _, height, width = pixel_values.shape
+        patch_size = self.config.patch_size
+        patch_height = height // patch_size
+        patch_width = width // patch_size
+
+        hidden_states = self.neck(hidden_states, patch_height, patch_width)
+
+        hidden_states = hidden_states[-1]
+
+        head_feature = self.head.conv1(hidden_states)
+        head_feature = nn.functional.interpolate(
+            head_feature,
+            (int(patch_height * patch_size), int(patch_width * patch_size)),
+            mode="bilinear",
+            align_corners=True,
+        )
+        head_feature = self.head.conv2(head_feature)
+
+        return head_feature
+
+
+def get_depth_anything_feature(
+    model: DepthAnythingForDepthEstimation,
+    processor: AutoImageProcessor,
+    images: list[npt.NDArray],
+    requires_grad: Optional[bool] = False,
+) -> torch.Tensor | list[torch.Tensor]:
+    """Get feature (after neck) from depth anything model.
+
+    Args:
+        model (DepthAnythingNeckFeature): Depth Anything model.
+        processor (AutoImageProcessor): Depth Anything processor.
+        images (list[npt.NDArray]): images to extract feature.
+        requires_grad (Optional[bool], optional): whether to keep gradient. Defaults to False.
+
+    Returns:
+        torch.Tensor: feature from depth anything model.
+    """
+    inputs = processor(images, return_tensors="pt").to(model.device)
+    if requires_grad:
+        outputs = model(**inputs)
+    else:
+        with torch.no_grad():
+            outputs = model(**inputs)
+            # if neck
+            # [torch.Size([1, D, 37, 49]), torch.Size([1, D, 74, 98]),
+            # torch.Size([1, D, 148, 196]), torch.Size([1, D, 296, 392])]
+            # D = 64, 128, 256 for small, base, large
+            # if head
+            # torch.Size([1, 32, 518, 686])
+    return outputs
+
+
+def get_depth_anything_model(
+    model_name: Optional[str] = "LiheYoung/depth-anything-large-hf",
+    device: Optional[str | torch.device] = "cuda",
+    selected_feature: Optional[str] = "neck",
+) -> tuple[DepthAnythingForDepthEstimation, AutoImageProcessor]:
+    """Get depth anything model.
+
+    Args:
+        model_name (Optional[str]): name of the model. Defaults to "LiheYoung/depth-anything-large-hf".
+        device (Optional[str | torch.device]): device to put model on. Defaults to "cuda".
+
+    Returns:
+        Tuple[DepthAnythingForDepthEstimation, AutoImageProcessor]: Depth Anything model and the processor.
+    """
+    processor = AutoImageProcessor.from_pretrained(model_name)
+    if selected_feature == "neck":
+        model = DepthAnythingNeckFeature.from_pretrained(model_name).to(device)
+    elif selected_feature == "head":
+        model = DepthAnythingHeadFeature.from_pretrained(model_name).to(device)
+    else:
+        raise ValueError(f"{selected_feature} is not supported for Depth Anything")
+    return model, processor
+
+
+def print_feature_size(
+    model_name: Optional[str] = "LiheYoung/depth-anything-large-hf", selected_feature: Optional[str] = "neck"
+) -> None:
+    """Print the size of the feature from Depth Anything.
+
+    Args:
+        model_name (Optional[str]): the name of Depth Anything model.
+            Defaults to "LiheYoung/depth-anything-large-hf".
+    """
+    import requests
+    from PIL import Image
+
+    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+    image = [Image.open(requests.get(url, stream=True).raw)]
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model, processor = get_depth_anything_model(model_name=model_name, device=device, selected_feature=selected_feature)
+
+    with torch.no_grad():
+        embedding = get_depth_anything_feature(model, processor, image)
+
+    print([x.size() for x in embedding] if isinstance(embedding, list) else embedding.size())
diff --git a/theia/foundation_models/vision_models/dinov2.py b/theia/foundation_models/vision_models/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa8b0a76a08b39bf8327bd5da5da3b729b9e9f2b
--- /dev/null
+++ b/theia/foundation_models/vision_models/dinov2.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import numpy as np
+import torch
+from transformers import AutoImageProcessor, Dinov2Model
+
+
+def get_dinov2_feature(
+    model: Dinov2Model, processor: AutoImageProcessor, images: list[np.ndarray], requires_grad: bool = False
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Get DINOv2 features.
+
+    Args:
+        model (Dinov2Model): DINOv2 model.
+        processor (AutoImageProcessor): DINOv2 input processor.
+        images (list[np.ndarray]): images to be encoded, in RGB, uint8.
+        requires_grad (bool): maintains gradient. Defaults to False.
+
+    Returns:
+        tuple[torch.Tensor, torch.Tensor, torch.Tensor]: (
+            cls_token:      last layer embedding from cls token # (1, 1, 1024) if dinov2-large,
+            visual_tokens:  last layer embeddings from image # (1, 1024, 16, 16) BCHW if dinov2-large,
+            pooled_cls_token: last layer embedding from cls + layernorm # (1, 1, 1024) if dinov2-large
+        )
+    """
+    inputs = processor(images, return_tensors="pt").to(model.device)
+    if requires_grad:
+        outputs = model(**inputs)
+    else:
+        with torch.no_grad():
+            outputs = model(**inputs)
+    cls_token = outputs.last_hidden_state[:, :1]  # (1, 1, 1024) if vit-large
+    visual_tokens = outputs.last_hidden_state[:, 1:]  # (1, 256, 1024) if vit-large
+    pooled_cls_token = outputs.pooler_output.unsqueeze(1)  # (1, 1, 1024) if vit-large
+    batch_size, num_patches, num_channels = visual_tokens.size()
+    visual_tokens = visual_tokens.transpose(1, 2)
+    visual_tokens = visual_tokens.reshape(
+        batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches))
+    )  # (1, 1024, 16, 16) BCHW for vit-huge
+    return cls_token, visual_tokens, pooled_cls_token
+
+
+def get_dinov2_model(
+    model_name: str = "facebook/dinov2-large", device: str | torch.device = "cuda"
+) -> tuple[Dinov2Model, AutoImageProcessor]:
+    """Get DINOv2 model and its input processor.
+
+    Args:
+        model_name (str, optional): name of DINOv2 model. Defaults to "facebook/dinov2-large".
+        device (str | torch.device, optional): device to put the model on. Defaults to "cuda".
+
+    Returns:
+        tuple[Dinov2Model, AutoImageProcessor]: DINOv2 model and the corresponding input processor
+    """
+    processor = AutoImageProcessor.from_pretrained(model_name)
+    model = Dinov2Model.from_pretrained(model_name).to(device)
+    return model, processor
+
+
+def print_feature_size(model_name: str = "facebook/dinov2-large") -> None:
+    """Print the sizes of features from DINOv2.
+
+    Args:
+        model_name (str, optional): the name of DINOv2. Defaults to "facebook/dinov2-large".
+    """
+    from datasets import load_dataset
+
+    dataset = load_dataset("huggingface/cats-image")
+    image = dataset["test"]["image"][0]
+    image = [np.array(image)]
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model, processor = get_dinov2_model(model_name=model_name, device=device)
+    cls_token, visual_tokens, pooled_cls_token = get_dinov2_feature(model, processor, image)
+    print(cls_token.size(), visual_tokens.size(), pooled_cls_token.size())
+    # (1, 1, 1024), (1, 1024, 16, 16), (1, 1, 1024) for dinov2-large
diff --git a/theia/foundation_models/vision_models/sam.py b/theia/foundation_models/vision_models/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0a3df3ea0800c6c9cd264a7346d137ab449091
--- /dev/null
+++ b/theia/foundation_models/vision_models/sam.py
@@ -0,0 +1,393 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from dataclasses import dataclass
+from typing import Any, Optional
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import SamConfig, SamModel, SamProcessor
+from transformers.models.sam.modeling_sam import SamMaskDecoder, SamMaskDecoderConfig
+from transformers.utils import ModelOutput
+
+
+class SamMaskDecoderWithFeature(SamMaskDecoder):
+    """Mask decoder with upscaled feature exposed. Borrowed from transformers."""
+
+    def __init__(self, config: SamMaskDecoderConfig):
+        super().__init__(config)
+
+    # borrowd from huggingface transformer
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_positional_embeddings: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        output_attentions: Optional[bool] = None,
+        attention_similarity: Optional[torch.Tensor] = None,
+        target_embedding: Optional[torch.Tensor] = None,
+    ) -> Any:
+        """Predict masks given image and prompt embeddings."""
+        batch_size, num_channels, height, width = image_embeddings.shape
+        point_batch_size = sparse_prompt_embeddings.shape[1]
+        # Concatenate output tokens
+        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+        output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
+
+        if sparse_prompt_embeddings.sum().item() != 0:
+            tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
+        else:
+            tokens = output_tokens
+        point_embeddings = tokens.to(self.iou_token.weight.dtype)
+
+        # Expand per-image data in batch direction to be per-point
+        image_embeddings = image_embeddings + dense_prompt_embeddings
+        image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
+        image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
+
+        # Run the transformer, image_positional_embedding are consumed
+        point_embedding, image_embeddings, attentions = self.transformer(
+            point_embeddings=point_embeddings,
+            image_embeddings=image_embeddings,
+            image_positional_embeddings=image_positional_embeddings,
+            attention_similarity=attention_similarity,
+            target_embedding=target_embedding,
+            output_attentions=output_attentions,
+        )
+        iou_token_out = point_embedding[:, :, 0, :]
+        mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        image_embeddings = image_embeddings.transpose(2, 3).reshape(
+            batch_size * point_batch_size, num_channels, height, width
+        )
+
+        upscaled_embedding = self.upscale_conv1(image_embeddings)
+        upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+        upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
+
+        hyper_in_list = []
+        for i in range(self.num_mask_tokens):
+            current_mlp = self.output_hypernetworks_mlps[i]
+            hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+        hyper_in = torch.stack(hyper_in_list, dim=2)
+
+        _, num_channels, height, width = upscaled_embedding.shape
+        upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
+        masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
+
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+
+        # Select the correct mask or masks for output
+        if multimask_output:
+            mask_slice = slice(1, None)
+        else:
+            mask_slice = slice(0, 1)
+        masks = masks[:, :, mask_slice, :, :]
+        iou_pred = iou_pred[:, :, mask_slice]
+
+        outputs: tuple[Any, ...] = (masks, iou_pred)
+
+        if output_attentions:
+            outputs = (*outputs, attentions)
+        else:
+            outputs = (*outputs, None)
+
+        outputs = (*outputs, upscaled_embedding.reshape(batch_size * point_batch_size, num_channels, height, width))
+        return outputs
+
+
+@dataclass
+class SamImageSegmentationWithFeatureOutput(ModelOutput):
+    """Sam segmentation output plus features."""
+
+    iou_scores: torch.FloatTensor = None
+    pred_masks: torch.FloatTensor = None
+    vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+    vision_attentions: Optional[tuple[torch.FloatTensor]] = None
+    mask_decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+    image_embeddings: Optional[tuple[torch.FloatTensor]] = None
+    upscaled_image_embeddings: Optional[tuple[torch.FloatTensor]] = None
+
+
+class SamModelWithFeature(SamModel):
+    """SAM model with feature exposed. Borrowed from transformers."""
+
+    def __init__(self, config: SamConfig):
+        super().__init__(config)
+        self.mask_decoder = SamMaskDecoderWithFeature(config.mask_decoder_config)
+        self.post_init()
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        input_points: Optional[torch.FloatTensor] = None,
+        input_labels: Optional[torch.LongTensor] = None,
+        input_boxes: Optional[torch.FloatTensor] = None,
+        input_masks: Optional[torch.LongTensor] = None,
+        image_embeddings: Optional[torch.FloatTensor] = None,
+        multimask_output: bool = True,
+        attention_similarity: Optional[torch.FloatTensor] = None,
+        target_embedding: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        **kwargs: Optional[dict[str, Any]],
+    ) -> tuple | SamImageSegmentationWithFeatureOutput:
+        """Sam forward pass with feature returned"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None and image_embeddings is None:
+            raise ValueError("Either pixel_values or image_embeddings must be provided.")
+
+        if pixel_values is not None and image_embeddings is not None:
+            raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
+
+        if input_points is not None and len(input_points.shape) != 4:
+            raise ValueError(
+                "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`,"
+                " `nb_points_per_image`, `2`.",
+                " got {}.".format(input_points.shape),
+            )
+        if input_boxes is not None and len(input_boxes.shape) != 3:
+            raise ValueError(
+                "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
+                " got {}.".format(input_boxes.shape),
+            )
+        if input_points is not None and input_boxes is not None:
+            point_batch_size = input_points.shape[1]
+            box_batch_size = input_boxes.shape[1]
+            if point_batch_size != box_batch_size:
+                raise ValueError(
+                    "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
+                        point_batch_size, box_batch_size
+                    )
+                )
+
+        image_positional_embeddings = self.get_image_wide_positional_embeddings()
+        # repeat with batch size
+        batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]  # type: ignore [union-attr]
+        image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+        vision_attentions = None
+        vision_hidden_states = None
+
+        if pixel_values is not None:
+            vision_outputs = self.vision_encoder(
+                pixel_values,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+            image_embeddings = vision_outputs[0]
+
+            if output_hidden_states:
+                vision_hidden_states = vision_outputs[1]
+            if output_attentions:
+                vision_attentions = vision_outputs[-1]
+
+        if input_points is not None and input_labels is None:
+            input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+        if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:  # type: ignore [union-attr]
+            raise ValueError(
+                "The batch size of the image embeddings and the input points must be the same. ",
+                "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),  # type: ignore [union-attr]
+                " if you want to pass multiple points for the same image, make sure that you passed ",
+                " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
+                " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
+            )
+
+        sparse_embeddings, dense_embeddings = self.prompt_encoder(
+            input_points=input_points,
+            input_labels=input_labels,
+            input_boxes=input_boxes,
+            input_masks=input_masks,
+        )
+
+        low_res_masks, iou_predictions, mask_decoder_attentions, upscaled_image_embeddings = self.mask_decoder(
+            image_embeddings=image_embeddings,
+            image_positional_embeddings=image_positional_embeddings,
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+            attention_similarity=attention_similarity,
+            target_embedding=target_embedding,
+            output_attentions=output_attentions,
+        )
+
+        if not return_dict:
+            output: tuple[Any, ...] = (iou_predictions, low_res_masks)
+            if output_hidden_states:
+                output = (*output, vision_hidden_states)
+            if output_attentions:
+                output = (*output, vision_attentions, mask_decoder_attentions)
+
+            output = (*output,)
+            return output
+
+        return SamImageSegmentationWithFeatureOutput(
+            iou_scores=iou_predictions,
+            pred_masks=low_res_masks,
+            vision_hidden_states=vision_hidden_states,
+            vision_attentions=vision_attentions,
+            mask_decoder_attentions=mask_decoder_attentions,
+            image_embeddings=image_embeddings,
+            upscaled_image_embeddings=upscaled_image_embeddings,
+        )
+
+
+class SamModelVisionFeature(SamModel):
+    """Sam with only feature from the vision backbone. Borrowed from transformers."""
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        input_points: Optional[torch.FloatTensor] = None,
+        input_labels: Optional[torch.LongTensor] = None,
+        input_boxes: Optional[torch.FloatTensor] = None,
+        input_masks: Optional[torch.LongTensor] = None,
+        image_embeddings: Optional[torch.FloatTensor] = None,
+        multimask_output: bool = True,
+        attention_similarity: Optional[torch.FloatTensor] = None,
+        target_embedding: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        **kwargs: Optional[dict[str, Any]],
+    ) -> list[dict[str, torch.Tensor]]:
+        """Sam forward pass that only goes through vision backbone and returns visual feature."""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None and image_embeddings is None:
+            raise ValueError("Either pixel_values or image_embeddings must be provided.")
+
+        if pixel_values is not None and image_embeddings is not None:
+            raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
+
+        if input_points is not None and len(input_points.shape) != 4:
+            raise ValueError(
+                "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`,"
+                " `nb_points_per_image`, `2`.",
+                " got {}.".format(input_points.shape),
+            )
+        if input_boxes is not None and len(input_boxes.shape) != 3:
+            raise ValueError(
+                "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
+                " got {}.".format(input_boxes.shape),
+            )
+        if input_points is not None and input_boxes is not None:
+            point_batch_size = input_points.shape[1]
+            box_batch_size = input_boxes.shape[1]
+            if point_batch_size != box_batch_size:
+                raise ValueError(
+                    "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
+                        point_batch_size, box_batch_size
+                    )
+                )
+
+        image_positional_embeddings = self.get_image_wide_positional_embeddings()
+        # repeat with batch size
+        batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]  # type: ignore [union-attr]
+        image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+        vision_attentions = None
+        vision_hidden_states = None
+
+        if pixel_values is not None:
+            vision_outputs = self.vision_encoder(
+                pixel_values,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+            image_embeddings = vision_outputs[0]
+
+            if output_hidden_states:
+                vision_hidden_states = vision_outputs[1]
+            if output_attentions:
+                vision_attentions = vision_outputs[-1]
+
+        return SamImageSegmentationWithFeatureOutput(
+            vision_hidden_states=vision_hidden_states,
+            vision_attentions=vision_attentions,
+            image_embeddings=image_embeddings,
+        )
+
+
+def get_sam_feature(
+    model: SamModel, processor: SamProcessor, images: list[np.ndarray], requires_grad: bool = False
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Get features from SAM.
+
+    Args:
+        model (SamModel): SAM model.
+        processor (SamProcessor): SAM input processor.
+        images (list[np.ndarray]): images to be encoded, in RGB, uint8.
+        requires_grad (bool): maintains gradient. Defaults to False.
+
+    Returns:
+        tuple[torch.Tensor, torch.Tensor]: (
+            image_embeddings: feature from SAM visual encoder # (1, 256, 64, 64) if BCHW vit-huge
+            upscaled_image_embeddings: features from mask decoder # (1, 32, 256, 256)
+        )
+    """
+    inputs = processor(images, return_tensors="pt").to(model.device)
+    if requires_grad:
+        outputs = model(**inputs)
+    else:
+        with torch.no_grad():
+            outputs = model(**inputs)
+    return (outputs.image_embeddings, outputs.upscaled_image_embeddings)
+
+
+def get_sam_model(
+    model_name: str = "facebook/sam-vit-huge", device: str | torch.device = "cuda", with_upscaled: bool = False
+) -> tuple[SamModelWithFeature, SamProcessor]:
+    """Get sam model and its input processor.
+
+    Args:
+        model_name (str, optional): name of SAM model. Defaults to "facebook/sam-vit-huge".
+        device (str | torch.device, optional): device to put the model on. Defaults to "cuda".
+        with_upscaled (bool, optional): if return upscaled features. Defaults to False.
+
+    Returns:
+        tuple[SamModelWithFeature, SamProcessor]: SAM and its corresponding input processor
+    """
+    if with_upscaled:
+        model = SamModelWithFeature.from_pretrained(model_name).to(device)
+    else:
+        model = SamModelVisionFeature.from_pretrained(model_name).to(device)
+    processor = SamProcessor.from_pretrained(model_name)
+    return model, processor
+
+
+def print_feature_size(model_name: str = "facebook/sam-vit-huge") -> None:
+    """Print the size of features from sam.
+
+    Args:
+        model_name (str, optional): the name of SAM model. Defaults to "facebook/sam-vit-huge".
+    """
+    import requests
+
+    img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
+    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+    image_array = [np.array(raw_image)]
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model, processor = get_sam_model(model_name=model_name, device=device)
+    image_embeddings, upscaled_embeddings = get_sam_feature(model, processor, image_array)
+
+    print(image_embeddings.size(), upscaled_embeddings.size() if upscaled_embeddings is not None else None)
+    # (1, 256, 64, 64) and (1, 32, 256, 256) for vit-huge
diff --git a/theia/foundation_models/vision_models/vit.py b/theia/foundation_models/vision_models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..d240e95e677a261f429de90f9ef400d935848d32
--- /dev/null
+++ b/theia/foundation_models/vision_models/vit.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import numpy as np
+import torch
+from transformers import AutoImageProcessor, ViTModel
+
+
+def get_vit_feature(
+    model: ViTModel, processor: AutoImageProcessor, images: list[np.ndarray], requires_grad: bool = False
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Get feature from ViT model.
+
+    Args:
+        model (ViTModel): ViT model.
+        processor (AutoImageProcessor): ViT input processor.
+        images (list[np.ndarray]): images to be encoded.
+        requires_grad (bool): maintains gradient. Defaults to False.
+
+    Returns:
+        torch.Tensor: feature from last layer, (1, 1280, 16, 16) BCHW vit-huge
+    """
+    inputs = processor(images, return_tensors="pt").to(model.device)
+    if requires_grad:
+        outputs = model(**inputs)
+    else:
+        with torch.no_grad():
+            outputs = model(**inputs)
+    cls_token, last_hidden_state = outputs.last_hidden_state[:, 0], outputs.last_hidden_state[:, 1:]
+    batch_size, num_patches, num_channels = last_hidden_state.size()
+    last_hidden_state = last_hidden_state.transpose(1, 2)
+    last_hidden_state = last_hidden_state.reshape(
+        batch_size, num_channels, int(np.sqrt(num_patches)), int(np.sqrt(num_patches))
+    )
+    return cls_token, last_hidden_state  # (1, 1280, 16, 16) BCHW for vit-huge
+
+
+def get_vit_model(
+    model_name: str = "google/vit-huge-patch14-224-in21k", device: str | torch.device = "cuda"
+) -> tuple[ViTModel, AutoImageProcessor]:
+    """Get ViT model and its corresponding input processor.
+
+    Args:
+        model_name (str, optional): the name of vit model. Defaults to "google/vit-huge-patch14-224-in21k".
+        device (str | torch.device, optional): device to put model on. Defaults to "cuda".
+
+    Returns:
+        tuple[ViTModel, AutoImageProcessor]: _description_
+    """
+    processor = AutoImageProcessor.from_pretrained(model_name)
+    model = ViTModel.from_pretrained(model_name).to(device)
+    return model, processor
+
+
+def print_feature_size(model_name: str = "google/vit-huge-patch14-224-in21k") -> None:
+    """Print the size of the feature from ViT.
+
+    Args:
+        model_name (str, optional): the name of ViT model. Defaults to "google/vit-huge-patch14-224-in21k".
+    """
+    from datasets import load_dataset
+
+    dataset = load_dataset("huggingface/cats-image")
+    image = dataset["test"]["image"][0]
+    image = np.array(image)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model, processor = get_vit_model(model_name=model_name, device=device)
+    cls_token, feature = get_vit_feature(model, processor, image)
+    print(cls_token.size(), feature.size())
+    # cls (1, 1280)
+    # feature (1, 1280, 16, 16) BCHW for vit-huge
diff --git a/theia/lr_schedulers/__init__.py b/theia/lr_schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c85cbe8c221af7fa03172959d23da67d622af727
--- /dev/null
+++ b/theia/lr_schedulers/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from .lr_schedulers import get_cos_lrs_with_linear_warm_up, get_constant_lrs_with_linear_warm_up
diff --git a/theia/lr_schedulers/lr_schedulers.py b/theia/lr_schedulers/lr_schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f3431ffa3a5061db4af2bdac9bf7a66ee2e054a
--- /dev/null
+++ b/theia/lr_schedulers/lr_schedulers.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Any
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR, ConstantLR
+
+
+def get_cos_lrs_with_linear_warm_up(
+    optimizer: Optimizer,
+    warm_up_steps: int = 2000,
+    warm_up_lr_start_factor: float = 1e-2,
+    warm_up_lr_end_factor: float = 1.0,
+    cos_lrs_T_0: int = 5000,
+) -> SequentialLR:
+    """Get a cos annealing warm restarts lr scheduler with linear warm up at the beginning.
+
+    Args:
+        optimizer (Optimizer): original optimizer to be scheduled.
+        warm_up_steps (int): number of warm up steps. Defaults to 2000.
+        warm_up_lr_start_factor (float): start factor of the linear schedular. Defaults to 1e-2.
+        warm_up_lr_end_factor (float): end factor of the linear scheduler. Defaults to 1.
+        cos_lrs_T_0 (int): T_0 param of cos lrs. Number of steps per cycle. Defaults to 5000.
+
+    Returns:
+        SequentialLR: a sequential lrs that combines linear and cos to implement warm up.
+    """
+    linear_lrs = LinearLR(
+        optimizer=optimizer,
+        start_factor=warm_up_lr_start_factor,
+        end_factor=warm_up_lr_end_factor,
+        total_iters=warm_up_steps,
+    )
+
+    cos_lrs = CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=cos_lrs_T_0, T_mult=1)
+
+    lrs = SequentialLR(optimizer=optimizer, schedulers=[linear_lrs, cos_lrs], milestones=[warm_up_steps])
+
+    return lrs
+
+
+def get_constant_lrs_with_linear_warm_up(
+    optimizer: Optimizer,
+    warm_up_steps: int = 2000,
+    warm_up_lr_start_factor: float = 1e-2,
+    warm_up_lr_end_factor: float = 1.,
+    **kwargs: Any
+) -> SequentialLR:
+    """Get a constant lr scheduler with linear warm up at the beginning.
+
+    Args:
+        optimizer (Optimizer): original optimizer to be scheduled.
+        warm_up_steps (int): number of warm up steps. Defaults to 2000.
+        warm_up_lr_start_factor (float): start factor of the linear schedular. Defaults to 1e-2.
+        warm_up_lr_end_factor (float): end factor of the linear scheduler. Defaults to 1.
+
+    Returns:
+        SequentialLR: a sequential lrs that combines linear and constant lrs to implement warm up.
+    """
+    linear_lrs = LinearLR(
+        optimizer = optimizer, 
+        start_factor = warm_up_lr_start_factor, 
+        end_factor = warm_up_lr_end_factor,
+        total_iters = warm_up_steps
+    )
+
+    constant_lrs = ConstantLR(
+        optimizer = optimizer,
+        factor=1.0
+    )
+
+    lrs = SequentialLR(
+        optimizer = optimizer,
+        schedulers = [linear_lrs, constant_lrs],
+        milestones = [warm_up_steps]
+    )
+
+    return lrs
diff --git a/theia/models/__init__.py b/theia/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/models/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/models/activations.py b/theia/models/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..641776b99afd5b991ec47cf9785d527e8e85c2ae
--- /dev/null
+++ b/theia/models/activations.py
@@ -0,0 +1,24 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import torch.nn as nn
+
+
+def get_activation_fn(activation: str) -> nn.Module:
+    """Return specified activation function.
+
+    Args:
+        activation (str): the name of the activation function.
+
+    Returns:
+        nn.Module: the activation function in nn.Module.
+    """
+    if activation == "relu":
+        return nn.ReLU()
+    elif activation == "gelu":
+        return nn.GELU()
+    elif activation == "tanh":
+        return nn.Tanh()
+    elif activation == "leaky_relu":
+        return nn.LeakyReLU()
+    else:
+        raise ValueError(f"{activation} is not defined in theia/models/activations.py:get_activation_fn()")
diff --git a/theia/models/adapter_heads.py b/theia/models/adapter_heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8a12e2ec936bcb6125319acfec85799aeaf0ef0
--- /dev/null
+++ b/theia/models/adapter_heads.py
@@ -0,0 +1,359 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+
+from itertools import chain
+
+import torch
+import torch.nn as nn
+from einops.layers.torch import Rearrange
+from torch.nn.functional import interpolate
+
+
+class Interpolation(nn.Module):
+    """Interpolation nn.Module wrap for nn.functional.interpolate.
+
+    Attributes:
+        target_size (tuple[int, int] | torch.Size): target spatial size of this interpolation.
+    """
+
+    def __init__(self, target_size: tuple[int, int] | torch.Size) -> None:
+        super().__init__()
+        self.target_size = target_size
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Very simple forward pass to call interpolate()."""
+        return interpolate(x, self.target_size)
+    
+
+class LinearAdapterHead(nn.Module):
+    """Adapter head contains a single linear layer."""
+    def __init__(
+        self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size
+    ):
+        """Initialization function for LinearAdapterHead.
+        Args:
+            source_size (tuple[int, ...] | torch.Size): the size of the source feature.
+            target_size (tuple[int, ...] | torch.Size): the size of the target feature.
+            num_layer (int): number of MLP layers (One linear layer if num_layer = 1).
+        """
+        super().__init__()
+
+        self.source_size = source_size
+        self.target_size = target_size
+
+        source_channel_size = self.source_size[0]
+        target_channel_size = self.target_size[0]
+
+        self.adapter = nn.Sequential(
+            nn.Linear(source_channel_size, target_channel_size),
+        )
+
+    def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
+        """Forward pass for the adapter. """
+        assert backbone_no_cls == False
+        # x: [B, (1+H*W), C]
+        # LinearAdapterHead is used only when there is cls token in the backbone.
+        x = x[:, 0]
+        x = self.adapter(x)
+        return x  # [B, (H*W), C]
+
+
+class MLPAdapterHead(nn.Module):
+    """MLP Adapter module.
+
+    Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
+    Will first do interpolation to match the spatial size [H_t, W_t],
+    followed by MLP to project to the target channel dimension [C_t].
+
+    Attributes:
+        source_size (tuple[int, ...] | torch.Size): the size of the source feature. [C, H, W]
+        target_size (tuple[int, ...] | torch.Size): the size of the target feature. [C, H, W]
+        adapter     (nn.Module):                    the adapter module.
+        interpolation (nn.Module):                  interpolation to adjust sizes before MLP.
+    """
+
+    def __init__(
+        self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size, num_layer: int
+    ):
+        """Initialization function for MLPAdapter.
+
+        Args:
+            source_size (tuple[int, ...] | torch.Size): the size of the source feature.
+            target_size (tuple[int, ...] | torch.Size): the size of the target feature.
+            num_layer (int): number of MLP layers (One linear layer if num_layer = 1).
+        """
+        super().__init__()
+        assert num_layer >= 1, f"`num_layer` in {self._get_name()} should >= 1. Got {num_layer}"
+
+        self.source_size = source_size
+        self.target_size = target_size
+
+        source_channel_size = self.source_size[0]
+        target_channel_size = self.target_size[0]
+
+        self.interpolation = nn.Sequential(
+            nn.Identity(),
+        )
+        if self.source_size[1] != self.target_size[1]:
+            self.interpolation = nn.Sequential(
+                Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
+                Interpolation(self.target_size[1:]),
+                Rearrange("b c h w-> b (h w) c"),
+            )
+
+        if num_layer == 1:
+            self.adapter = nn.Sequential(
+                nn.Linear(source_channel_size, target_channel_size),
+            )
+        elif num_layer >= 2:
+            hidden_dim = source_channel_size * 2
+            self.adapter = nn.Sequential(
+                nn.Linear(source_channel_size, hidden_dim),
+                *list(
+                    chain.from_iterable([[nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)] for _ in range(num_layer - 2)])
+                ),
+                nn.ReLU(),
+                nn.Linear(hidden_dim, target_channel_size),
+            )
+
+    def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
+        """Forward pass for the adapter. First interpolation then MLP."""
+        # x: [B, (1)+H*W, C]
+        if not backbone_no_cls:
+            x = x[:, 1:]
+        # x: [B, (H*W), C]
+        x = self.interpolation(x)
+        x = self.adapter(x)
+        return x  # [B, (H*W), C]
+
+
+class ConvAdapterHead(nn.Module):
+    """Convolutional Adapter module.
+
+    Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
+    Uses CNN to map channel and spatial sizes jointly.
+    Note: only work for (16, 16), (any, any), any <= 14, and (64, 64) spatial sizes for now.
+
+    Attributes:
+        source_size (tuple[int, ...] | torch.Size): the size of the source feature.
+        target_size (tuple[int, ...] | torch.Size): the size of the target feature.
+        adapter     (nn.Module):                    the adapter module.
+        interpolation (nn.Module):                  interpolation to adjust sizes before MLP.
+    """
+
+    def __init__(
+        self,
+        source_size: tuple[int, ...] | torch.Size,
+        target_size: tuple[int, ...] | torch.Size,
+    ):
+        """Initialization function for ConvAdapter.
+
+        Args:
+            source_size (tuple[int, ...] | torch.Size): the size of the source feature.
+            target_size (tuple[int, ...] | torch.Size): the size of the target feature.
+        """
+        super().__init__()
+        self.source_size = source_size
+        self.target_size = target_size
+
+        hidden_dim = self.source_size[0] * 2
+        source_channel_size = self.source_size[0]
+        target_channel_size = self.target_size[0]
+
+        if self.source_size[1] < 12:
+            raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.")
+        elif self.source_size[1] < 16:  # pad (any, any), any <= 14 to (16, 16)
+            self.pad = nn.Sequential(
+                Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
+                nn.ConvTranspose2d(
+                    source_channel_size,
+                    source_channel_size,
+                    kernel_size=3,
+                    stride=1,
+                    output_padding=14 - self.source_size[1],
+                ),
+            )
+            self.source_size = (self.source_size[0], 16, 16)
+        elif self.source_size[1] == 16 or self.source_size[1] == 64:  # do nothing for (16, 16) and (64, 64)
+            self.pad = nn.Sequential(
+                Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
+            )
+        else:
+            raise NotImplementedError("feature spatial size (>=16x16) other than 16x16 and 64x64 is not supported.")
+
+        if self.source_size[1] < self.target_size[1]:  # (16, 16) / (14, 14) to (64, 64)
+            self.adapter = nn.Sequential(
+                nn.LayerNorm(self.source_size),
+                nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1),  # 31
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 31, 31]),
+                nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1),  # 64
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 64, 64]),
+                nn.ConvTranspose2d(hidden_dim, target_channel_size, kernel_size=3, stride=1, padding=1),  # 64
+                Rearrange("b c h w-> b (h w) c"),
+            )
+        elif self.source_size[1] == self.target_size[1]:  # (16, 16) to (16, 16)
+            self.adapter = nn.Sequential(
+                nn.LayerNorm(self.source_size),
+                nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1),  # 16
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
+                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),  # 16
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
+                nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1),  # 16
+                Rearrange("b c h w-> b (h w) c"),
+            )
+        else:  # (64, 64) to (16, 16)
+            self.adapter = nn.Sequential(
+                nn.LayerNorm(self.source_size),
+                nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1),  # 32
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 32, 32]),
+                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1),  # 16
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 16, 16]),
+                nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1),  # 16
+                Rearrange("b c h w-> b (h w) c"),
+            )
+
+    def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
+        """Forward pass for ConvAdapter"""
+        # x: [B, (1)+H*W, C]
+        if not backbone_no_cls:
+            x = x[:, 1:]
+        # x: [B, H*W, C]
+        x = self.pad(x)
+        x = self.adapter(x)
+        return x  # B, (H*W), C
+
+
+class LightConvAdapterHead(nn.Module):
+    """Light Convolutional Adapter module.
+
+    Transforms features from source size in [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t].
+    Uses CNN to map channel and spatial sizes jointly.
+    Note: only work for source sizes (H_s, W_s): (16, 16), (any, any), 12 <= any <= 14,
+        and target sizes (H_t, W_t): (16, 16) and (64, 64) for now.
+
+    Attributes:
+        source_size (tuple[int, ...] | torch.Size): the size of the source feature,
+            channel first (C, H, W).
+        target_size (tuple[int, ...] | torch.Size): the size of the target feature,
+            channel first (C, H, W).
+        adapter     (nn.Module):                    the adapter module.
+        interpolation (nn.Module):                  interpolation to adjust sizes before MLP.
+    """
+
+    def __init__(
+        self,
+        source_size: tuple[int, ...] | torch.Size,
+        target_size: tuple[int, ...] | torch.Size,
+        hidden_size_factor: int | float = 1.0,
+    ):
+        """Initialization function for ConvAdapter.
+
+        Args:
+            source_size (tuple[int, ...] | torch.Size): the size of the source feature.
+            target_size (tuple[int, ...] | torch.Size): the size of the target feature.
+            hidden_size_factor (int | float): the size of hidden dim of feature translator
+                as a factor of input feature hidden dim.
+        """
+        super().__init__()
+        if source_size[1] != source_size[2] or target_size[1] != target_size[2]:
+            raise NotImplementedError(
+                "Currently does not support non-square feature maps like source size"
+                "{source_size} and target size {target_size}."
+            )
+        self.source_size = source_size
+        self.target_size = target_size
+        self.hidden_size_factor = hidden_size_factor
+
+        hidden_dim = int(self.source_size[0] * hidden_size_factor)
+        source_channel_size = self.source_size[0]
+        target_channel_size = self.target_size[0]
+
+        if self.source_size[1] < 12:
+            raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.")
+        elif self.source_size[1] < 16 and self.target_size[1] >= 16:  # pad (any, any), any <= 14 to (16, 16)
+            self.pad = nn.Sequential(
+                Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
+                nn.ConvTranspose2d(
+                    source_channel_size,
+                    source_channel_size,
+                    kernel_size=3,
+                    stride=1,
+                    output_padding=14 - self.source_size[1],
+                ),
+            )
+            self.source_size = (self.source_size[0], 16, 16)
+        elif (self.source_size[1] == 16 or self.source_size[1] == 64) or \
+             (self.source_size[1] == 14 and self.target_size[1] == 14):  
+            # no padding for (16, 16), (64, 64) and (14, 14) <-> (14, 14)
+            self.pad = nn.Sequential(
+                Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
+            )
+        elif self.target_size[1] < 14:
+            self.pad = nn.Sequential(
+                Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]),
+            )
+        else:
+            raise NotImplementedError("feature spatial size larger than 16x16 (other than 64x64) is not supported.")
+
+        if self.source_size[1] == 16 and self.target_size[1] == 64:  # (16, 16) to (64, 64)
+            self.adapter = nn.Sequential(
+                nn.LayerNorm(self.source_size),
+                nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1),  # 31
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 31, 31]),
+                nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1),  # 64
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 64, 64]),
+                Rearrange("b c h w-> b (h w) c"),
+                nn.Linear(hidden_dim, target_channel_size),
+            )
+        elif self.source_size[1] == self.target_size[1]:  # (16, 16) to (16, 16)
+            self.adapter = nn.Sequential(
+                nn.LayerNorm(self.source_size),
+                nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1),  # 16
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
+                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),  # 16
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, *self.source_size[1:]]),
+                Rearrange("b c h w-> b (h w) c"),
+                nn.Linear(hidden_dim, target_channel_size),
+            )
+        elif self.source_size[1] == 64 and self.target_size[1] == 16:  # (64, 64) to (16, 16)
+            self.adapter = nn.Sequential(
+                nn.LayerNorm(self.source_size),
+                nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1),  # 32
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 32, 32]),
+                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1),  # 16
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 16, 16]),
+                Rearrange("b c h w-> b (h w) c"),
+                nn.Linear(hidden_dim, target_channel_size),
+            )
+        elif self.target_size[1] == 7:
+            self.adapter = nn.Sequential(
+                nn.LayerNorm(self.source_size),
+                nn.Conv2d(source_channel_size, hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7
+                nn.ReLU(),
+                nn.LayerNorm([hidden_dim, 7, 7]),
+                Rearrange("b c h w-> b (h w) c"),
+                nn.Linear(hidden_dim, target_channel_size)
+            )
+        else:
+            NotImplementedError(f"{self.source_size} to {self.target_size} is not supported.")
+
+    def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor:
+        """Forward pass for ConvAdapter"""
+        # x: [B, (1)+H*W, C]
+        if not backbone_no_cls:
+            x = x[:, 1:]
+        x = self.pad(x)
+        x = self.adapter(x)
+        return x  # [B, H*W, C]
diff --git a/theia/models/backbones.py b/theia/models/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4ea20e7cd42f249122b275d3fd1def5bb70b448
--- /dev/null
+++ b/theia/models/backbones.py
@@ -0,0 +1,526 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import math
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+from transformers import AutoConfig, AutoModel, AutoProcessor
+from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel
+
+
+# Modified from huggingface transformers ViTEmbeddings
+# Original Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+class ViTEmbeddingsNoCLS(ViTEmbeddings):
+    """ViT Embedding Module without CLS token."""
+
+    def __init__(self, config: AutoConfig, use_mask_token: bool = False):
+        """Initialization.
+
+        Args:
+            config (AutoConfig): config for ViT.
+            use_mask_token (bool, optional): whether to use mask token. Defaults to False.
+        """
+        super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token)
+        self.cls_token = None
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images.
+
+        Source:
+        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+        """
+
+        num_patches = embeddings.shape[1]
+        num_positions = self.position_embeddings.shape[1] - 1
+        if num_patches == num_positions and height == width:
+            return self.position_embeddings
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        dim = embeddings.shape[-1]
+        h0 = height // self.config.patch_size
+        w0 = width // self.config.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        h0, w0 = h0 + 0.1, w0 + 0.1
+        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+            mode="bicubic",
+            align_corners=False,
+        )
+        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return patch_pos_embed
+
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> torch.Tensor:
+        batch_size, num_channels, height, width = pixel_values.shape
+        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+        if bool_masked_pos is not None:
+            seq_length = embeddings.shape[1]
+            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        # add positional encoding to each token
+        if interpolate_pos_encoding:
+            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+        else:
+            embeddings = embeddings + self.position_embeddings[:, 1:]
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+# modified from huggingface transformers ViTModel
+class ViTModelNoCLS(ViTModel):
+    """ViT Model without CLS token."""
+
+    def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
+        super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token)
+        self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token)
+        self.no_cls = True
+
+    def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, ViTEmbeddings):
+            module.position_embeddings.data = nn.init.trunc_normal_(
+                module.position_embeddings.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.position_embeddings.dtype)
+
+
+# modified from huggingface transformers ViTEmbeddings
+class ViTEmbeddingsReg(ViTEmbeddings):
+    """
+    ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1
+    """
+
+    def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7):
+        super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token)
+        self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))
+        self.num_reg_tokens = num_reg_tokens
+        self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))
+
+        self.reg_pos_embed.data = nn.init.trunc_normal_(
+            self.reg_pos_embed.data.to(torch.float32),
+            mean=0.0,
+            std=self.config.initializer_range,
+        ).to(self.reg_pos_embed.dtype)
+
+        self.reg_token.data = nn.init.trunc_normal_(
+            self.reg_token.data.to(torch.float32),
+            mean=0.0,
+            std=self.config.initializer_range,
+        ).to(self.reg_token.dtype)
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images.
+
+        Source:
+        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+        """
+
+        num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens
+        num_positions = self.position_embeddings.shape[1] - 1
+        if num_patches == num_positions and height == width:
+            return self.position_embeddings
+        class_pos_embed = self.position_embeddings[:, 0]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        reg_pos_embed = self.reg_pos_embed
+        dim = embeddings.shape[-1]
+        h0 = height // self.config.patch_size
+        w0 = width // self.config.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        h0, w0 = h0 + 0.1, w0 + 0.1
+        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+            mode="bicubic",
+            align_corners=False,
+        )
+        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1)
+
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> torch.Tensor:
+        batch_size, num_channels, height, width = pixel_values.shape
+        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+        if bool_masked_pos is not None:
+            seq_length = embeddings.shape[1]
+            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        # add the [CLS] token to the embedded patch tokens
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+        reg_tokens = self.reg_token.expand(batch_size, -1, -1)
+        embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1)
+
+        # add positional encoding to each token
+        if interpolate_pos_encoding:
+            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+        else:
+            embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1)
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+# modified from huggingface transformers ViTModel
+class ViTModelReg(ViTModel):
+    """ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1"""
+
+    def __init__(
+        self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7
+    ):
+        super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token)
+        self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens)
+        self.num_reg_tokens = num_reg_tokens
+
+    def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, ViTEmbeddings):
+            module.position_embeddings.data = nn.init.trunc_normal_(
+                module.position_embeddings.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.position_embeddings.dtype)
+            module.cls_token.data = nn.init.trunc_normal_(
+                module.cls_token.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.cls_token.dtype)
+
+
+class DeiT(nn.Module):
+    """DeiT model.
+
+    Paper: Training data-efficient image transformers & distillation through attention
+        https://arxiv.org/abs/2012.12877
+    Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit
+
+    Attributes:
+        model_name (str): name of the model.
+        pretrained (bool): whether to use pretrained weights.
+    """
+
+    def __init__(
+        self,
+        model_name: str = "facebook/deit-small-patch16-224",
+        pretrained: bool = False,
+        image_size: int = 224,
+    ):
+        super().__init__()
+        self.image_size = image_size
+        model = AutoModel.from_pretrained(model_name)
+        if pretrained:
+            self.model = model
+        else:
+            deit_config = model.config
+            self.model = AutoModel.from_config(deit_config)
+            del model
+
+        self.model.pooler = nn.Identity()
+
+        self.processor = AutoProcessor.from_pretrained(model_name)
+
+    def get_feature_size(
+        self,
+        keep_spatial: bool = False,
+        return_torch_size: bool = False,
+    ) -> torch.Size | tuple[int, ...]:
+        """Get the size of the feature.
+
+        Args:
+            keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
+            return_torch_size (bool): if true, return torch.Size type. Defaults to False.
+
+        Returns:
+            torch.Size | tuple[int, ...]: returned feature shape.
+        """
+        with torch.inference_mode():
+            image_size = (224, 224)
+            x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
+            y = self.forward(x)[:, 1:]  # for getting feature size, discard cls token
+            size = y.size()[1:][::-1]
+            if keep_spatial:
+                assert math.isqrt(size[-1])
+                h = w = int(math.sqrt(size[-1]))
+                size = (size[0], h, w)
+                if return_torch_size:
+                    size = torch.Size(size)
+            return size
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        do_resize: bool = True,
+        interpolate_pos_encoding: Optional[bool] = None,
+        do_rescale: bool = True,
+        do_normalize: bool = True,
+    ) -> torch.Tensor:
+        """Forward pass of the model
+
+        Args:
+            x (torch.Tensor): model input.
+
+            - arguments for self.processor. Details can be find at
+                https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
+            do_resize (bool): if do resizing in processor. Defaults to True.
+            interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
+            do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
+            do_normalize (bool): if do normalize in processor. Defaults to True.
+
+        Returns:
+            torch.Tensor: model output.
+        """
+        input = self.processor(
+            x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
+        ).to(self.model.device)
+        y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
+        return y.last_hidden_state
+
+
+class DeiTNoCLS(nn.Module):
+    """Modified DeiT model without CLS token."""
+
+    def __init__(
+        self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224
+    ):
+        super().__init__()
+        self.image_size = image_size
+        pretrained_model_name = model_name.replace("nocls-", "")
+        deit_config = AutoConfig.from_pretrained(pretrained_model_name)
+        self.model = ViTModelNoCLS(deit_config)
+        if pretrained:
+            pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
+            pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
+            self.load_state_dict(pretrained_dict, strict=False)
+            del pretrained_model, pretrained_dict
+
+        self.model.pooler = nn.Identity()
+        self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
+        self.no_cls = True
+
+    def get_feature_size(
+        self,
+        keep_spatial: bool = False,
+        return_torch_size: bool = False,
+    ) -> torch.Size | tuple[int, ...]:
+        """Get the size of the feature.
+
+        Args:
+            keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
+            return_torch_size (bool): if true, return torch.Size type. Defaults to False.
+
+        Returns:
+            torch.Size | tuple[int, ...]: returned feature shape.
+        """
+        with torch.inference_mode():
+            image_size = (self.image_size, self.image_size)
+            x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
+            y = self.forward(x)
+            size = y.size()[1:][::-1]
+            if keep_spatial:
+                assert math.isqrt(size[-1])
+                h = w = int(math.sqrt(size[-1]))
+                size = (size[0], h, w)
+                if return_torch_size:
+                    size = torch.Size(size)
+            return size
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        do_resize: bool = True,
+        interpolate_pos_encoding: Optional[bool] = None,
+        do_rescale: bool = True,
+        do_normalize: bool = True,
+    ) -> torch.Tensor:
+        """Forward pass of the model
+
+        Args:
+            x (torch.Tensor): model input.
+
+            - arguments for self.processor. Details can be find at
+                https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
+            do_resize (bool): if do resizing in processor. Defaults to True.
+            do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
+            do_normalize (bool): if do normalize in processor. Defaults to True.
+
+            - argument for forward
+            interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
+
+        Returns:
+            torch.Tensor: model output.
+        """
+        input = self.processor(
+            x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
+        ).to(self.model.device)
+        y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
+        return y.last_hidden_state
+
+
+class DeiTReg(nn.Module):
+    """Modified DeiT model with register tokens."""
+
+    def __init__(
+        self,
+        model_name: str = "reg-facebook/deit-small-patch16-224",
+        pretrained: bool = False,
+        image_size: int = 224,
+        num_reg_tokens: int = 7,
+    ):
+        super().__init__()
+        self.image_size = image_size
+        pretrained_model_name = model_name.replace("reg-", "")
+        deit_config = AutoConfig.from_pretrained(pretrained_model_name)
+        self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens)
+        if pretrained:
+            pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
+            pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
+            self.load_state_dict(pretrained_dict, strict=False)
+            del pretrained_model, pretrained_dict
+
+        self.model.pooler = nn.Identity()
+        self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
+        self.num_reg_tokens = num_reg_tokens
+
+    def get_feature_size(
+        self,
+        keep_spatial: bool = False,
+        return_torch_size: bool = False,
+    ) -> torch.Size | tuple[int, ...]:
+        """Get the size of the feature.
+
+        Args:
+            keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
+            return_torch_size (bool): if true, return torch.Size type. Defaults to False.
+
+        Returns:
+            torch.Size | tuple[int, ...]: returned feature shape.
+        """
+        with torch.inference_mode():
+            image_size = (self.image_size, self.image_size)
+            x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
+            y = self.forward(x)[:, 1 : -self.num_reg_tokens]
+            size = y.size()[1:][::-1]
+            if keep_spatial:
+                assert math.isqrt(size[-1])
+                h = w = int(math.sqrt(size[-1]))
+                size = (size[0], h, w)
+                if return_torch_size:
+                    size = torch.Size(size)
+            return size
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        do_resize: bool = True,
+        interpolate_pos_encoding: Optional[bool] = None,
+        do_rescale: bool = True,
+        do_normalize: bool = True,
+    ) -> torch.Tensor:
+        """Forward pass of the model
+
+        Args:
+            x (torch.Tensor): model input.
+
+            - arguments for self.processor. Details can be find at
+                https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
+            do_resize (bool): if do resizing in processor. Defaults to True.
+            interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
+            do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
+            do_normalize (bool): if do normalize in processor. Defaults to True.
+
+        Returns:
+            torch.Tensor: model output.
+        """
+        input = self.processor(
+            x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
+        ).to(self.model.device)
+        y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
+        return y.last_hidden_state
+
+
+def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module:
+    """Build the backbone visual encoder of robot vision foundation model.
+
+    Args:
+        model_name (str): name of the model.
+        pretrained (bool): whether to use pretrained weights. Defaults to False.
+        image_size (int): size of the image. Assume a square image. Defaults to 224
+        kwargs (Any): any kwargs specific to some models. For example,
+            `num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name`
+
+    Returns:
+        nn.Module: backbone network.
+    """
+    if "reg" in model_name:
+        return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
+    elif "nocls" in model_name:
+        return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
+    elif "deit" in model_name:
+        return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size)
+    else:
+        raise NotImplementedError(f"Requested {model_name} is not implemented.")
diff --git a/theia/models/feature_translators.py b/theia/models/feature_translators.py
new file mode 100644
index 0000000000000000000000000000000000000000..a30422b08963638609fa2a0f96defd8e4e744e40
--- /dev/null
+++ b/theia/models/feature_translators.py
@@ -0,0 +1,313 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import math
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+
+from theia.models.adapter_heads import ConvAdapterHead, LightConvAdapterHead, MLPAdapterHead, LinearAdapterHead
+
+
+class FeatureTranslator(nn.Module):
+    """Base class for the feature translator.
+
+    The flow is backbone_adapter -> translator_stem -> translator_heads.
+
+    Attributes:
+        backbone_feature_size (torch.Size): the size of features of the backbone.
+        target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
+        translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048.
+        target_model_names (list[str]): convenient attribute to hold all the names of the target models.
+
+        backbone_adapter (nn.Module): the adapter to map channel dim of backbone to the translator hidden dim.
+        translator_stem (nn.Module):  the shared stem for all target models.
+        translator_heads (nn.ModuleDict): specific heads for different target models.
+    """
+
+    def __init__(
+        self,
+        backbone_feature_size: torch.Size,
+        target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
+        translator_hidden_size: int = 1024,
+    ) -> None:
+        """Initalization function for FeatureTranslator.
+
+        Args:
+            backbone_feature_size (torch.Size): the size of features of the backbone.
+            target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models.
+            translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048.
+        """
+        super().__init__()
+        self.backbone_feature_size = backbone_feature_size  # (C, H, W)
+        self.target_feature_sizes = target_feature_sizes  # [(C, H, W)]
+        self.translator_hidden_size = translator_hidden_size  # C
+        self.target_model_names = list(target_feature_sizes.keys())
+        self.legit_target_model_name_map: dict[str, str] = {t: t.replace(".", "_") for t in self.target_model_names}
+        self.translator_heads: nn.ModuleDict = None
+
+        self.backbone_adapter = nn.Sequential(
+            nn.LayerNorm(self.backbone_feature_size[0]),  # do a pre-norm
+            nn.Linear(
+                self.backbone_feature_size[0],  # C in [C,H,W]
+                self.translator_hidden_size,
+            ),
+        )
+        self.translator_stem: nn.Module = nn.Identity()
+        self.build_translator_heads()
+
+    def build_translator_heads(self) -> None:
+        """Build translator heads to match the dimension of each target feature set.
+
+        Example:
+            translator_heads: dict[str, nn.Module] = ...
+            self.translator_heads = nn.ModuleDict(translator_heads)
+        """
+        raise NotImplementedError("build_translator_heads() should be overridden")
+
+    def forward(
+        self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False
+    ) -> torch.Tensor:
+        """Forward pass for a base feature translator.
+
+        Args:
+            x (torch.Tensor): input features from the backbone. [B, (1)+H*W, C].
+                (1) means optional CLS token. If `backbone_no_cls==True`, then [B, H*W, C].
+            target_model_names (Optional[list[str]]): names of the target models.
+            backbone_no_cls (bool): indicate backbone has cls token or not.
+                Can use it to customize whether to drop cls.
+
+        Returns:
+            dict[str, torch.Tensor]: predicted features for target models.
+        """
+        # x: [B, (1)+H*W, C]
+        x = self.backbone_adapter(x)  
+        x = self.translator_stem(x) 
+        target_model_names = target_model_names if target_model_names is not None else self.target_model_names
+        features = {t: self.translator_heads[self.legit_target_model_name_map[t]](x, backbone_no_cls=backbone_no_cls) for t in target_model_names}
+        return features
+
+
+class MLPFeatureTranslator(FeatureTranslator):
+    def __init__(
+        self,
+        backbone_feature_size: torch.Size,
+        target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
+        translator_hidden_size: int = 1024,
+        translator_n_layer: int = 3,
+    ) -> None:
+        """Initalization function for MLPFeatureTranslator.
+
+        Args:
+            backbone_feature_size (torch.Size): the size of features of the backbone.
+            target_feature_sizes (dict[str, torch.Size  |  tuple[int, ...]]): the sizes of features of target models.
+            translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048.
+            translator_n_layer (int): number of MLP layers. Defaults to 3.
+        """
+        self.translator_n_layer = translator_n_layer
+
+        super().__init__(
+            backbone_feature_size=backbone_feature_size,
+            target_feature_sizes=target_feature_sizes,
+            translator_hidden_size=translator_hidden_size,
+        )
+
+    def build_translator_heads(self) -> nn.ModuleDict:
+        """Build MLP translator heads to match the dimension of each target feature set."""
+        translator_heads = {}
+        source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:])
+        for target_model, target_size in self.target_feature_sizes.items():
+            head = MLPAdapterHead(source_size=source_size, target_size=target_size, num_layer=self.translator_n_layer)
+            translator_heads[self.legit_target_model_name_map[target_model]] = head
+        self.translator_heads = nn.ModuleDict(translator_heads)
+
+
+class ConvFeatureTranslator(FeatureTranslator):
+    def __init__(
+        self,
+        backbone_feature_size: torch.Size,
+        target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
+        translator_hidden_size: int = 1024,
+    ) -> None:
+        """Initalization function for ConvFeatureTranslator.
+
+        Args:
+            backbone_feature_size (torch.Size): the size of features of the backbone.
+            target_feature_sizes (dict[str, torch.Size  |  tuple[int, ...]]): the sizes of features of target models.
+            translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048.
+        """
+        super().__init__(
+            backbone_feature_size=backbone_feature_size,
+            target_feature_sizes=target_feature_sizes,
+            translator_hidden_size=translator_hidden_size,
+        )
+
+    def build_translator_heads(self) -> nn.ModuleDict:
+        """Build translator heads to match the dimension of each target feature set.
+
+        Returns:
+            nn.ModuleDict: the translator heads.
+        """
+        translator_heads = {}
+        source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:])
+        for target_model, target_size in self.target_feature_sizes.items():
+            head = ConvAdapterHead(source_size=source_size, target_size=target_size)
+            translator_heads[self.legit_target_model_name_map[target_model]] = head
+        self.translator_heads = nn.ModuleDict(translator_heads)
+
+
+class LightConvFeatureTranslator(FeatureTranslator):
+    def __init__(
+        self,
+        backbone_feature_size: torch.Size,
+        target_feature_sizes: dict[str, torch.Size | tuple[int, ...]],
+        translator_hidden_size: int = 1024,
+        hidden_size_factor: int | float = 1.0,
+    ) -> None:
+        """Initalization function for LightConvFeatureTranslator.
+            It's for a smaller translator compared to ConvFeatureTranslator.
+
+        Args:
+            backbone_feature_size (torch.Size): the size of features of the backbone.
+            target_feature_sizes (dict[str, torch.Size  |  tuple[int, ...]]): the sizes of features of target models.
+            translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 1024.
+            hidden_size_factor: the size of hidden dim of feature translator
+                as a factor of input feature hidden dim. Defaults to 1.0
+        """
+        self.hidden_size_factor = hidden_size_factor
+        super().__init__(
+            backbone_feature_size=backbone_feature_size,
+            target_feature_sizes=target_feature_sizes,
+            translator_hidden_size=translator_hidden_size,
+        )
+        self.backbone_adapter = nn.Identity()
+
+    def build_translator_heads(self) -> nn.ModuleDict:
+        """Build translator heads to match the dimension of each target feature set.
+
+        Returns:
+            nn.ModuleDict: the translator heads.
+        """
+        translator_heads = {}
+        for target_model, target_size in self.target_feature_sizes.items():
+            if "_cls" in target_model:
+                head = LinearAdapterHead(
+                    source_size=self.backbone_feature_size,
+                    target_size=target_size
+                )
+            else:
+                head = LightConvAdapterHead(
+                    source_size=self.backbone_feature_size, 
+                    target_size=target_size, 
+                    hidden_size_factor=self.hidden_size_factor
+                )
+            translator_heads[self.legit_target_model_name_map[target_model]] = head
+        self.translator_heads = nn.ModuleDict(translator_heads)
+
+
+class TransformerFreatureTranslator(FeatureTranslator):
+    def __init__(
+        self,
+        backbone_feature_size: torch.Size,
+        target_feature_sizes: dict[str, torch.Size | tuple[int, int]],
+        translator_hidden_size: int = 1024,
+        translator_n_layers: int = 2,
+        translator_n_heads: int = 8,
+        translator_activation: str = "gelu",
+    ) -> None:
+        super().__init__(
+            backbone_feature_size=backbone_feature_size,
+            target_feature_sizes=target_feature_sizes,
+            translator_hidden_size=translator_hidden_size,
+        )
+
+        self.translator_stem = nn.TransformerDecoder(
+            nn.TransformerDecoderLayer(
+                d_model=translator_hidden_size,
+                nhead=translator_n_heads,
+                dim_feedforward=translator_hidden_size * 2,
+                activation=translator_activation,
+                batch_first=True,
+                norm_first=True,
+            ),
+            num_layers=translator_n_layers,
+        )
+
+        self.decode_tokens = nn.Parameter(
+            torch.randn((1, math.prod(self.backbone_feature_size[1:]), translator_hidden_size))
+        )
+
+        self.target_model_emb = nn.ParameterDict(
+            {
+                self.legit_target_model_name_map[t]: torch.randn(1, 1, translator_hidden_size)
+                for t in self.target_model_names
+            }
+        )
+
+    def build_translator_heads(self) -> None:
+        """Build Transformer translator heads to match the dimension of each target feature set."""
+        translator_heads = {}
+        for target_model, target_size in self.target_feature_sizes.items():
+            head = MLPAdapterHead(
+                source_size=(self.translator_hidden_size, *self.backbone_feature_size[1:]),
+                target_size=target_size,
+                num_layer=2,
+            )
+            translator_heads[self.legit_target_model_name_map[target_model]] = head
+        self.translator_heads = nn.ModuleDict(translator_heads)
+
+    def forward(
+        self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False
+    ) -> torch.Tensor:
+        """Forward pass for a simple linear translator.
+
+        Args:
+            x (torch.Tensor): input features from the backbone.
+            target_model_names (Optional[str]): names of the target models.
+            backbone_no_cls (bool): indicate backbone has cls token or not.
+                Can use it to customize whether to drop cls.
+
+        Returns:
+            dict[str, torch.Tensor]: predicted features for target models.
+        """
+        if not backbone_no_cls:
+            x = x[:, 1:]
+        x = self.backbone_adapter(x)
+        features = {}
+        target_model_names = target_model_names if target_model_names is not None else self.target_model_names
+        for t in target_model_names:
+            feature = self.translator_stem(
+                torch.cat(
+                    [
+                        self.decode_tokens.repeat(x.size(0), 1, 1),
+                        self.target_model_emb[self.legit_target_model_name_map[t]].repeat(x.size(0), 1, 1),
+                    ],
+                    dim=1,
+                ),
+                memory=x,
+            )[:, 1:, ...]
+            features[t] = self.translator_heads[self.legit_target_model_name_map[t]](feature)
+        return features
+
+
+def build_feature_translator(translator_type: str, **kwargs: Any) -> FeatureTranslator:
+    """Handy function to build feature translators given the type
+
+    Args:
+        translator_type (str): the type of the translator,
+            one in `"mlp"`, `"conv"`, `"lconv"`, `"transformer"` (or `"trans"`).
+            At the moment we are actively using `"lconv"`.
+
+    Returns:
+        FeatureTranslator: the corresponding FeatureTranslator
+    """
+    if translator_type == "mlp":
+        return MLPFeatureTranslator(**kwargs)
+    elif translator_type == "conv":
+        return ConvFeatureTranslator(**kwargs)
+    elif translator_type == "lconv":
+        return LightConvFeatureTranslator(**kwargs)
+    elif translator_type == "transformer" or translator_type == "trans":
+        return TransformerFreatureTranslator(**kwargs)
+    else:
+        raise NotImplementedError(f"Requested {translator_type} is not implemented yet.")
diff --git a/theia/models/rvfm.py b/theia/models/rvfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..958a9ba781382b1014ba2a2c08090d9923fee950
--- /dev/null
+++ b/theia/models/rvfm.py
@@ -0,0 +1,185 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from omegaconf import OmegaConf
+
+from theia.models.backbones import build_backbone
+from theia.models.feature_translators import build_feature_translator
+from theia.models.utils import handle_feature_output
+
+
+class RobotVisionFM(nn.Module):
+    """Robot Vision Foundation Model (temporary name).
+
+    Attributes:
+        backbone (str | nn.Module): backbone network. Defaults to "deit-small-patch16-224".
+        pretrained (bool): whether to use pretrained weights. Default to False.
+        translator (str | nn.Module): feature translator module. Defaults to "conv".
+        target_feature_sizes (Optional[dict[str, torch.Size | tuple[int, ...]]]):
+            a dict to hold target feature sizes.
+        translator_kwargs (Optional[dict[str, Any]]): other keyword arguments to the translator.
+        target_loss_weights (Optional[dict[str, float]]):
+            weights to balance loss from different target models. If not specified, use even weights.
+        checkpoint_path: (Optional[str]): filename of pretrained weights to load.
+        feature_reduce_method: (Optional[str]): how to reduce the feature in downstream applications.
+    """
+
+    def __init__(
+        self,
+        backbone: str | nn.Module = "deit-small-patch16-224",
+        pretrained: bool = False,
+        translator: str | nn.Module = "lconv",
+        target_feature_sizes: Optional[dict[str, torch.Size | tuple[int, ...]]] = None,
+        translator_kwargs: Optional[dict[str, Any]] = None,
+        target_loss_weights: Optional[dict[str, float]] = None,
+        checkpoint_path: Optional[str] = None,
+        feature_reduce_method: Optional[str] = None,
+        image_size: int = 224,
+        **kwargs: Any
+    ) -> None:
+        super().__init__()
+
+        self.target_feature_sizes = target_feature_sizes
+        self.preprocessor = None
+        self.pretrained = pretrained
+
+        # backbone
+        self.image_size = image_size
+        self.backbone: nn.Module = build_backbone(backbone, pretrained, image_size=image_size, **kwargs)
+        self.final_spatial = None
+        if hasattr(self.backbone, "final_spatial"):
+            self.final_spatial = self.backbone.final_spatial
+
+        # handle output feature (feature reduce)
+        self.feature_reduce_method = feature_reduce_method
+        self.no_cls = hasattr(self.backbone, "no_cls")
+        self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0
+
+        # translator
+        backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True)
+        if self.target_feature_sizes:
+            translator_kwargs = {} if translator_kwargs is None else OmegaConf.to_container(translator_kwargs)
+            translator_kwargs["backbone_feature_size"] = backbone_feature_size
+            translator_kwargs["target_feature_sizes"] = target_feature_sizes
+            self.translator = build_feature_translator(translator, **translator_kwargs)
+
+        # loss
+        self.mse_loss = nn.MSELoss()
+        self.l1_loss = nn.SmoothL1Loss()
+        self.cos_loss = nn.CosineEmbeddingLoss()
+        self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False)
+        self.target_loss_weights = target_loss_weights
+
+    def load_pretrained_weights(self, checkpoint_path: str):
+        """Load pretrained weights.
+
+        Args:
+            checkpoint_path (str): path to checkpoint / weight.
+        """
+        if checkpoint_path:
+            weights_dict = torch.load(checkpoint_path, map_location="cpu")
+            # Filter out unnecessary keys
+            pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()}
+            self.load_state_dict(pretrained_dict, strict=False)
+
+    def freeze_translator(self) -> None:
+        """Freeze the feature translator."""
+        for param in self.translator.parameters():
+            param.requires_grad = False
+
+    def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+        """Forward RVFM feature only (before translators).
+
+        Args:
+            x (torch.Tensor): input image. By default it accepts images 
+                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
+            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
+                `do_resize` (bool) defaults to True.
+                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
+                `do_rescale` (bool) defaults to True.
+                `do_normalize` (bool) defaults to True.
+
+        Returns:
+            torch.Tensor: RVFM feature.
+        """
+        feature = self.backbone(x, **kwargs)
+        # [B, 1+H*W+N, C] if including both CLS and register tokens.
+        # [B, 1+H*W, C] for standard model (N=0).
+        # [B, H*W, C] for model without CLS.
+        return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens)
+
+    def forward(self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, **kwargs: Any) -> dict[str, torch.Tensor]:
+        """Forward pass of Robot Vision Foundation Model.
+
+        Args:
+            x (torch.Tensor): input image. By default it accepts images 
+                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
+            target_model_names (Optional[list[str]]): names of the target foundation models.
+            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
+                `do_resize` (bool) defaults to True.
+                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
+                `do_rescale` (bool) defaults to True.
+                `do_normalize` (bool) defaults to True.
+
+        Returns:
+            dict[str, torch.Tensor]: features that match to each foundation model.
+                Each feature is in [B, (H*W), C] or [B, C].
+        """
+        x = self.backbone(x, **kwargs)
+        if self.num_reg_tokens > 0:
+            x = x[:, :-self.num_reg_tokens]  # [B, (1)+H*W, C]
+        features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls)  # each is [B, H*W, C] or [B, C]
+        return features
+
+    def get_loss(self, pred_features: dict[str, torch.Tensor], y: dict[str, torch.Tensor]) -> dict[str, Any]:
+        """Get loss terms given predictions and targets.
+
+        Args:
+            pred_features (dict[str, torch.Tensor]): predictions.
+            y (dict[str, torch.Tensor]): targets.
+
+        Returns:
+            tuple[Any, ...]: loss terms
+        """
+        mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0
+        mse_losses_per_model = {}
+        cos_losses_per_model = {}
+        l1_losses_per_model = {}
+
+        for t in pred_features:
+            pred = pred_features[t]
+            target = y[t]
+
+            # mse loss
+            mse_loss = self.mse_loss(pred, target)
+            weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features)
+
+            # l1 loss
+            l1_loss = self.l1_loss(pred, target)
+
+            # cos loss
+            pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2)
+            target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2)
+            target = self.cos_target.repeat(pred.size(0)).to(pred.device)
+            cos_loss = self.cos_loss(pred_norm, target_norm, target)
+
+            mse_loss_avg += mse_loss * weight
+            cos_loss_avg += cos_loss / len(pred_features)  # balance cos by default for meaningful eval
+            l1_loss_avg += l1_loss * weight
+
+            mse_losses_per_model[t] = mse_loss.item()
+            cos_losses_per_model[t] = cos_loss.item()
+            l1_losses_per_model[t] = l1_loss.item()
+
+        return {
+            "mse_loss": mse_loss_avg,
+            "cos_loss": cos_loss_avg,
+            "l1_loss": l1_loss_avg,
+            "mse_losses_per_model": mse_losses_per_model,
+            "cos_losses_per_model": cos_losses_per_model,
+            "l1_losses_per_model": l1_losses_per_model,
+        }
diff --git a/theia/models/utils.py b/theia/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1391c375e87c207c6917de3a50ed3a596065c08e
--- /dev/null
+++ b/theia/models/utils.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Optional
+
+import torch
+
+
+def handle_feature_output(
+    x: torch.Tensor, feature_reduce_method: Optional[str] = None, num_discard_tokens: int = 0
+) -> torch.Tensor:
+    """Handle feature output from transformer.
+
+    Args:
+        x (torch.Tensor): input feature to be handled. shape is
+            [B, 1+H*W+N, C] if including both CLS and register tokens.
+            [B, 1+H*W, C] for standard model (N=0).
+            [B, H*W, C] for model without CLS.
+        feature_reduce_method (Optional[str]): method to select token. Options:
+            - `mean_pooling`: average over spatial tokens (non CLS tokens), output shape = [B, C].
+            - `max_pooling`: max over spatial tokens, output shape = [B, C].
+            - `cls`: return CLS token only, output shape = [B, C].
+            - `identity`: return the feature without touching it, output shape = input shape.
+            - `None`: return spatial tokens, output shape = [B, H*W, C] (assuming input is [B, 1+H*W, C]).
+            suppose raw feature is in shape [B, 1+H*W, C], `1` corresponds to CLS token.
+        num_discard_tokens (int):
+            number of tokens to be discarded. Assuming they are at the end of the sequence.
+    Returns:
+        torch.Tensor: selected feature tokens.
+    """
+
+    match feature_reduce_method:
+        case "mean_pooling":
+            return torch.mean(x[:, 1 : x.size(1) - num_discard_tokens], dim=1)  # [B, C]
+        case "max_pooling":
+            return torch.amax(x[:, 1 : x.size(1) - num_discard_tokens], dim=1)  # [B, C]
+        case "cls":
+            return x[:, 0]  # [B, C]
+        case "identity":
+            return x
+        case None:
+            return x[:, 1 : x.size(1) - num_discard_tokens]
+        case _:
+            raise NotImplementedError(f"feature_reduce_method {feature_reduce_method} it not implemented.")
diff --git a/theia/models/vfm.py b/theia/models/vfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b4e21970ff8f36c26b035f55c4a694f080df975
--- /dev/null
+++ b/theia/models/vfm.py
@@ -0,0 +1,204 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+from theia.foundation_models import get_clip_model, get_deit_model, get_dinov2_model, get_sam_model, get_vit_model
+from transformers import AutoImageProcessor, AutoModel
+
+from theia.models.utils import handle_feature_output
+
+
+class VFMEncoder(nn.Module):
+    """Wrapper class of an individual VFM Encoder for feature extraction.
+
+    Attrs:
+        model_name (str): name of the model.
+        feature_reduce_method (str): how to select the output feature token and shape.
+        processor (AutoProcessor): input pre-processor.
+    """
+
+    def __init__(self, model_name: str, feature_reduce_method: Optional[str] = None, **kwargs: Any):
+        """Instanciate a (off-the-shelf) VFM encoder.
+
+        Args:
+            model_name (str): name of the encoder
+            feature_reduce_method (Optional[str]): how to select the output feature token and shape. Defaults to None.
+            **kwargs (Any): anything not needed got pass-through
+        """
+        super().__init__()
+        self.model_name = model_name
+        if "google/vit" in model_name:
+            model, processor = get_vit_model(model_name, device="cpu")
+        elif "facebook/dino" in model_name:
+            model, processor = get_dinov2_model(model_name, device="cpu")
+        elif "facebook/sam" in model_name:
+            model, processor = get_sam_model(model_name, device="cpu")
+        elif "openai/clip" in model_name:
+            model, processor = get_clip_model(model_name, device="cpu")
+        elif "facebook/deit" in model_name:
+            model, processor = get_deit_model(model_name, device="cpu")
+        elif "nvidia" in model_name:
+            model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
+            processor = AutoImageProcessor.from_pretrained(model_name)
+        elif "mvp" in model_name:
+            import mvp
+
+            model_name_mvp = model_name.replace("mvp-", "")
+            model = mvp.load(model_name_mvp)
+            processor = None
+        elif "vip" in model_name:
+            from vip import load_vip
+
+            model = load_vip()
+            processor = None
+        elif "r3m" in model_name:
+            from r3m import load_r3m
+
+            model_name_r3m = model_name.replace("r3m-", "")
+            model = load_r3m(model_name_r3m)
+            processor = None
+        else:
+            raise NotImplementedError(f"{model_name} is not supported in theia.models.vfm.VFM")
+
+        self.model = model
+        self.processor = processor
+        self.feature_reduce_method = feature_reduce_method
+        if "image_size" in kwargs:
+            self.image_size = kwargs["image_size"]
+        if "final_spatial" in kwargs:
+            self.final_spatial = kwargs["final_spatial"]
+
+    def get_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+        """Return the feature from the VFM.
+
+        Args:
+            x (torch.Tensor): input image.
+            kwargs: any arguments pass-through (mainly for processor currently).
+                For example, `do_rescale`, `do_resize`, `interpolate_pos_encoding`
+                    to control image preprocessing pipeline.
+
+        Returns:
+            torch.Tensor: feature.
+        """
+        if (
+            "google/vit" in self.model_name
+            or "facebook/dinov2" in self.model_name
+            or "facebook/deit" in self.model_name
+        ):
+            inputs = self.processor(x, return_tensors="pt", **kwargs).to(self.model.device)
+            feature = self.model(**inputs).last_hidden_state
+        elif "openai/clip" in self.model_name:
+            inputs = self.processor(images=x, return_tensors="pt", **kwargs).to(self.model.device)
+            feature = self.model(**inputs).last_hidden_state
+        elif "facebook/sam" in self.model_name:
+            inputs = self.processor(x, return_tensors="pt", **kwargs).to(self.model.device)
+            feature = self.model(**inputs).image_embeddings
+        elif "nvidia" in self.model_name:
+            inputs = (
+                self.processor(images=x, return_tensors="pt", **kwargs)
+                .pixel_values.to(torch.bfloat16)
+                .to(self.model.device)
+            )
+            summary, feature = self.model(inputs)
+            if self.feature_reduce_method == "cls_identity":
+                feature = summary.to(torch.float32)
+            else:
+                feature = feature.to(torch.float32)
+        elif "mvp" in self.model_name:
+            feature = self.model(x)
+        elif "vip" in self.model_name:
+            feature = self.model(x)
+        elif "r3m" in self.model_name:
+            feature = self.model(x)
+        return feature
+
+    def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+        """Forward method, including getting the feature and handle the output token / shape.
+
+        Args:
+            x (torch.Tensor): input image.
+
+        Returns:
+            torch.Tensor: output feature with token or shape handled.
+        """
+        feature = self.get_feature(x, **kwargs)  # [B, 1+H*W, C]
+        return handle_feature_output(feature, self.feature_reduce_method)
+
+    def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+        """Alias of forward() to accommandate some downstream usage.
+
+        Args:
+            x (torch.Tensor): input image.
+
+        Returns:
+            torch.Tensor: output feature with token or shape handled.
+        """
+        return self.forward(x, **kwargs)
+
+
+class ConcatVFMEncoder(nn.Module):
+    """Wrapper class that combines features from multiple VFM Encoders. The combination is channel-wise concatenation.
+
+    Attrs:
+        model_names (list[str]): names of the models.
+        feature_reduce_method (Optional[str]): how to select the output feature token and shape.
+        model (nn.ModuleDict): a dict to hold different VFM encoders.
+    """
+
+    def __init__(self, model_names: list[str], feature_reduce_method: Optional[str] = None, **kwargs: Any):
+        """Instanciate a (off-the-shelf) VFM encoder.
+
+        Args:
+            model_names (list[str]): name of the encoder
+            feature_reduce_method (str, optional): how to select the output feature token and shape. Defaults to None.
+            **kwargs (Any): anything not needed got pass-through
+        """
+        super().__init__()
+        self.model_names = model_names
+        self.model = {}
+        for model_name in model_names:
+            model = VFMEncoder(model_name, feature_reduce_method=feature_reduce_method, **kwargs)
+            self.model[model_name] = model
+
+        self.model = nn.ModuleDict(self.model)
+        self.feature_reduce_method = feature_reduce_method
+
+    def get_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+        """Get different features from VFMs.
+
+        Args:
+            x (torch.Tensor): input image.
+
+        Returns:
+            torch.Tensor: features concatenated at channel dimension.
+        """
+        features = []
+        for model_name in self.model_names:
+            features.append(self.model[model_name](x, **kwargs))
+        features = torch.cat(features, dim=-1)
+        return features
+
+    def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+        """Forward method, including getting the feature and handle the output token / shape.
+
+        Args:
+            x (torch.Tensor): input image.
+
+        Returns:
+            torch.Tensor: output feature with token or shape handled.
+        """
+        feature = self.get_feature(x, **kwargs)  # [B, 1+H*W, C]
+        return handle_feature_output(feature, self.feature_reduce_method)
+
+    def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+        """Alias of forward() to accommandate some downstream usage.
+
+        Args:
+            x (torch.Tensor): input image.
+
+        Returns:
+            torch.Tensor: output feature with token or shape handled.
+        """
+        return self.forward(x, **kwargs)
diff --git a/theia/optimizers/__init__.py b/theia/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/optimizers/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/optimizers/utils.py b/theia/optimizers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db1331a8fb940f18af7ebe48b49ba5285eec149
--- /dev/null
+++ b/theia/optimizers/utils.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Any, Iterable
+
+import torch.nn as nn
+
+
+def param_groups_weight_decay(
+    model: nn.Module, weight_decay: float = 1e-5, no_weight_decay_parameters: Iterable[str] = ()
+) -> list[dict[str, Any]]:
+    """Group parameters into sets with decay applied and no decay.
+
+    Args:
+        model (nn.Module): the model.
+        weight_decay (float): weight decay. Defaults to 1e-5.
+        no_weight_decay_parameters (Iterable[str]): parameters added to no weight decay
+            in addition to defaults. Defaults to ().
+
+    Returns:
+        list[dict[str, Any]]: parameter groups with different weight decay values.
+            Follow the format required by torch.optim.Optimizer.
+    """
+    no_weight_decay_parameters = set(no_weight_decay_parameters)
+    decay = []
+    no_decay = []
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue
+
+        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_parameters:
+            no_decay.append(param)
+        else:
+            decay.append(param)
+
+    return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]
+
+
+def param_groups_lr_weight_decay(
+    model: nn.Module,
+    backbone_lr: float = 1e-3,
+    translator_lr: float = 1e-3,
+    weight_decay: float = 1e-5,
+    no_weight_decay_parameters: Iterable[str] = (),
+) -> list[dict[str, Any]]:
+    """Group parameters into set with decay applied and no decay.
+
+    Args:
+        model (nn.Module): the model.
+        weight_decay (float): weight decay. Defaults to 1e-5.
+        no_weight_decay_parameters (Iterable[str]): parameters added to no weight decay
+            in addition to defaults. Defaults to ().
+
+    Returns:
+        list[dict[str, Any]]: parameter groups with different weight decay values.
+            Follow the format required by torch.optim.Optimizer.
+    """
+    no_weight_decay_parameters = set(no_weight_decay_parameters)
+    decay_backbone = []
+    no_decay_backbone = []
+    decay_translator = []
+    no_decay_translator = []
+
+    for name, param in model.module.backbone.named_parameters():
+        if not param.requires_grad:
+            continue
+
+        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_parameters:
+            no_decay_backbone.append(param)
+        else:
+            decay_backbone.append(param)
+
+    for name, param in model.module.translator.named_parameters():
+        if not param.requires_grad:
+            continue
+
+        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_parameters:
+            no_decay_translator.append(param)
+        else:
+            decay_translator.append(param)
+
+    return [
+        {"params": no_decay_backbone, "weight_decay": 0.0, "lr": backbone_lr},
+        {"params": decay_backbone, "weight_decay": weight_decay, "lr": backbone_lr},
+        {"params": no_decay_translator, "weight_decay": 0.0, "lr": translator_lr},
+        {"params": decay_translator, "weight_decay": weight_decay, "lr": translator_lr},
+    ]
diff --git a/theia/preprocessing/feature_extraction_core/__init__.py b/theia/preprocessing/feature_extraction_core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d1f2b64556713b62ff6daceee10c1f2a43b8290
--- /dev/null
+++ b/theia/preprocessing/feature_extraction_core/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from .models import get_feature_outputs, get_model, get_models
+from .webdataset_utils import check_existing_shard, decode_image_npy_only, read_shard
diff --git a/theia/preprocessing/feature_extraction_core/models.py b/theia/preprocessing/feature_extraction_core/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d1b0d351bc01200a000c0d331f2c738b1a51f4
--- /dev/null
+++ b/theia/preprocessing/feature_extraction_core/models.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Any
+
+import torch
+import torch.nn as nn
+from numpy.typing import NDArray
+from torch.nn.functional import interpolate
+from theia.foundation_models import (
+    get_clip_feature,
+    get_clip_model,
+    get_depth_anything_feature,
+    get_depth_anything_model,
+    get_dinov2_feature,
+    get_dinov2_model,
+    get_llava_vision_model,
+    get_llava_visual_feature,
+    get_sam_feature,
+    get_sam_model,
+    get_vit_feature,
+    get_vit_model,
+)
+
+
+def get_model(model_name: str, device: int | str | torch.device = "cpu") -> tuple[nn.Module, Any]:
+    if "google/vit" in model_name:
+        model, processor = get_vit_model(model_name, device=device)
+    elif "facebook/sam" in model_name:
+        model, processor = get_sam_model(model_name, device=device, with_upscaled=False)
+    elif "openai/clip" in model_name:
+        model, processor = get_clip_model(model_name, device=device)
+    elif "facebook/dinov2" in model_name:
+        model, processor = get_dinov2_model(model_name, device=device)
+    elif "llava" in model_name:
+        model, processor = get_llava_vision_model(model_name, device=device)
+    elif "depth-anything" in model_name:
+        model, processor = get_depth_anything_model(model_name, device=device, selected_feature="head")
+    else:
+        raise NotImplementedError(f"{model_name} is not implemented")
+    return model, processor
+
+
+def get_models(
+    model_names: list[str], device: int | str | torch.device = "cpu"
+) -> tuple[dict[str, nn.Module], dict[str, Any]]:
+    models: dict[str, nn.Module] = {}
+    processors: dict[str, Any] = {}
+    for model_name in model_names:
+        model, processor = get_model(model_name, device)
+        models[model_name.replace("/", "_")] = model
+        processors[model_name.replace("/", "_")] = processor
+    return models, processors
+
+
+def get_feature_outputs(
+    model_name: str, model: nn.Module, processor: Any, batch_images: list[NDArray], dtype: torch.dtype = torch.bfloat16
+) -> dict[str, dict[str, torch.Tensor]]:
+    features: dict[str, dict[str, torch.Tensor]] = {model_name: {}}
+    if "google_vit" in model_name:
+        cls_token, feature = get_vit_feature(model, processor, batch_images)
+        features[model_name] = {
+            "cls_token": cls_token.detach().cpu().to(dtype).contiguous(),
+            "embedding": feature.detach().cpu().to(dtype).contiguous()
+        }
+    elif "facebook_sam" in model_name:
+        feature, upscaled_feature = get_sam_feature(model, processor, batch_images)
+        features[model_name] = {"embedding": feature.detach().cpu().to(dtype).contiguous()}
+        features[model_name + "_32"] = {
+            "embedding": interpolate(feature, (32, 32)).detach().cpu().to(dtype).contiguous()
+        }
+
+        if upscaled_feature:
+            features[model_name]["upscaled_embedding"] = upscaled_feature.detach().cpu().to(dtype).contiguous()
+    elif "openai_clip" in model_name:
+        cls_token, visual_tokens, pooled_cls_token = get_clip_feature(model, processor, batch_images)
+        features[model_name] = {
+            "embedding": visual_tokens.detach().cpu().to(dtype).contiguous(),
+            "cls_token": cls_token.detach().cpu().to(dtype).contiguous(),
+            "pooled_cls_token": pooled_cls_token.detach().cpu().to(dtype).contiguous(),
+        }
+    elif "facebook_dinov2" in model_name:
+        cls_token, visual_tokens, pooled_cls_token = get_dinov2_feature(model, processor, batch_images)
+        features[model_name] = {
+            "embedding": visual_tokens.detach().cpu().to(dtype).contiguous(),
+            "cls_token": cls_token.detach().cpu().to(dtype).contiguous(),
+            "pooled_cls_token": pooled_cls_token.detach().cpu().to(dtype).contiguous(),
+        }
+    elif "llava" in model_name:
+        feature = get_llava_visual_feature(model, processor, batch_images)
+        features[model_name] = {"embedding": feature.detach().cpu().to(dtype).contiguous()}
+    elif "depth-anything" in model_name:
+        feature = get_depth_anything_feature(model, processor, batch_images)
+        features[model_name] = {"embedding": interpolate(feature, (64, 64)).detach().cpu().to(dtype).contiguous()}
+    else:
+        raise NotImplementedError(f"model {model_name} is not supported")
+
+    return features
diff --git a/theia/preprocessing/feature_extraction_core/webdataset_utils.py b/theia/preprocessing/feature_extraction_core/webdataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9400eb973c329256c37306cbfb0b811f4fe1d567
--- /dev/null
+++ b/theia/preprocessing/feature_extraction_core/webdataset_utils.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import os
+import tarfile
+from io import BytesIO
+
+import cv2
+import numpy as np
+from numpy.typing import NDArray
+
+
+def check_existing_shard(path: str, keys: list[str]) -> tuple[int, dict]:
+    """
+    Check the integrity of a shard given path.
+
+    Returns:
+        tuple[int, dict]:
+            code (int): 1 the file is ok, 0 not
+            count_per_key (dict): if the file is ok, how many samples are generated per key
+    """
+    count_per_key = {k: 0 for k in keys}
+    if os.path.exists(path):
+        try:
+            with tarfile.open(path, "r") as tarf:
+                tar_members = tarf.getmembers()
+                tar_members = sorted(tar_members, key=lambda x: x.name)
+                for tar_mem in tar_members:
+                    for k in keys:
+                        if k in tar_mem.name:
+                            count_per_key[k] += 1
+            return 1, count_per_key
+        except tarfile.TarError:
+            return 0, count_per_key
+    else:
+        return 0, count_per_key
+
+
+def read_shard(path: str) -> dict[str, bytes]:
+    """Read a (half) processed tar shard and store file contents in bytes.
+
+    The tar should be complete to read.
+
+    Args:
+        path (str): path to the tar file.
+
+    Returns:
+        dict[str, bytes]: tarfile content in a dictionary where key is the tarinfo.name of each member
+    """
+    samples = {}
+    with tarfile.open(path, "r") as tarf:
+        tar_members = tarf.getmembers()
+        tar_members = sorted(tar_members, key=lambda x: x.name)
+        for tar_mem in tar_members:
+            f = tarf.extractfile(tar_mem.name)
+            if f:
+                samples[tar_mem.name] = f.read()
+    return samples
+
+
+def decode_image_npy_only(key: str, data: bytes) -> NDArray | bytes:
+    if "image" in key:
+        image = np.load(BytesIO(data))
+        if len(image.shape) == 2:
+            return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
+        elif len(image.shape) == 3 and image.shape[-1] == 4:
+            return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
+        else:
+            return image
+    else:
+        return data
diff --git a/theia/scripts/decoding/decoding_example.py b/theia/scripts/decoding/decoding_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..da3a5313782472b38dec17f64f035772863232cc
--- /dev/null
+++ b/theia/scripts/decoding/decoding_example.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+"""
+Example script to decode features from theia model to corresponding visual task output,
+    including DINOv2 visualization, SAM segmentation masks, and Depth Anything predicted depths.
+"""
+
+import argparse
+import os
+
+import cv2
+import numpy as np
+import torch
+import transformers
+
+from PIL import Image
+from theia.foundation_models.common import get_model_feature_size
+from theia.decoding import decode_everything, load_feature_stats, prepare_depth_decoder, prepare_mask_generator
+from theia.models.rvfm import RobotVisionFM
+from theia.utils.seed import seed_everything
+from torchvision.io import read_video, write_video
+
+transformers.logging.set_verbosity_error()
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--backbone", type=str, default="facebook/deit-tiny-patch16-224", help="name of the backbone")
+    parser.add_argument("--checkpoint-path", type=str, help="path to the model weights")
+    parser.add_argument("--feature-stat-dir", type=str, help="the directory to find feature stats")
+    parser.add_argument("--media-to-vis-path", type=str, help="the location of source video / image for visualization")
+    parser.add_argument(
+        "--vis-output-dir", type=str, default="./vis_output/", help="output dir to save visualization result"
+    )
+    args = parser.parse_args()
+    seed_everything(0)
+    device = 0
+
+    target_model_names = [
+        "google/vit-huge-patch14-224-in21k",
+        "facebook/dinov2-large",
+        "openai/clip-vit-large-patch14",
+        "facebook/sam-vit-huge",
+        "LiheYoung/depth-anything-large-hf",
+    ]
+    target_feature_sizes = {t: get_model_feature_size(t, keep_spatial=True) for t in target_model_names}
+    theia_model = RobotVisionFM(
+        translator="lconv", target_feature_sizes=target_feature_sizes, backbone=args.backbone, pretrained=False
+    )
+
+    theia_model.load_pretrained_weights(args.checkpoint_path)
+    theia_model = theia_model.to(device)
+    feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root=args.feature_stat_dir)
+
+    mask_generator, sam_model = prepare_mask_generator(device)
+    depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
+    depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device)
+
+    if args.media_to_vis_path.lower().endswith((".mp4")):
+        video, _, _ = read_video(args.media_to_vis_path, pts_unit="sec", output_format="THWC")
+        video = video.numpy()
+        images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video]
+    elif args.media_to_vis_path.lower().endswith((".jpg", ".png", ".jpeg", ".bmp")):
+        images = [Image.open(args.media_to_vis_path).resize((224, 224))]
+
+    theia_decode_results, gt_decode_results = decode_everything(
+        theia_model=theia_model,
+        feature_means=feature_means,
+        feature_vars=feature_vars,
+        images=images,
+        mask_generator=mask_generator,
+        sam_model=sam_model,
+        depth_anything_decoder=depth_anything_decoder,
+        pred_iou_thresh=0.5,
+        stability_score_thresh=0.7,
+        gt=True,
+        device=device,
+    )
+
+
+    if not os.path.exists(args.vis_output_dir):
+        os.makedirs(args.vis_output_dir)
+    if len(images) > 1:
+        vis_output_save_fn = (
+            f"{args.media_to_vis_path.split('/')[-1].split('.')[0]}_{args.checkpoint_path.split('/')[-1].replace('.pth', '')}.mp4"
+        )
+        vis_video = np.stack(
+            [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]
+        )
+        vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8)
+        
+        vis_save_path = os.path.join(args.vis_output_dir, vis_output_save_fn)
+        write_video(vis_save_path, vis_video, fps=10)
+    else:
+        vis_output_save_fn = (
+            f"{args.media_to_vis_path.split('/')[-1].split('.')[0]}_{args.checkpoint_path.split('/')[-1].replace('.pth', '')}.png"
+        )
+        vis_image = np.stack(
+            [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]
+        )
+        vis_image = Image.fromarray((vis_image * 255.0).astype(np.uint8)[0])
+        vis_save_path = os.path.join(args.vis_output_dir, vis_output_save_fn)
+        vis_image.save(vis_save_path)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/theia/scripts/preprocessing/calc_feature_mean.py b/theia/scripts/preprocessing/calc_feature_mean.py
new file mode 100644
index 0000000000000000000000000000000000000000..f27bb8a4c495138b74df1f8641c89f172643ab15
--- /dev/null
+++ b/theia/scripts/preprocessing/calc_feature_mean.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+"""
+Calculate the channel-wise mean and var of extracted features on ImageNet dataset.
+The resulting mean and var will be used in distillation process.
+"""
+
+import argparse
+import glob
+import os
+from io import BytesIO
+
+import numpy as np
+import torch
+import webdataset as wds
+from einops import rearrange
+from safetensors.torch import load as sft_load
+from torch.utils.data import default_collate
+
+
+def decode_dataset_sample(key: str, data: bytes) -> bytes | torch.Tensor:
+    """
+    Decode a feature / column in webdataset sample in bytes to its original format.
+
+    Args:
+        key (str): name of the feature / column.
+        data (bytes): data in bytes.
+
+    Returns:
+        bytes | torch.Tensor: decoded feature.
+    """
+    if ".safetensors" in key:
+        sft = sft_load(data)
+        return rearrange(sft["embedding"], "c h w -> (h w) c")
+    elif key == ".image":
+        return torch.from_numpy(np.load(BytesIO(data)))
+    else:
+        return data
+
+
+def main() -> None:
+    """Entry point of this script for calculating mean and var."""
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset-path", type=str)
+    parser.add_argument("--output-path", type=str)
+    args = parser.parse_args()
+
+    all_datasets = {}
+    all_datasets.update({"imagenet": {"steps": 1_281_167}})
+    ds_dir = args.dataset_path
+    models = [m for m in os.listdir(ds_dir) if os.path.isdir(os.path.join(ds_dir, m))]
+    for model in models:
+        print(model)
+        if model == "images" or model == "image" or model == "images_val":
+            continue
+        if os.path.exists(f"{args.output_path}/imagenet_mean_{model}.npy"):
+            continue
+        model_mean: torch.Tensor = None
+        model_var_sum: torch.Tensor = None
+        n = 0
+        ds = (
+            wds.WebDataset(
+                sorted(glob.glob(f"{ds_dir}/{model}/*.tar")),
+                shardshuffle=False,
+            )
+            .decode(decode_dataset_sample)
+            .batched(256, collation_fn=default_collate)
+        )
+
+        key = f"{model}.safetensors".lower()
+        for batch_idx, batch in enumerate(ds):
+            if model_mean is None:
+                model_mean = torch.zeros((batch[key].size(-1)))
+            new_n = np.prod(batch[key].size()[:2])
+            batch_mean = batch[key].float().mean((0, 1))
+            model_mean = (model_mean * n + batch_mean * new_n) / (n + new_n)
+            n += new_n
+            print(f"calc {model} mean {batch_idx*256:07d}\r", end="")
+
+        model_mean_npy = model_mean.numpy()
+        np.save(f"{args.output_path}/imagenet_mean_{model}.npy", model_mean_npy)
+
+        # var
+        for i, b in enumerate(ds):
+            if model_var_sum is None:
+                model_var_sum = torch.zeros((b[key].size(-1)))
+            model_var_sum += ((b[key].float() - model_mean) ** 2).sum((0, 1))
+            print(f"calc {model} var {i*256:07d}\r", end="")
+
+        model_var = torch.sqrt(model_var_sum / (n - 1))
+        np.save(f"{args.output_path}/imagenet_var_{model}.npy", model_var.numpy())
+
+
+if __name__ == "__main__":
+    main()
diff --git a/theia/scripts/preprocessing/check_feature.py b/theia/scripts/preprocessing/check_feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..f19403eee16668e2dca8cee8c1c5ba46c9bd36ca
--- /dev/null
+++ b/theia/scripts/preprocessing/check_feature.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import argparse
+import json
+import os
+import tarfile
+from io import BytesIO
+from typing import Any
+
+import cv2
+import numpy as np
+import torch
+from numpy.typing import NDArray
+from PIL import Image
+from safetensors.torch import load as sft_load
+
+from theia.dataset import ALL_IMAGE_DATASETS, ALL_VIDEO_DATASETS
+from theia.foundation_models.common import MODELS
+from theia.preprocessing.feature_extraction_core import (
+    get_feature_outputs,
+    get_model,
+)
+from theia.utils.seed import seed_everything
+
+
+def decode_oxe_sample(data: bytes, data_type: str) -> Any:
+    """Decode the sample from bytes.
+
+    Args:
+        data (bytes): data to be decoded.
+        data_type (str): the type of the data.
+            Usually is part of the key (filename of the sample) in the webdataset.
+
+    Returns:
+        Any: decoded data or pass-through bytes without touch.
+    """
+    if ".safetensors" in data_type:
+        sftensor = sft_load(data)
+        return sftensor["embedding"]
+    elif data_type == ".image":
+        image = np.load(BytesIO(data))
+        if len(image.shape) == 2:
+            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
+        elif len(image.shape) == 3 and image.shape[-1] == 4:
+            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
+        # return torch.from_numpy(image)
+        return image
+    else:
+        return data
+
+
+def get_tar_sample(tarf: tarfile.TarFile, sample_index: int) -> bytes:
+    """Get bytes of a sample with index `sample_index` in tarfile `tarf`.
+
+    Args:
+        tarf (tarfile.TarFile): tar file.
+        sample_index (int): index of the sample
+
+    Returns:
+        bytes: content of the sample in bytes
+    """
+    tar_members = tarf.getmembers()
+    tar_members = sorted(tar_members, key=lambda x: x.name)
+    tar_mem = tar_members[sample_index]
+    f = tarf.extractfile(tar_mem.name)
+    if f:
+        return f.read()
+    else:
+        raise IOError(f"failed to read tarfile {tarf}.")
+
+
+def get_tar_sample_name(tarf: tarfile.TarFile, sample_index: int) -> str:
+    """Get the name of the sample with index `sample_index` in the tarfile `tarf`.
+
+    Args:
+        tarf (tarfile.TarFile): tar file.
+        sample_index (int): index of the sample
+
+    Returns:
+        str: name of the file
+    """
+    tar_members = tarf.getmembers()
+    tar_members = sorted(tar_members, key=lambda x: x.name)
+    tar_mem = tar_members[sample_index]
+    return tar_mem.name
+
+
+def check_feature(
+    args: argparse.Namespace,
+    dataset: str,
+    modelnames_to_check: list[str],
+    models: dict[str, Any],
+    processors: dict[str, Any],
+    shard_idx: int,
+    sample_indices: list[int] | NDArray,
+    split: str = "train",
+    dtype: torch.dtype = torch.bfloat16,
+) -> dict[str, bool]:
+    """Check feature consistency given a dataset, names of models to check,
+        shard index and sample indices within that shard.
+
+    Args:
+        args (argparse.Namespace): arguments.
+        dataset (str): name of the dataset
+        modelnames_to_check (list[str]): names of the features (models) to check.
+        models (dict[str, Any]): original models to produce features on the fly.
+        processors (dict[str, Any]): original processor of the models.
+        shard_idx (int): index of the shard.
+        sample_indices (list[int] | NDArray): indices of samples to be checked.
+        split (str, optional): name of the split of the dataset. Defaults to "train".
+        dtype (torch.dtype, optional): dtype of the generated feature. Defaults to torch.bfloat16.
+
+    Returns:
+        dict[str, bool]: check result. The keys are model names. True means passing the check.
+    """
+    data_dir = os.path.join(args.dataset_root, dataset, "images")
+    shard_filenames = sorted([filename for filename in os.listdir(data_dir) if f"{split}.tar" in filename])
+    image_tar = tarfile.open(os.path.join(data_dir, shard_filenames[shard_idx]), "r")
+    images = [
+        decode_oxe_sample(get_tar_sample(image_tar, sample_index), data_type=".image")
+        for sample_index in sample_indices
+    ]
+    for image, sample_index in zip(images, sample_indices, strict=False):
+        if args.save_image:
+            if not os.path.exists(args.image_save_dir):
+                os.makedirs(args.image_save_dir)
+            image = Image.fromarray(image)
+            image.save(os.path.join(args.image_save_dir, f"image_{shard_idx}_{sample_index}.jpg"))
+    image_names = [get_tar_sample_name(image_tar, sample_index).split(".")[0] for sample_index in sample_indices]
+
+    model_check_pass = {m: False for m in modelnames_to_check}
+    for model_name in modelnames_to_check:
+        legit_model = model_name.replace("/", "_")
+        data_dir = os.path.join(args.dataset_root, dataset, legit_model)
+        shard_filenames = sorted([filename for filename in os.listdir(data_dir) if f"{split}.tar" in filename])
+        feature_tar = tarfile.open(os.path.join(data_dir, shard_filenames[shard_idx]), "r")
+        features = torch.stack(
+            [
+                decode_oxe_sample(get_tar_sample(feature_tar, sample_index), data_type=".safetensors")
+                for sample_index in sample_indices
+            ]
+        )
+        gt_features = get_feature_outputs(
+            legit_model, models[legit_model], processors[legit_model], images, dtype=dtype
+        )[legit_model]["embedding"]
+        print(torch.sum(torch.abs(features - gt_features)), torch.max(torch.abs(features - gt_features)))
+        model_check_pass[model_name] = torch.all((features - gt_features) == 0)
+        if args.check_feature_name:
+            names = [get_tar_sample_name(feature_tar, sample_index).split(".")[0] for sample_index in sample_indices]
+            model_check_pass[model_name] = (
+                all([imname == filename for imname, filename in zip(image_names, names, strict=False)])
+                and model_check_pass[model_name]
+            )
+    return model_check_pass
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset-root", type=str)
+    parser.add_argument("--dataset", type=str)
+    parser.add_argument("--split", type=str, default="val")
+    parser.add_argument("--samples-per-shard", type=int, default=1000, help="number of samples per webdataset shard.")
+    parser.add_argument("--check-feature-name", action="store_true")
+    parser.add_argument("--save-image", action="store_true")
+    parser.add_argument("--image-save-dir", type=str, default="./tmp")
+    parser.add_argument("--seed", type=int, default=0)
+    args = parser.parse_args()
+
+    seed_everything(0)
+
+    all_datasets = {}
+    all_datasets.update(ALL_IMAGE_DATASETS)
+    all_datasets.update(ALL_VIDEO_DATASETS)
+
+    with open(os.path.join(args.dataset_root, args.dataset, "splits.json"), "r") as f:
+        dataset_len = json.load(f)[args.split]
+
+    n_shards = dataset_len // args.samples_per_shard
+
+    model_names = [model_name for model_name in MODELS if "llava" not in model_name]
+    models, processors = {}, {}
+    for model_name in model_names:
+        legit_model_name = model_name.replace("/", "_")
+        model, processor = get_model(model_name, device=0)
+        models[legit_model_name] = model
+        processors[legit_model_name] = processor
+
+    shard_indices = np.random.permutation(n_shards)[:5]
+    print(f"randomly check {args.dataset} shards {shard_indices}")
+    model_check_pass: dict[str, list[bool]] = {model_name: [] for model_name in model_names}
+    for shard_idx in shard_indices:
+        sample_indices = np.random.permutation(1000)[:8]
+        print(f"randomly check {args.dataset} shard {shard_idx} sample_indices {sample_indices}")
+        check_result = check_feature(
+            args, args.dataset, model_names, models, processors, shard_idx, sample_indices, split=args.split
+        )
+        for model_name in model_check_pass:
+            model_check_pass[model_name].append(check_result[model_name])
+    for model_name in model_check_pass:
+        if not all(model_check_pass[model_name]):
+            print(f"{args.dataset} {args.split} {model_name} check failed!!!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/theia/scripts/preprocessing/feature_extraction.py b/theia/scripts/preprocessing/feature_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c671009d78107d70c79eff46ad8cb314671f722
--- /dev/null
+++ b/theia/scripts/preprocessing/feature_extraction.py
@@ -0,0 +1,401 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import argparse
+import gc
+import glob
+import json
+import math
+import multiprocessing
+import os
+from io import BytesIO
+from os.path import join
+from typing import Any, Generator, Iterable, Optional
+
+import cv2
+import numpy as np
+import torch
+import webdataset as wds
+from numpy.typing import NDArray
+from safetensors.torch import save as safe_torch_save
+
+try:
+    import tensorflow_datasets as tfds
+    from tensorflow.python.ops.numpy_ops import np_config
+except ImportError as e:
+    print (e)
+    print ("No TF usable. It's ok if you are not processing OXE dataset.")
+
+from theia.dataset import ALL_IMAGE_DATASETS, ALL_OXE_DATASETS, ALL_VIDEO_DATASETS
+from theia.dataset.oxe.oxe_common import oxe_dsname2path
+from theia.preprocessing.feature_extraction_core import (
+    check_existing_shard,
+    decode_image_npy_only,
+    get_feature_outputs,
+    get_model,
+)
+from torch.utils.data import IterableDataset
+
+
+def get_dataset(dataset_name: str, split: str, dataset_root: Optional[str] = None) -> tuple[Iterable, list[str]]:
+    """Get the dataset and its subset keys (if has) given a dataset name.
+
+    Args:
+        dataset_name (str): name of the dataset.
+        split (str): split of the dataset.
+        dataset_root (Optional[str]): root dir of the dataset, if the dataset is stored locally.
+            Defaults to None (remote dataset).
+
+    Returns:
+        tuple[Iterable, list[str]]: dataset and its subset keys
+    """
+    if dataset_name in ALL_OXE_DATASETS:
+        builder = tfds.builder_from_directory(builder_dir=oxe_dsname2path(dataset_name))
+        split = f"{split}[0:]"  # don't change this to skip samples
+        dataset = builder.as_dataset(split=split)
+        visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
+        return dataset, visual_observation_keys
+    elif dataset_name in ALL_VIDEO_DATASETS or dataset_name in ALL_IMAGE_DATASETS:
+        if dataset_root is None:
+            raise ValueError("`dataset_root` is not given.")
+        dataset_dir = os.path.join(dataset_root, dataset_name, "images")
+        if not os.path.exists(dataset_dir) or not os.path.isdir(dataset_dir):
+            raise ValueError(f"{dataset_dir} is not found or is not a directory.")
+        print("dataset shards", sorted(glob.glob(f"{dataset_dir}/*-{split}.tar")))
+        dataset = wds.WebDataset(
+            sorted(glob.glob(f"{dataset_dir}/*-{split}.tar")),
+            shardshuffle=False,
+        ).decode(decode_image_npy_only)
+        return dataset, ["__self__"]
+    else:
+        raise NotImplementedError(f"{dataset_name} is not available")
+
+
+def get_episode(ds: Any) -> Generator[tuple[Any, int], Any, Any]:
+    """Get an episode / a trajectory / a segment form the dataset
+
+    Args:
+        ds (Any): oxe dataset in tfds format or image/video dataset in webdataset format.
+
+    Yields:
+        Generator[tuple[Any, int], Any, Any]: a trajectory with its length.
+    """
+    if isinstance(ds, IterableDataset):
+        it = iter(ds)
+        while True:
+            sample_buff = []
+            try:
+                for _ in range(1000):
+                    sample = next(it)
+                    sample_buff.append(sample)
+                yield sample_buff, len(sample_buff)
+            except StopIteration:
+                yield sample_buff, len(sample_buff)
+                break
+    else:
+        for ep in ds:
+            yield ep, len(ep["steps"])
+
+
+def get_images(ep: Any, subset: str) -> tuple[list[NDArray], Optional[list[str]]]:
+    """Get images from an episode / a trajectory.
+
+    Args:
+        ep (Any): an episode / a trajectory.
+        subset (str): subset name.
+
+    Returns:
+        tuple[list[NDArray], Optional[list[str]]]: extracted images with optional info.
+    """
+    if isinstance(ep, list):  # for image / video dataset, no subsets
+        return [step["image"] for step in ep], [step["__key__"] for step in ep]
+    else:  # for oxe dataset, subset means multiple camera views
+        images: list[NDArray] = []
+        for step in ep["steps"]:
+            image = cv2.resize(step["observation"][subset].numpy(), (224, 224))
+            images.append(image)
+        return images, None
+
+
+def get_shard_dir(root: str, subset: str, key: str) -> str:
+    """Get the directory to hold shards.
+
+    Args:
+        root (str): root directory.
+        subset (str): subset name.
+        key (str): key (column) name of the processed dataset. Usually it is the name of the feature / input.
+
+    Returns:
+        str: directory to hold the shards.
+    """
+    if subset == "__self__":
+        return os.path.join(root, key)
+    else:
+        return os.path.join(root, subset, key)
+
+
+def get_shard_filename(dataset_name: str, subset: str, split: str, shard_idx: int) -> str:
+    """Get file name of the shard.
+
+    Args:
+        dataset_name (str): name of the dataset.
+        subset (str): name of the subset.
+        split (str): name of the split.
+        shard_idx (int): index of this shard.
+
+    Returns:
+        str: shard file name.
+    """
+    if dataset_name in ALL_OXE_DATASETS:
+        if subset == "__self__":
+            return f"{dataset_name}_{split}-{shard_idx:06d}.tar"
+        else:
+            return f"{dataset_name}_{subset}_{split}-{shard_idx:06d}.tar"
+    else:
+        if subset == "__self__":
+            return f"{dataset_name}_{split}-{shard_idx:06d}-{split}.tar"
+        else:
+            return f"{dataset_name}_{subset}_{split}-{shard_idx:06d}-{split}.tar"
+
+
+def feature_extractor(
+    args: argparse.Namespace,
+    shard_queue: multiprocessing.Queue,
+    worker_id: int,
+    dataset_len: int = 0,
+) -> None:
+    """Feature extractor, operating on each `worker_id`.
+
+    Args:
+        args (argparse.Namespace): configurations.
+        shard_queue (multiprocessing.Queue): queue to get shard index to work on.
+        worker_id (int): id of this worker.
+        dataset_len (int): length of the entire dataset to be processed.
+    """
+    if args.model != "image":
+        model, processor = get_model(args.model, device=worker_id)
+    else:
+        model, processor = None, None
+    dataset, subsets = get_dataset(args.dataset, args.split, args.dataset_root)
+    dataset_output_root = join(args.output_path, args.dataset)
+
+    cum_traj_len, traj_index = 0, 0
+    shard_idx = shard_queue.get()
+    data_iter = get_episode(dataset)
+    episode, traj_len = next(data_iter)
+    remain_traj_len = traj_len
+    while shard_idx is not None:
+        print(f"{args.dataset} {args.model} shard {shard_idx:04d} worker {worker_id} " f"Subsets: {subsets}")
+        # navigate (stream) the dataset to the correct trajectory
+        while (cum_traj_len + remain_traj_len) <= shard_idx * args.samples_per_shard:
+            cum_traj_len += remain_traj_len
+            try:
+                episode, traj_len = next(data_iter)
+                remain_traj_len = traj_len
+                traj_index += 1
+            except StopIteration:
+                break
+
+        # check shard
+        model_names_legit = args.model.replace("/", "_")
+        shard_keys = [model_names_legit]
+        subset_check_codes = {subset: {k: 0 for k in shard_keys} for subset in subsets}
+
+        for subset in subsets:
+            for k in shard_keys:
+                shard_dir = get_shard_dir(dataset_output_root, subset, k)
+                shard_filename = get_shard_filename(args.dataset, subset, args.split, shard_idx)
+                shard_path = os.path.join(shard_dir, shard_filename)
+                shard_check_code, _ = check_existing_shard(shard_path, shard_keys)
+                subset_check_codes[subset][k] = shard_check_code
+
+        # generate data to the shard buffers
+        subset_shard_buffers: dict[str, dict[str, list[dict[str, str | bytes]]]] = {
+            subset: {k: [] for k in shard_keys} for subset in subsets
+        }
+        while cum_traj_len < min((shard_idx + 1) * args.samples_per_shard, dataset_len):
+            for subset in subsets:
+                images, info = None, None
+
+                start_frame_index = traj_len - remain_traj_len
+                if start_frame_index >= traj_len:
+                    raise ValueError("calculate start frame index error, needs more trajectories")
+                # end of the trajectory
+                end_frame_index = min((shard_idx + 1) * args.samples_per_shard - cum_traj_len, traj_len)
+
+                # generate shard data per key, including images and model features
+                # skip any indices that are completed
+                for k in shard_keys:
+                    if subset_check_codes[subset][k] == 1:
+                        print(f"{args.dataset} {subset} {k} shard {shard_idx:04d} check pass")
+                        continue
+                    if k == "image":
+                        if images is None:
+                            # read all the images in the trejectory
+                            images, info = get_images(episode, subset)
+                        for frame_index in range(start_frame_index, end_frame_index):
+                            if args.dataset in ALL_OXE_DATASETS:
+                                basename = (
+                                    f"{args.dataset}"
+                                    f"{'' if subset=='__self__' else '_'+subset}_seq{traj_index:06d}_{frame_index:06d}"
+                                )
+                            else:
+                                basename = info[frame_index] if info else ""
+                            if not args.dry_run:
+                                image_out = BytesIO()
+                                np.save(image_out, images[frame_index])
+                                subset_shard_buffers[subset][k].append({"__key__": basename, k: image_out.getvalue()})
+                    else:
+                        if images is None:
+                            images, info = get_images(episode, subset)
+                        processed = start_frame_index
+                        # batch processing images
+                        while processed < end_frame_index:
+                            # take a batch
+                            batch_images = images[processed : processed + args.batch_size]
+                            if not args.dry_run:
+                                effective_batch_size = len(batch_images)
+                                features = get_feature_outputs(k, model, processor, batch_images)
+                                for frame_index in range(processed, processed + effective_batch_size):
+                                    if args.dataset in ALL_OXE_DATASETS:
+                                        basename = (
+                                            f"{args.dataset}"
+                                            f"{'' if subset=='__self__' else '_'+subset}"
+                                            f"_seq{traj_index:06d}_{frame_index:06d}"
+                                        )
+                                    else:
+                                        basename = info[frame_index] if info else ""
+                                    tensor_sample_buffer = {}
+                                    for feature_key in features[k]:
+                                        tensor_sample_buffer[feature_key] = features[k][feature_key][
+                                            frame_index - processed
+                                        ]
+                                    subset_shard_buffers[subset][k].append(
+                                        {"__key__": basename, f"{k}.safetensors": safe_torch_save(tensor_sample_buffer)}
+                                    )
+
+                            # next batch
+                            processed += args.batch_size
+
+            cum_traj_len += (
+                end_frame_index - start_frame_index
+            )  # only increase processed traj len by the actual number of frames processed
+            remain_traj_len -= end_frame_index - start_frame_index
+            print(f"{args.dataset} {args.model} shard {shard_idx:04d} traj {traj_index:06d} remains {remain_traj_len}")
+            # if the trajectory is exhausted, get the next one
+            if remain_traj_len == 0:
+                try:
+                    episode, traj_len = next(data_iter)
+                    remain_traj_len = traj_len
+                    traj_index += 1
+                except StopIteration:
+                    break
+
+        # shard_buffer generate done, write shard
+        if not args.dry_run:
+            for subset in subsets:
+                for k in shard_keys:
+                    if subset_check_codes[subset][k] == 1:
+                        continue
+                    shard_dir = get_shard_dir(dataset_output_root, subset, k)
+                    shard_filename = get_shard_filename(args.dataset, subset, args.split, shard_idx)
+                    shard_path = os.path.join(shard_dir, shard_filename)
+                    if not os.path.exists(shard_dir):
+                        os.makedirs(shard_dir)
+                    print(len(subset_shard_buffers[subset][k]))
+                    with wds.TarWriter(shard_path) as tar_writer:
+                        for sample in subset_shard_buffers[subset][k]:
+                            tar_writer.write(sample)
+
+        print(f"{args.dataset} {args.model} shard {shard_idx:04d} done")
+        del subset_shard_buffers
+        gc.collect()
+        # get a new shard to process
+        shard_idx = shard_queue.get()
+
+
+def main() -> None:
+    """Main entry of feature extraction"""
+    np_config.enable_numpy_behavior()
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset", type=str)
+    parser.add_argument("--dataset-root", type=str)
+    parser.add_argument("--output-path", type=str)
+    parser.add_argument("--model", type=str)
+    parser.add_argument("--split", default="train")
+    parser.add_argument("--start", type=int, default=0, help="start index (form 0) of **steps** to process")
+    parser.add_argument(
+        "--num-to-process",
+        type=int,
+        default=-1,
+        help="number of **steps** to process based on start. -1 means all remaining from the start.",
+    )
+    parser.add_argument("--batch-size", type=int, default=8, help="batch size for the model forward pass")
+    parser.add_argument("--force", action="store_true", help="force overwrite existing feature files.")
+    parser.add_argument("--dry-run", action="store_true", help="do not do model forward pass and write out.")
+    parser.add_argument(
+        "--samples-per-shard", type=int, default=1000, help="number of samples per webdataset shard. Rarely changed."
+    )
+    parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus to parallel")
+    args = parser.parse_args()
+
+    if torch.cuda.is_available():
+        args.num_gpus = min(args.num_gpus, torch.cuda.device_count())
+    else:
+        args.num_gpus = 0
+
+    # make directories
+    dataset_output_root = os.path.join(args.output_path, args.dataset)
+    if not os.path.exists(dataset_output_root):
+        os.makedirs(dataset_output_root)
+
+    # organize the start index to start of a shard
+    start_fi = args.start // args.samples_per_shard * args.samples_per_shard
+    start_shard_idx = start_fi // args.samples_per_shard
+
+    all_datasets = {}
+    all_datasets.update(ALL_OXE_DATASETS)
+    all_datasets.update(ALL_IMAGE_DATASETS)
+    all_datasets.update(ALL_VIDEO_DATASETS)
+    dataset_dir = os.path.join(args.dataset_root, args.dataset)
+
+    if args.dataset in ALL_IMAGE_DATASETS or args.dataset in ALL_VIDEO_DATASETS:
+        with open(os.path.join(dataset_dir, "splits.json"), "r") as f:
+            splits = json.load(f)
+            dataset_len = splits[args.split]
+    else:
+        dataset_len = all_datasets[args.dataset]["steps"]
+
+    # calculate how many shards to create
+    if args.num_to_process > 0:
+        end_sample_index = args.start + args.num_to_process
+    else:
+        end_sample_index = dataset_len
+
+    if end_sample_index % args.samples_per_shard == 0:
+        end_shard_idx = end_sample_index // args.samples_per_shard
+    else:
+        end_shard_idx = math.ceil((end_sample_index) / args.samples_per_shard)
+    shards = list(range(start_shard_idx, end_shard_idx))
+
+    # create a queue to hold shards
+    shard_queue: multiprocessing.Queue = multiprocessing.Queue()
+    for shard_idx in shards:
+        shard_queue.put(shard_idx)
+    for _ in range(args.num_gpus * 2 + 1):
+        shard_queue.put(None)
+
+    # create workers
+    workers = [
+        multiprocessing.Process(target=feature_extractor, args=(args, shard_queue, worker_id, dataset_len))
+        for worker_id in range(max(args.num_gpus, 1))
+    ]
+
+    for w in workers:
+        w.start()
+    for w in workers:
+        w.join()
+
+
+if __name__ == "__main__":
+    torch.multiprocessing.set_start_method("spawn")
+    main()
diff --git a/theia/scripts/preprocessing/image_datasets/organize_imagenet_webdataset.py b/theia/scripts/preprocessing/image_datasets/organize_imagenet_webdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3bf47191be0043aacf8811c13017f714cd5a224
--- /dev/null
+++ b/theia/scripts/preprocessing/image_datasets/organize_imagenet_webdataset.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+"""Organize imagefolder-like images (ImageNet) to webdataset format."""
+
+import argparse
+import glob
+import os
+import shutil
+import tarfile
+from io import BytesIO
+
+import numpy as np
+import webdataset as wds
+from numpy.typing import NDArray
+from PIL import Image
+from torchvision.transforms.v2 import Compose, Resize
+
+
+def check_existing_shard(path: str) -> bool:
+    """Check the integrity of the existing webdataset shard.
+
+    Args:
+        path (str): path to the webdataset shard.
+
+    Returns:
+        bool: True for complete shard.
+            False for non-existing or broken shard.
+    """
+    try:
+        tarf = tarfile.open(path)
+        for _ in tarf.getmembers():
+            pass
+    except (ValueError, tarfile.ReadError, tarfile.CompressionError) as e:
+        print(e)
+        return False
+    return True
+
+
+def create_shard(
+    args: argparse.Namespace,
+    shard_idx: int,
+    shard_path: str | None,
+    remote_shard_path: str,
+    frames: list[tuple[NDArray, str]],
+) -> None:
+    """Create a webdataset shard.
+
+    Args:
+        args (argparse.Namespace): arguments.
+        shard_idx (int): index of this shard.
+        shard_path (str): (local) path to save the shard.
+        remote_shard_path (str): final destination (remote) to save the shard.
+        frames (list[tuple[NDArray, str]]): images to save in this shard.
+    """
+    if check_existing_shard(remote_shard_path):
+        print(f"creating {args.dataset} shard {shard_idx:06d} - check pass, skip\r", end="")
+        return
+    print(f"creating {args.dataset} shard {shard_idx:06d}\r", end="")
+    if shard_path is None:
+        shard_path = remote_shard_path
+    with wds.TarWriter(shard_path) as tar_writer:
+        for i, (image, basename) in enumerate(frames):
+            image_out = BytesIO()
+            np.save(image_out, image)
+            sample = {"__key__": basename, "image": image_out.getvalue()}
+            tar_writer.write(sample)
+            if (i + 1) % 20 == 0:
+                print(f"creating {args.dataset} shard {shard_idx:06d} - {(i+1) * 100 // len(frames):02d}%\r", end="")
+    if shard_path != remote_shard_path:
+        shutil.move(shard_path, remote_shard_path)
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset", type=str)
+    parser.add_argument("--output-path", type=str)
+    parser.add_argument("--imagenet-raw-path", type=str)
+    parser.add_argument("--tmp-shard-path", type=str, default="None")
+    parser.add_argument("--split", type=str, default="train")
+    parser.add_argument("--samples-per-shard", type=int, default=1000)
+    args = parser.parse_args()
+
+    match args.dataset:
+        case "imagenet":
+            IMAGE_DATASET_RAW_DIR = args.imagenet_raw_path
+        case _:
+            raise NotImplementedError(f"{args.dataset} is not supported")
+
+    if args.tmp_shard_path == "None":
+        TMP_SHARD_PATH = None
+    else:
+        TMP_SHARD_PATH = os.path.join(args.tmp_shard_path, args.dataset)
+        if not os.path.exists(TMP_SHARD_PATH):
+            os.makedirs(TMP_SHARD_PATH)
+
+    OUTPUT_SHARD_PATH = os.path.join(args.output_path, args.dataset)
+    if not os.path.exists(OUTPUT_SHARD_PATH):
+        os.makedirs(OUTPUT_SHARD_PATH, exist_ok=True)
+
+    if args.split == "train":
+        image_paths = sorted(glob.glob(f"{IMAGE_DATASET_RAW_DIR}/{args.split}/*/*.JPEG"))
+    else:
+        image_paths = sorted(glob.glob(f"{IMAGE_DATASET_RAW_DIR}/{args.split}/*.JPEG"))
+
+    transform = Compose([Resize((224, 224), antialias=True)])
+
+    shard_idx = 0
+    shard_buffer: list[tuple[NDArray, str]] = []
+    for image_path in image_paths:
+        basename = image_path.split("/")[-1].split(".")[0]
+        image = np.array(transform(Image.open(image_path)))
+        shard_buffer.append((image, basename))
+        if len(shard_buffer) % 20 == 0:
+            print(f"shard {shard_idx: 04d} frames {len(shard_buffer)}\r", end="")
+        if len(shard_buffer) == args.samples_per_shard:
+            shard_fn = f"{args.dataset}_{args.split}-{shard_idx:06d}-{args.split}.tar"
+            local_shard_path = os.path.join(TMP_SHARD_PATH, shard_fn) if TMP_SHARD_PATH else None
+            remote_shard_path = os.path.join(OUTPUT_SHARD_PATH, shard_fn)
+            create_shard(args, shard_idx, local_shard_path, remote_shard_path, shard_buffer)
+            shard_buffer = []
+            shard_idx += 1
+
+    shard_fn = f"{args.dataset}_{args.split}-{shard_idx:06d}-{args.split}.tar"
+    local_shard_path = os.path.join(TMP_SHARD_PATH, shard_fn) if TMP_SHARD_PATH else None
+    remote_shard_path = os.path.join(OUTPUT_SHARD_PATH, shard_fn)
+    if len(shard_buffer) > 0:
+        create_shard(args, shard_idx, local_shard_path, remote_shard_path, shard_buffer)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/theia/scripts/preprocessing/iv_feature_extraction.sh b/theia/scripts/preprocessing/iv_feature_extraction.sh
new file mode 100644
index 0000000000000000000000000000000000000000..258cde177b80856c41903e30d3c2980e0d6b4cb4
--- /dev/null
+++ b/theia/scripts/preprocessing/iv_feature_extraction.sh
@@ -0,0 +1,16 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+#! /bin/bash
+
+dataset=$1
+numgpus=$2
+
+# modify models below
+models=(facebook/dinov2-large google/vit-huge-patch14-224-in21k openai/clip-vit-large-patch14 LiheYoung/depth-anything-large-hf) # facebook/sam-vit-huge
+for model in ${models[@]}
+do
+    (
+        python feature_extraction.py --dataset $dataset --output-path /storage/nfs/datasets/jshang/ --model $model --split train --num-gpus $numgpus; \
+        python feature_extraction.py --dataset $dataset --output-path /storage/nfs/datasets/jshang/ --model $model --split val --num-gpus $numgpus
+    ) &
+done
+wait
diff --git a/theia/scripts/preprocessing/split_dataset.py b/theia/scripts/preprocessing/split_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0c79d429812a16902dc119e649a294002562389
--- /dev/null
+++ b/theia/scripts/preprocessing/split_dataset.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import argparse
+import json
+import math
+import os
+import tarfile
+from collections import OrderedDict
+
+from theia.dataset.oxe.oxe_common import ALL_OXE_DATASETS
+from theia.dataset.video import ALL_VIDEO_DATASETS
+
+DATASET_RATIOS = OrderedDict({"train": 0.8, "val": 0.05, "test": 0.15})
+
+all_datasets = {}
+all_datasets.update(ALL_OXE_DATASETS)
+all_datasets.update(ALL_VIDEO_DATASETS)
+# all_datasets.update(ALL_IMAGE_DATASETS) imagenet has its own splits, can be done seperately
+
+
+def count_steps(tar_path: str) -> int:
+    """Count how many samples are in the shard
+
+    Args:
+        tar_path (str): path to the shard
+    """
+    with tarfile.open(tar_path) as tarf:
+        return len(list(set([x.name.split(".")[0] for x in tarf.getmembers()])))
+
+
+def do_dataset_split(args: argparse.Namespace, dataset_name: str) -> None:
+    """Split the dataset given a dataset name.
+    The dataset will be split based on shards in the lexical order of their filenames.
+    The first part goes to `training` set, the second part goes to `validation` set,
+        and the last part goes to `test` set.
+
+    Args:
+        dataset_name (str): name of the dataset.
+    """
+    dataset_dir = os.path.join(args.dataset_root, dataset_name)
+    split_json_file = os.path.join(dataset_dir, "splits.json")
+
+    if os.path.exists(split_json_file):
+        return
+
+    # only apply to images
+    # then feature extraction script will handle splits for features
+    shard_dirs = [os.path.join(dataset_dir, "images")]
+    for shard_dir in shard_dirs:
+        shard_names = sorted(
+            [filename for filename in os.listdir(shard_dir) if filename.endswith(".tar") and "-" in filename]
+        )
+        n_shards = len(shard_names)
+        print(f"{dataset_name} total {n_shards} shards")
+
+        cum_n_shards = 0
+        split_steps_count = {}
+        for _, split in enumerate(DATASET_RATIOS):
+            ratio = DATASET_RATIOS[split]
+            split_n_shards = math.ceil(n_shards * ratio)
+            split_steps_count[split] = 0
+            print(f"{dataset_name} {split} {split_n_shards} shards")
+
+            for shard_idx in range(cum_n_shards, min(cum_n_shards + split_n_shards, n_shards)):
+                original_path = os.path.join(shard_dir, shard_names[shard_idx])
+                if shard_idx == n_shards - 1:
+                    split_steps_count[split] += count_steps(original_path)
+                else:
+                    split_steps_count[split] += args.samples_per_shard
+                split_shard_filename = shard_names[shard_idx].replace(".tar", f"-{split}.tar")
+                split_shard_path = os.path.join(shard_dir, split_shard_filename)
+
+                if not args.dry_run:
+                    os.rename(original_path, split_shard_path)
+            cum_n_shards += split_n_shards
+
+        with open(os.path.join(dataset_dir, "splits.json"), "w") as f:
+            json.dump(split_steps_count, f, indent=4)
+        print(split_steps_count)
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset-root", type=str)
+    parser.add_argument("--dry-run", action="store_true")
+    parser.add_argument(
+        "--samples-per-shard",
+        type=int,
+        default=1000,
+        help="Number of samples per webdataset shard. Rarely changed. Replace with your actual setting.",
+    )
+    args = parser.parse_args()
+    for dataset in all_datasets:
+        if dataset in ALL_OXE_DATASETS:
+            if "_sim" in dataset:
+                continue
+            if "uiuc_d3field" in dataset or "cmu_playing_with_food" in dataset or "robot_vqa" in dataset:
+                continue
+        do_dataset_split(args, dataset)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/theia/scripts/preprocessing/video_datasets/subsampling_videos.py b/theia/scripts/preprocessing/video_datasets/subsampling_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..eda7d4d250fc52970bf7c7d6c5026b60170746c1
--- /dev/null
+++ b/theia/scripts/preprocessing/video_datasets/subsampling_videos.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import argparse
+import os
+import shutil
+import tarfile
+from io import BytesIO
+
+import numpy as np
+import torch
+import webdataset as wds
+from numpy.typing import NDArray
+from PIL import Image
+from torchvision.io import VideoReader, read_video
+from torchvision.transforms import Compose, Resize, ToPILImage
+
+# torchvision.set_video_backend("video_reader")
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--dataset", type=str)
+parser.add_argument(
+    "--dataset-path",
+    type=str,
+    help="please provide the dataset path directly contains videos (.mp4, .webm) or frames (.tar for epic_kitchen)",
+)
+parser.add_argument("--output-path", type=str, help="will create a subfolder within this output path")
+parser.add_argument("--subsampling-rate", type=int, default=-1)
+parser.add_argument("--samples-per-shard", type=int, default=1000)
+args = parser.parse_args()
+
+
+if args.dataset == "ego4d":
+    # default sampling rate for ego4d
+    SUBSAMPLING_RATE = 150 if args.subsampling_rate > 0 else args.subsampling_rate
+    video_ext = ".mp4"
+elif args.dataset == "ssv2":
+    # default sampling rate for ego4d
+    SUBSAMPLING_RATE = 32 if args.subsampling_rate > 0 else args.subsampling_rate
+    video_ext = ".webm"
+elif args.dataset == "epic_kitchen":
+    # default sampling rate for ego4d
+    SUBSAMPLING_RATE = 60 if args.subsampling_rate > 0 else args.subsampling_rate
+    video_ext = ".tar"
+else:
+    raise NotImplementedError(f"{args.dataset} is not supported.")
+
+print(f"subsampling {args.dataset} by 1/{SUBSAMPLING_RATE}")
+
+RAW_VIDEO_PATH = args.dataset_path
+TMP_SAMPLED_FRAMES_PATH = f"/storage/nvme/tmp_video_subsampling/{args.dataset}_1in{SUBSAMPLING_RATE}_images"
+SAMPLED_FRAMES_PATH = os.path.join(args.output_path, f"{args.dataset}_1in{SUBSAMPLING_RATE}_images")
+os.makedirs(SAMPLED_FRAMES_PATH, exist_ok=True)
+os.makedirs(TMP_SAMPLED_FRAMES_PATH, exist_ok=True)
+
+SAMPLES_PER_SHARD = args.samples_per_shard
+
+video_fns = sorted([fn for fn in os.listdir(RAW_VIDEO_PATH) if video_ext in fn])
+
+transform = Compose([Resize((224, 224), antialias=True), ToPILImage()])
+
+
+def check_existing_shard(path: str) -> bool:
+    """
+    Check the integrity of a shard given path.
+
+    Returns:
+        bool: True if the shard exists and is complete.
+    """
+    if os.path.exists(path):
+        try:
+            tarf = tarfile.open(path)
+            for _ in tarf.getmembers():
+                pass
+        except tarfile.TarError:
+            return False
+    else:
+        return False
+    return True
+
+
+def create_shard(shard_idx: int, frames: list[tuple[NDArray, str]]) -> None:
+    """Create a shard given index and frame list.
+
+    Args:
+        shard_idx (int): index of this shard. Used to determine file paths.
+        frames (list[tuple[NDArray, str]]): frames to write to this shard.
+    """
+    shard_fn = f"{args.dataset}_1in{SUBSAMPLING_RATE}-{shard_idx:06d}.tar"
+    local_shard_path = os.path.join(TMP_SAMPLED_FRAMES_PATH, shard_fn)
+    remote_shard_path = os.path.join(SAMPLED_FRAMES_PATH, shard_fn)
+    if check_existing_shard(remote_shard_path):
+        print(f"creating {args.dataset} shard {shard_idx:06d} - check pass, skip\r", end="")
+        return
+    print(f"creating {args.dataset} shard {shard_idx:06d}\r", end="")
+    with wds.TarWriter(local_shard_path) as tar_writer:
+        for i, (image, basename) in enumerate(frames):
+            image_out = BytesIO()
+            np.save(image_out, image)
+            sample = {"__key__": basename, "image": image_out.getvalue()}
+            tar_writer.write(sample)
+            if (i + 1) % 20 == 0:
+                print(
+                    f"creating {args.dataset} shard {shard_idx:06d} - {int((i+1) / len(frames) * 100):02d}%\r", end=""
+                )
+
+    # move from local to remote
+    shutil.move(local_shard_path, remote_shard_path)
+
+
+shard_idx = 0
+shard_buffer: list[tuple[NDArray, str]] = []
+cum_video_len = 0
+for vfn in video_fns:
+    if args.dataset == "ego4d":
+        print(vfn)
+    video_path = os.path.join(RAW_VIDEO_PATH, vfn)
+    if video_ext == ".mp4":  # for ego4d
+        video = VideoReader(video_path, stream="video", num_threads=32)
+        metadata = video.get_metadata()
+        fps = metadata["video"]["fps"][0]
+        duration = metadata["video"]["duration"][0]
+        fi = 0
+        while fi < (duration * fps):
+            frame = next(video.seek(fi / fps))
+            basename = f"{vfn.replace(video_ext, '')}_{fi:06d}"
+            image = np.array(transform(frame["data"]))
+            # print (image.dtype, image.shape)
+            shard_buffer.append((image, basename))
+            if len(shard_buffer) % 20 == 0:
+                print(f"shard {shard_idx: 04d} frames {len(shard_buffer)}\r", end="")
+            if len(shard_buffer) == SAMPLES_PER_SHARD:
+                create_shard(shard_idx, shard_buffer)
+                shard_buffer = []
+                shard_idx += 1
+            fi += SUBSAMPLING_RATE
+
+    elif video_ext == ".webm":  # for ssv2
+        video, _, info = read_video(video_path, output_format="TCHW")
+        video_len = video.size(0)  # for webm, only fps is available; 12 fps for ssv2
+        for fi in range(video_len):
+            if (fi + cum_video_len) % SUBSAMPLING_RATE == 0:
+                frame = video[fi]
+                basename = f"{vfn.replace(video_ext, '')}_{fi:06d}"
+                image = np.array(transform(frame))
+                shard_buffer.append((image, basename))
+                if len(shard_buffer) % 20 == 0:
+                    print(f"shard {shard_idx: 04d} frames {len(shard_buffer)} - file progress {vfn} - {fi}\r", end="")
+                if len(shard_buffer) == SAMPLES_PER_SHARD:
+                    create_shard(shard_idx, shard_buffer)
+                    shard_buffer = []
+                    shard_idx += 1
+        cum_video_len += video_len
+
+    elif video_ext == ".tar":  # for epic_kitchen
+        tar = tarfile.open(video_path)
+        frame_fns = sorted([tinfo.name for tinfo in tar.getmembers() if ".jpg" in tinfo.name])
+        video_len = len(frame_fns)
+        for fi in range(video_len):
+            if (fi + cum_video_len) % SUBSAMPLING_RATE == 0:
+                frame_tarf = tar.extractfile(frame_fns[fi])
+                if frame_tarf:
+                    frame_bytes = frame_tarf.read()
+                else:
+                    continue
+                image = np.array(
+                    transform(torch.from_numpy(np.array(Image.open(BytesIO(frame_bytes)))).permute(-1, 0, 1))
+                )
+                basename = f"{vfn.replace(video_ext, '')}_{fi:06d}"
+                shard_buffer.append((image, basename))
+                if len(shard_buffer) % 20 == 0:
+                    print(f"shard {shard_idx: 04d} frames {len(shard_buffer)} - file progress {vfn} - {fi}\r", end="")
+                if len(shard_buffer) == SAMPLES_PER_SHARD:
+                    create_shard(shard_idx, shard_buffer)
+                    shard_buffer = []
+                    shard_idx += 1
+        cum_video_len += video_len
+
+# create a shard for final remainings
+if len(shard_buffer) > 0:
+    create_shard(shard_idx, shard_buffer)
+    shard_buffer = []
+    shard_idx += 1
diff --git a/theia/scripts/train/sanity_check_train_rvfm.sh b/theia/scripts/train/sanity_check_train_rvfm.sh
new file mode 100755
index 0000000000000000000000000000000000000000..b0aae3572708cdd321a829f431cd453719677032
--- /dev/null
+++ b/theia/scripts/train/sanity_check_train_rvfm.sh
@@ -0,0 +1,5 @@
+#!/bin/bash
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+torchrun --nproc_per_node=1 --nnodes 1 --rdzv_backend c10d --rdzv_endpoint localhost:0 scripts/train/train_rvfm.py \
+  +logging.note=sanitycheck +dataset.data_portion=0.001
diff --git a/theia/scripts/train/train_rvfm.py b/theia/scripts/train/train_rvfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..176f0a9885f989e16f2262d335c82cd17dbca39b
--- /dev/null
+++ b/theia/scripts/train/train_rvfm.py
@@ -0,0 +1,349 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+"""
+Training script for theia, also called robot visual foundation model (RVFM) in
+the code.
+This training script uses hydra. To change configurations go for theia/configs.
+"""
+
+import math
+import os.path as osp
+import random
+import warnings
+from typing import Any, Callable
+
+import hydra
+import wandb
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from torch.optim.lr_scheduler import LRScheduler
+from torchvision.transforms.v2 import Compose
+from torch.nn.parallel import DistributedDataParallel as DDP
+from tqdm import tqdm
+from omegaconf import DictConfig, OmegaConf
+
+from theia.models.rvfm import RobotVisionFM
+from theia.optimizers.utils import param_groups_weight_decay
+from theia.utils.logging import create_meters, log_metrics
+from theia.utils.seed import seed_everything
+from theia.foundation_models.common import MODEL_FEATURE_SIZES, get_model_feature_size
+from theia.dataset.data_utils import get_frame_dataloader, get_frame_iterator, get_image_video_dataset
+from theia.dataset.oxe.oxe_transforms import totensor
+
+
+warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
+
+
+def train(
+    rvfm: nn.Module,
+    target_model_names: list[str],
+    optimizer: torch.optim.Optimizer,
+    lr_scheduler: LRScheduler,
+    train_dataset: Any,
+    eval_dataset: Any,
+    cfg: DictConfig,
+    device: int = 0,
+    train_epoch_steps: int = 0,
+    eval_epoch_steps: int = 0,
+    total_train_steps: int = 0,
+    warmup_steps: int = 0,
+) -> None:
+    """Training and evaluation for robot visual foundation model (rvfm).
+
+    Args:
+        rvfm (nn.Module): model to train.
+        target_model_names (list[str]): list of teacher model names.
+        optimizer (torch.optim.Optimizer): optimizer.
+        lr_scheduler (LRScheduler): learning rate scheduler.
+        train_dataset (Any): train dataset.
+        eval_dataset (Any): eval dataset.
+        cfg (DictConfig): train config
+        device (int, optional): device (of this process). Defaults to 0.
+        train_epoch_steps (int, optional): steps per training epoch. Defaults to 0.
+        eval_epoch_steps (int, optional): steps per eval epoch. Defaults to 0.
+        total_train_steps (int, optional): total training steps. Defaults to 0.
+        warmup_steps (int, optional): warmup steps. Defaults to 0.
+    """
+    epochs = cfg.training.epochs
+    steps = 0
+    # wrap the loaders so handle sync dataloaders easily
+    for ep in range(epochs):
+
+        train_loaders = get_frame_dataloader(
+            train_dataset,
+            batch_size=cfg.training.batch_size,
+            pin_memory=True,
+            num_workers=cfg.training.num_workers,
+            shuffle=cfg.dataset.shuffle,
+            shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
+            seed=cfg.seed + device * 100 + ep,  # either cfg.seed or cfg.seed + rank
+        )
+        eval_loaders = get_frame_dataloader(
+            eval_dataset,
+            batch_size=cfg.training.batch_size,
+            pin_memory=True,
+            num_workers=cfg.training.num_workers,
+            shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
+            seed=cfg.seed,  # either cfg.seed or cfg.seed + rank
+        )
+        train_iter = get_frame_iterator(train_loaders)
+
+        metric_meters = create_meters(target_model_names)
+        rvfm.train()
+        train_tqdm = tqdm(range(train_epoch_steps), ncols=80) if device == 0 else range(train_epoch_steps)
+        for _ in train_tqdm:
+            try:
+                batch = next(train_iter)
+            except StopIteration:
+                train_iter = get_frame_iterator(train_loaders)
+                batch = next(train_iter)
+            images_batch = batch["image"].to(device, non_blocking=True)
+            if cfg.training.random_target_models > 0:
+                batch_target_model_names = random.sample(target_model_names, 2)
+            else:
+                batch_target_model_names = target_model_names
+
+            target_features_batch = {}
+            for t in batch_target_model_names:
+                base_name = t.replace("_cls", "")
+                cls = True if "_cls" in t else False
+                if cls:
+                    target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float()
+                else:
+                    target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float()
+
+            pred = rvfm(images_batch)
+            losses = rvfm.module.get_loss(pred, target_features_batch)
+
+            if cfg.training.main_loss == "mse" or cfg.training.main_loss is None:
+                main_loss = losses["mse_loss"]
+            elif cfg.training.main_loss == "cos_l1":
+                main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"]
+
+            optimizer.zero_grad()
+            main_loss.backward()
+            if cfg.training.grad_clip:
+                nn.utils.clip_grad_norm_(
+                    rvfm.parameters(),
+                    cfg.training.grad_clip_norm_warmup if steps < warmup_steps else cfg.training.grad_clip_norm,
+                )
+            optimizer.step()
+
+            lr_scheduler.step()
+
+            steps += 1
+            batch_size = images_batch.size(0)
+
+            log_metrics(
+                metric_meters,
+                target_model_names=target_model_names,
+                device=device,
+                batch_size=batch_size,
+                mode="train",
+                upload_wandb=True,
+                main_loss=main_loss,
+                **losses,
+            )
+
+            if cfg.training.freeze_translator:
+                if steps == int(cfg.training.freeze_translator_start_steps_ratio * total_train_steps):
+                    rvfm.module.freeze_translator()
+
+            if steps % cfg.logging.save_ckpt_interval == 0 and device == 0:
+                model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth"
+                save_path = osp.join(cfg.logging.model_path, model_save_fn)
+                torch.save(rvfm.module.state_dict(), save_path)
+
+        dist.barrier()
+        rvfm.eval()
+        eval_iter = get_frame_iterator(eval_loaders)
+        eval_tqdm = tqdm(range(eval_epoch_steps), ncols=80) if device == 0 else range(eval_epoch_steps)
+        with torch.no_grad():
+            for _ in eval_tqdm:
+                batch = next(eval_iter)
+                images_batch = batch["image"]
+                target_features_batch = {}
+                for t in target_model_names:
+                    base_name = t.replace("_cls", "")
+                    cls = True if "_cls" in t else False
+                    if cls:
+                        target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float()
+                    else:
+                        target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float()
+
+                pred = rvfm(images_batch)
+                losses = rvfm.module.get_loss(pred, target_features_batch)
+                if cfg.training.main_loss == "mse" or cfg.training.main_loss is None:
+                    main_loss = losses["mse_loss"]
+                elif cfg.training.main_loss == "cos_l1":
+                    main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"]
+
+                batch_size = images_batch.size(0)
+                log_metrics(
+                    metric_meters,
+                    target_model_names=target_model_names,
+                    device=device,
+                    batch_size=batch_size,
+                    mode="eval",
+                    upload_wandb=False,
+                    main_loss=main_loss,
+                    **losses,
+                )
+
+        log_metrics(
+            metric_meters,
+            mode="eval",
+            upload_wandb=True,
+            only_upload=True,
+            target_model_names=target_model_names,
+            device=device,
+        )
+
+        if device == 0:
+            model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth"
+            save_path = osp.join(cfg.logging.model_path, model_save_fn)
+            torch.save(rvfm.module.state_dict(), save_path)
+
+        dist.barrier()
+
+
+def ddp_setup() -> None:
+    """Initialize stuff for DDP."""
+    dist.init_process_group("nccl")
+
+
+def ddp_cleanup() -> None:
+    """Clean up stuff for DDP."""
+    dist.destroy_process_group()
+
+
+def ddp_main(cfg: DictConfig) -> None:
+    """Entry point of DDP.
+
+    Args:
+        cfg (DictConfig): settings for training.
+    """
+    ddp_setup()
+    rank, world_size = dist.get_rank(), dist.get_world_size()
+
+    target_model_names = (
+        cfg.training.target_models.target_model_names
+        if len(cfg.training.target_models.target_model_names) > 0
+        else list(MODEL_FEATURE_SIZES.keys())
+    )
+    target_model_names = [t for t in target_model_names if "llava" not in t]  # llava is currently not supported
+    target_feature_sizes = {t: get_model_feature_size(t, keep_spatial=True) for t in target_model_names}
+
+    target_model_names_wocls = target_model_names[:]
+    if hasattr(cfg.training, "distill_cls") and cfg.training.distill_cls == True:
+        target_model_names_copy = target_model_names[:]
+        for t in target_model_names:
+            if "google/vit" in t or "facebook/dino" in t or "openai/clip" in t:
+                target_feature_sizes[t+"_cls"] = get_model_feature_size(t, keep_spatial=True)[:1]
+                target_model_names_copy.append(t+"_cls")
+
+        target_model_names = target_model_names_copy
+
+    rvfm = RobotVisionFM(
+        translator=cfg.model.translator.type,
+        translator_kwargs=cfg.model.translator.kwargs,
+        target_feature_sizes=target_feature_sizes,
+        target_loss_weights=cfg.training.target_models.target_model_weights,
+        **cfg.model.backbone,
+    )
+
+    rvfm.to(rank)
+
+    rvfm_ddp = DDP(rvfm, device_ids=[rank], find_unused_parameters=False)
+
+    image_transform: Compose | Callable = totensor  # currently just ndarray to tensor
+
+    train_dataset, train_dataset_expected_length = get_image_video_dataset(
+        dataset_root=cfg.dataset.dataset_root,
+        dataset_mix=cfg.dataset.dataset_mix,
+        split="train",
+        dataset_ratio=cfg.dataset.dataset_ratio,
+        feature_models=target_model_names_wocls,
+        image_transform=image_transform,
+        feature_norm=cfg.dataset.feature_norm,
+        rank=rank,
+        world_size=world_size,
+        shuffle=cfg.dataset.shuffle,
+        seed=cfg.seed,
+        shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
+        num_workers=cfg.training.num_workers,
+    )
+
+    eval_dataset, eval_dataset_expected_length = get_image_video_dataset(
+        dataset_root=cfg.dataset.dataset_root,
+        dataset_mix=cfg.dataset.dataset_mix,
+        split="val",
+        dataset_ratio=0.1,
+        feature_models=target_model_names_wocls,
+        image_transform=image_transform,
+        feature_norm=cfg.dataset.feature_norm,
+        rank=rank,
+        world_size=world_size,
+        shuffle=False,
+        seed=cfg.seed,
+        shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
+        num_workers=cfg.training.num_workers,
+    )
+
+    train_epoch_steps = math.ceil(train_dataset_expected_length / cfg.training.batch_size / world_size)
+    eval_epoch_steps = math.ceil(eval_dataset_expected_length / cfg.training.batch_size / world_size)
+    total_train_steps = train_epoch_steps * cfg.training.epochs
+
+    rvfm_param_groups = param_groups_weight_decay(rvfm_ddp, cfg.training.weight_decay)
+    lr = cfg.training.base_lr * (
+        (cfg.training.batch_size * world_size) / (cfg.training.base_batch_size * cfg.training.base_world_size)
+    )
+    optimizer = hydra.utils.instantiate(cfg.training.optimizer, rvfm_param_groups, lr=lr)
+    lr_scheduler = hydra.utils.instantiate(
+        cfg.training.lr_scheduler,
+        optimizer=optimizer,
+        warm_up_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps),
+        cos_lrs_T_0=int(total_train_steps * (1 - cfg.training.warm_up_steps_ratio)),
+    )
+
+    if rank == 0:
+        print(OmegaConf.to_yaml(cfg))
+        wandb.init(project=cfg.logging.project, name=cfg.logging.run_identifier_prefix, config=OmegaConf.to_object(cfg))
+
+    train(
+        rvfm_ddp,
+        target_model_names,
+        optimizer,
+        lr_scheduler,
+        train_dataset,
+        eval_dataset,
+        cfg=cfg,
+        device=rank,
+        train_epoch_steps=train_epoch_steps,
+        eval_epoch_steps=eval_epoch_steps,
+        total_train_steps=total_train_steps,
+        warmup_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps),
+    )
+
+    ddp_cleanup()
+
+
+@hydra.main(version_base=None, config_path="../../configs", config_name="train_rvfm_imagenet")
+def main(cfg: DictConfig) -> None:
+    """Main. Dealing with arguments and call DDP."""
+
+    backbone_fn = f"_{cfg.model.backbone.backbone.replace('/', '-')}"
+    notes_fn = f"_{cfg.logging.notes}" if cfg.logging.notes else ""
+    translator_fn = f"_{cfg.model.translator.type}"
+    pretrained_fn = "_pretrained" if cfg.model.backbone.pretrained else ""
+    dp_fn = f"_dp{cfg.dataset.dataset_ratio:.3f}"
+    cfg.logging.run_identifier_prefix = f"rvfm{dp_fn}{backbone_fn}{translator_fn}{pretrained_fn}{notes_fn}"
+
+    seed_everything(cfg.seed)
+
+    ddp_main(cfg)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/theia/utils/__init__.py b/theia/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/utils/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/utils/cortexbench/__init__.py b/theia/utils/cortexbench/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/utils/cortexbench/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/utils/cortexbench/load_model.py b/theia/utils/cortexbench/load_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c074ff0b17c9a31df59e75d3a3a202ac04abe0a
--- /dev/null
+++ b/theia/utils/cortexbench/load_model.py
@@ -0,0 +1,40 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import math
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+from torchvision.transforms import Compose
+
+
+def load_model(
+    model: nn.Module, transform: Compose, metadata: Any, **kwargs: Any
+) -> tuple[nn.Module, torch.Size, Compose, Any]:
+    """Helper function for loading model for cortexbench.
+
+    Args:
+        model (nn.Module): model.
+        transform (torchvision.transforms.Compose): transform applied to input image.
+        metadata (Any): any metadata embedded in the model.
+        kwargs (Any): any parameters for loading the model. Including
+            `checkpoint_path` for loading weights for rvfm.
+
+    Returns:
+        tuple[nn.Module, torch.Size, Compose, Any]: return model, size of the embedding, transform, and the metadata.
+    """
+
+    if kwargs.get("checkpoint_path"):
+        model.load_pretrained_weights(kwargs["checkpoint_path"])
+
+    with torch.inference_mode():
+        zero_img = np.array(Image.new("RGB", (100, 100)))  # for getting the embedding shape
+        transformed_img = transform(zero_img).unsqueeze(0)
+        embedding_dim = model.forward_feature(transformed_img).size()[1:]  # [H*W, C]
+        if len(embedding_dim) > 1:
+            h = w = int(math.sqrt(embedding_dim[0]))
+            embedding_dim = torch.Size((embedding_dim[1], h, w))  # [C, H, W]
+
+    return model, embedding_dim, transform, metadata
diff --git a/theia/utils/cortexbench/policy_heads.py b/theia/utils/cortexbench/policy_heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3fe7f31ae8cf43589b9ae58f9d4dc9ac0167103
--- /dev/null
+++ b/theia/utils/cortexbench/policy_heads.py
@@ -0,0 +1,240 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from typing import Any, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+# since this code will be only used for running cortexbench
+# the following dependency won't be added to the project by default
+from mjrl.policies.gaussian_mlp import BatchNormMLP
+from numpy.typing import NDArray
+
+
+class ConvBatchNormMLP(BatchNormMLP):
+    """Convolution followed with a BatchNormMLP (BatchNormMLP is from mjrl).
+
+    Attrs:
+        embedding_dim (tuple[int, ...] | list[int, ...] | torch.Size): dimension of the representation.
+        proprio_dim (tuple[int, ...] | list[int, ...] | torch.Size):
+            dimension of the proprio information from the environment.
+        history_window (int): the number of history observations considered.
+        model (nn.ModuleDict): the dict to original BatchNormMLP (as a "head") and newly created Conv (as a "neck").
+        device (str | torch.device): track the device that the model is on.
+    """
+
+    def __init__(
+        self,
+        env_spec: Any,
+        hidden_sizes: str = "(64, 64)",  # str is to adapt with mjrl side
+        min_log_std: float = -3.0,
+        init_log_std: float = 0.0,
+        seed: Optional[int] = None,
+        nonlinearity: str = "relu",
+        dropout: float = 0.0,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        """
+        Args:
+            env_spec (gym.EnvSpec): specs of the environment that this policy will run on.
+            hidden_sizes (tuple): size of hidden layers of MLP. Defaults to (64,64).
+            min_log_std (float): minimum log std value for action. This is to match mjrl. Defaults to -3.
+            init_log_std (float): initial log std value for action. This is to match mjrl. Defaults to 0.
+            seed (Optional[int]): seed. Defaults to None.
+            nonlinearity (str): kind of non-linearility activation function. Defaults to 'relu'.
+            dropout (float): dropout rate. Defaults to 0.
+        """
+        self.embedding_dim = kwargs["embedding_dim"]  # [C, H, W]
+        self.proprio_dim = kwargs["proprio_dim"]
+        self.history_window = kwargs["history_window"]
+        hidden_sizes = eval(hidden_sizes)  # hack to match mjrl
+        env_spec.observation_dim = hidden_sizes[0] + self.proprio_dim
+        super().__init__(
+            env_spec, hidden_sizes, min_log_std, init_log_std, seed, nonlinearity, dropout, *args, **kwargs
+        )
+
+        neck = nn.Sequential(
+            nn.Conv2d(self.embedding_dim[0] * self.history_window, 256, kernel_size=4, stride=2, padding=1),
+            nn.LayerNorm([256, 7, 7]),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),  # 14x14 -> 7x7  # just to keep the same as super class
+            nn.Conv2d(256, 256, kernel_size=3, stride=2),
+            nn.LayerNorm([256, 3, 3]),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),  # 7x7 -> 3x3
+            nn.Conv2d(256, 256, kernel_size=3, stride=1),
+            nn.LayerNorm([256, 1, 1]),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),  # 3x3 -> 1x1
+            nn.Flatten(),
+        )
+
+        # re-encapsule so that all nn parts are in self.model
+        # so that explicit operations on self.model by cortexbench are applied on all nn parts
+        # e.g. policy.model.eval()
+        head: nn.Module = self.model  # type:ignore [has-type]
+        self.model = nn.ModuleDict({"neck": neck, "head": head})
+        self.device: Optional[str | torch.device] = None
+
+    def to(self, device: str | torch.device) -> None:
+        """Put the model on the `device`.
+
+        Args:
+            device (str | torch.device): the device to put the model
+        """
+        for k in self.model:
+            self.model[k].to(device)
+        self.device = device
+
+    def eval(self) -> None:
+        """Set the model in eval mode."""
+        for k in self.model:
+            self.model[k].eval()
+
+    def train(self) -> None:
+        """:Set the model in train mode."""
+        for k in self.model:
+            self.model[k].train()
+
+    def get_action_mean(self, observation: torch.Tensor) -> torch.Tensor:
+        """Get the mean action given the observation.
+
+        Args:
+            observation (torch.Tensor): observation.
+
+        Returns:
+            torch.Tensor : mean action.
+        """
+        if len(self.embedding_dim) > 0:
+            # observation (B, T*H*W*C+C_pripro)
+            if self.proprio_dim > 0:
+                emb_obs, proprio_obs = observation[..., : -self.proprio_dim], observation[..., -self.proprio_dim :]
+                emb_obs = rearrange(
+                    emb_obs,
+                    "b (t h w c) -> b (c t) h w",
+                    t=self.history_window,
+                    c=self.embedding_dim[0],
+                    h=self.embedding_dim[1],
+                    w=self.embedding_dim[2],
+                )
+                emb_obs = self.model["neck"](emb_obs)
+                self.obs_var = torch.cat([emb_obs, proprio_obs], dim=1)
+            else:
+                emb_obs = rearrange(
+                    observation,
+                    "b (t h w c) -> b (c t) h w",
+                    t=self.history_window,
+                    c=self.embedding_dim[0],
+                    h=self.embedding_dim[1],
+                    w=self.embedding_dim[2],
+                )
+                self.obs_var = self.model["neck"](emb_obs)
+        else:
+            raise ValueError(f"input observation {observation.size()} is not from a valid spatial embedding.")
+        mean = self.model["head"](self.obs_var)
+        return mean
+
+    def forward(self, observation: torch.Tensor) -> torch.Tensor:
+        """Model forward. Wrapper for get_action_mean() used during training.
+
+        Args:
+            observation (torch.Tensor): observation.
+
+        Returns:
+            torch.Tensor: mean action.
+        """
+        return self.get_action_mean(observation)
+
+    def get_action(self, observation: NDArray) -> tuple[NDArray, dict[str, Any]]:
+        """Get action with some noise used in evaluation / rollout. No gradient.
+
+        Args:
+            observation (NDArray): observation.
+
+        Returns:
+            tuple[NDArray, dict[str, Any]]: action and some statistics (required by mjrl)
+        """
+        with torch.no_grad():
+            observation = torch.from_numpy(observation.astype(np.float32)).unsqueeze(0).to(self.device)
+            mean = self.get_action_mean(observation).detach().cpu().numpy().ravel()
+            noise = np.exp(self.log_std_val) * np.random.randn(self.m)
+            action = mean + noise
+            return (action, {"mean": mean, "log_std": self.log_std_val, "evaluation": mean})
+
+    def get_action_deterministic(self, observation: NDArray) -> tuple[NDArray, dict[str, Any]]:
+        """Get action without noise (using mean) used in evaluation / rollout. No gradient.
+
+        Args:
+            observation (NDArray): observation.
+
+        Returns:
+            tuple[NDArray, dict[str, Any]]: action and some statistics (required by mjrl)
+        """
+        with torch.no_grad():
+            observation = torch.from_numpy(observation.astype(np.float32)).unsqueeze(0).to(self.device)
+            action = self.get_action_mean(observation).detach().cpu().numpy().ravel()
+            return (action, {"mean": action, "log_std": 0, "evaluation": action})
+
+
+class ConvPolicyHead(ConvBatchNormMLP):
+    """A smaller Convolution followed with a smaller BatchNormMLP (BatchNormMLP is from mjrl).
+
+    Attrs:
+        embedding_dim (tuple[int, ...] | list[int, ...] | torch.Size): dimension of the representation.
+        proprio_dim (tuple[int, ...] | list[int, ...] | torch.Size):
+            dimension of the proprio information from the environment.
+        history_window (int): the number of history observations considered.
+        model (nn.ModuleDict): the dict to original BatchNormMLP (as a "head") and newly created Conv (as a "neck").
+        device (str | torch.device): track the device that the model is on.
+    """
+
+    def __init__(
+        self,
+        env_spec: Any,
+        hidden_sizes: str = "(64, 64)",  # str is to adapt with mjrl side
+        min_log_std: float = -3.0,
+        init_log_std: float = 0.0,
+        seed: Optional[int] = None,
+        nonlinearity: str = "relu",
+        dropout: float = 0.0,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        """
+        Args:
+            env_spec (gym.EnvSpec): specs of the environment that this policy will run on.
+            hidden_sizes (tuple): size of hidden layers of MLP. Defaults to (64,64).
+            min_log_std (float): minimum log std value for action. This is to match mjrl. Defaults to -3.
+            init_log_std (float): initial log std value for action. This is to match mjrl. Defaults to 0.
+            seed (Optional[int]): seed. Defaults to None.
+            nonlinearity (str): kind of non-linearility activation function. Defaults to 'relu'.
+            dropout (float): dropout rate. Defaults to 0.
+        """
+        self.embedding_dim = kwargs["embedding_dim"]  # [C, H, W]
+        self.proprio_dim = kwargs["proprio_dim"]
+        self.history_window = kwargs["history_window"]
+        hidden_sizes = eval(hidden_sizes)  # hack to match mjrl
+        env_spec.observation_dim = hidden_sizes[0] + self.proprio_dim
+        super().__init__(
+            env_spec, hidden_sizes, min_log_std, init_log_std, seed, nonlinearity, dropout, *args, **kwargs
+        )
+
+        del self.model
+
+        neck = nn.Sequential(
+            nn.Conv2d(self.embedding_dim[0] * self.history_window, 60, kernel_size=4, stride=2, padding=1),
+            nn.LayerNorm([60, 7, 7]),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),  # 14x14-> 7x7  # just to keep the same as super class
+            nn.Conv2d(60, 60, kernel_size=3, stride=2),
+            nn.LayerNorm([60, 3, 3]),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),  # 3x3
+            nn.Flatten(),
+        )
+        head = nn.Sequential(
+            nn.Linear(60 * 3 * 3 + self.proprio_dim, 256),
+            nn.LayerNorm(256),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),
+            nn.Linear(256, self.m),
+        )
+        self.model = nn.ModuleDict({"neck": neck, "head": head})
+        self.device = None
diff --git a/theia/utils/cortexbench/transforms.py b/theia/utils/cortexbench/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcdaba491da065b1c05a6082fa6b63f52700c130
--- /dev/null
+++ b/theia/utils/cortexbench/transforms.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import torch
+import torchvision.transforms.v2 as T
+from torchvision.transforms import InterpolationMode
+
+
+def rvfm_image_transforms(output_size: int = 224) -> T.Transform:
+    """Image transform used by RVFM.
+
+    Args:
+        output_size (int): output size of the image.
+
+    Returns:
+        T.Compose: the transform
+    """
+    return T.Compose(
+        [
+            T.ToImage(),
+            T.Resize(output_size, interpolation=InterpolationMode.BICUBIC),
+        ]
+    )
+
+
+def vit_transforms(resize_size: int = 256, output_size: int = 224) -> T.Transform:
+    return T.Compose(
+        [
+            T.ToImage(),
+            T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC),
+            T.CenterCrop(output_size),
+            T.ToDtype(torch.float32, scale=True),
+            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+        ]
+    )
+
+
+def r3m_transforms(resize_size: int = 256, output_size: int = 224) -> T.Transform:
+    return T.Compose(
+        [
+            T.ToImage(),
+            T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC),
+            T.CenterCrop(output_size),
+            T.ToDtype(torch.float32, scale=False),
+        ]
+    )
diff --git a/theia/utils/cortexbench/trifinger/__init__.py b/theia/utils/cortexbench/trifinger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022d0690688cf230f4c18896e9a6f027ef025fb9
--- /dev/null
+++ b/theia/utils/cortexbench/trifinger/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
diff --git a/theia/utils/cortexbench/trifinger/policy.py b/theia/utils/cortexbench/trifinger/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed4aba893c77b316d51df620c3f658fa00803e03
--- /dev/null
+++ b/theia/utils/cortexbench/trifinger/policy.py
@@ -0,0 +1,123 @@
+# File modified. Modifications Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This source code is licensed under the CC-BY-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+from typing import Any
+
+import torch
+import torch.nn as nn
+from einops.layers.torch import Rearrange
+
+
+class ConvBatchNormMLPDeterministicPolicy(nn.Module):
+    def __init__(
+        self,
+        in_dim: tuple[int, ...],
+        extra_dim: int,
+        out_dim: int,
+        max_a: Any = None,
+        hidden_size: int = 256,
+        nonlinearity: str = "relu",
+        device: str | int | torch.device = "cpu",
+    ) -> None:
+        super().__init__()
+        self.extra_dim = extra_dim
+        self.in_dim = in_dim
+        self.neck = nn.Sequential(
+            Rearrange("b (h w c) -> b c h w", h=14, w=14),
+            nn.Conv2d(in_dim[0], 256, kernel_size=4, stride=2, padding=1),  # 14x14 -> 7x7
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),
+            nn.Conv2d(256, 256, kernel_size=3, stride=2),  # 7x7 -> 3x3
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),
+            nn.Conv2d(256, 256, kernel_size=3, stride=1),  # 3x3 -> 1x1
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),
+            nn.Flatten(),
+        )
+        self.policy = nn.Sequential(
+            nn.Linear(256 + extra_dim, hidden_size),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),
+            nn.Linear(hidden_size, hidden_size),
+            nn.ReLU() if nonlinearity == "relu" else nn.Tanh(),
+            nn.Linear(hidden_size, out_dim),
+        )
+        self.neck.to(device)
+        self.policy.to(device)
+        self.device = device
+
+        self.init_state = copy.deepcopy(self.policy.state_dict())
+        self.neck_init_state = copy.deepcopy(self.neck.state_dict())
+
+        self.max_a = max_a
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+
+    def forward(self, state: torch.Tensor) -> torch.Tensor:
+        visual_state = state[..., : -self.extra_dim]
+        feature = self.neck(visual_state)
+        if self.extra_dim > 0:
+            feature = torch.cat([feature, state[..., -self.extra_dim :]], dim=1)
+        action = self.policy(feature)
+        return action
+
+    def reset(self) -> None:
+        self.policy.load_state_dict(self.init_state)
+        self.neck.load_state_dict(self.neck_init_state)
+
+    def clip_action(self, a: torch.Tensor) -> torch.Tensor:
+        if self.max_a is not None:
+            a = torch.where(a > self.max_a, torch.tensor([self.max_a]).to(self.device), a)
+            a = torch.where(a < -self.max_a, -torch.tensor([self.max_a]).to(self.device), a)
+        return a
+
+    def scale_to_range(self, a: torch.Tensor) -> torch.Tensor:
+        """Does not do anything; just returns a"""
+        return a
+
+
+def construct_policy(
+    type: str,
+    task_state_type: str,
+    train_ft_state_shape: int,
+    pretrained_dim: tuple[int, ...],
+    task_goal_type: str,
+    out_dim: int,
+    max_a: Any,
+    device: str | int | torch.device,
+    hidden_size: int = 256,
+    nonlinearity: str = "relu",
+    **kwargs: Any,
+) -> ConvBatchNormMLPDeterministicPolicy:
+    in_dim = pretrained_dim
+    extra_dim = 0
+    if task_state_type == "obj":
+        extra_dim += 0
+    elif task_state_type in ["ftpos_obj", "ftpos"]:
+        extra_dim += train_ft_state_shape
+    else:
+        raise NameError("Invalid state_type")
+
+    if task_goal_type == "goal_none":
+        in_dim = pretrained_dim
+    elif task_goal_type == "goal_cond":
+        in_dim = (pretrained_dim[0] * 2, *pretrained_dim[1:])
+    elif task_goal_type == "goal_o_pos":
+        extra_dim += 3
+    else:
+        raise NameError("Invalid goal_type")
+
+    if type == "ConvBatchNormMLP":
+        policy = ConvBatchNormMLPDeterministicPolicy(
+            in_dim=in_dim,
+            extra_dim=extra_dim,
+            out_dim=out_dim,
+            max_a=max_a,
+            hidden_size=hidden_size,
+            nonlinearity=nonlinearity,
+            device=device,
+        )
+    else:
+        raise NotImplementedError(f"Policy network {type} is not supported.")
+    return policy
diff --git a/theia/utils/logging.py b/theia/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f16b3d5c473edb39b378b7e0bf55964048fbf70
--- /dev/null
+++ b/theia/utils/logging.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+from enum import Enum
+from typing import Any
+
+import torch
+import torch.distributed as dist
+import wandb
+
+
+class SummaryType(Enum):
+    NONE = 0
+    AVERAGE = 1
+    SUM = 2
+    COUNT = 3
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value
+    Attributes:
+        name (str): name of the meter.
+        fmt (str): format string. Defaults to ':f'.
+        summary_type (Enum): reduce method. Defaults to Summary.AVERAGE.
+
+        val (float): last mean value over batch.
+        avg (float): average value since meter creation.
+        sum (float): sum of all the values = self.avg * self.count.
+        count (int): number of values considered since meter creation.
+    """
+
+    def __init__(self, name: str, fmt: str = ":f", summary_type: SummaryType = SummaryType.AVERAGE) -> None:
+        """Initialize an average meter."""
+        self.name = name
+        self.fmt = fmt
+        self.summary_type = summary_type
+        self.reset()
+
+    def reset(self) -> None:
+        """Reset the meter."""
+        self.val: float = 0.0
+        self.avg: float = 0.0
+        self.sum: float = 0.0
+        self.count: int = 0
+
+    def update(self, val: float, n: int = 1) -> None:
+        """Update the meter.
+
+        Args:
+            val (float): (mean) value over n samples.
+            n (int): number of samples. Defaults to 1.
+        """
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def all_reduce(self) -> None:
+        """Reduce meters across ranks."""
+        if torch.cuda.is_available():
+            device = torch.device("cuda")
+        elif torch.backends.mps.is_available():
+            device = torch.device("mps")
+        else:
+            device = torch.device("cpu")
+        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
+        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=True)
+        self.sum, self.count = total.tolist()
+        self.avg = self.sum / self.count
+
+    def __str__(self) -> str:
+        """String representation of the meter."""
+        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+        return fmtstr.format(**self.__dict__)
+
+    def summary(self) -> str:
+        """Print the summary of the meter status."""
+        fmtstr = ""
+        match self.summary_type:
+            case SummaryType.NONE:
+                fmtstr = ""
+            case SummaryType.AVERAGE:
+                fmtstr = "{name} {avg:.3f}"
+            case SummaryType.SUM:
+                fmtstr = "{name} {sum:.3f}"
+            case SummaryType.COUNT:
+                fmtstr = "{name} {count:.3f}"
+            case _:
+                raise ValueError("invalid summary type %r" % self.summary_type)
+
+        return fmtstr.format(**self.__dict__)
+
+
+def create_meters(target_model_names: list[str]) -> dict[str, AverageMeter]:
+    """Create meters for logging statistics, including individual meters for each target model.
+
+    Args:
+        target_model_names (list[str]): names of the target models.
+
+    Returns:
+        dict[str, AverageMeter]: meters created
+    """
+    meters = {}
+    for loss in ["mse", "cos", "l1"]:
+        meters[f"train_{loss}_loss"] = AverageMeter(f"train_{loss}_loss")
+        meters[f"eval_{loss}_loss"] = AverageMeter(f"eval_{loss}_loss")
+
+    for t in target_model_names:
+        for loss in ["mse", "cos", "l1"]:
+            for mode in ["train", "eval"]:
+                meters[f"{mode}_{t}_{loss}_loss"] = AverageMeter(f"{mode}_{t}_{loss}_loss")
+
+    return meters
+
+
+def log_metrics(meters: dict[str, AverageMeter], **kwargs: Any) -> None:
+    """log metrics to wandb.
+
+    Args:
+        meters (dict[str, AverageMeter]): _description_
+    """
+    metrics = {}
+
+    mode = kwargs["mode"]
+    batch_size = kwargs["batch_size"] if "batch_size" in kwargs else 0
+
+    if not kwargs.get("only_upload", False):
+        # update meters
+        meters[f"{mode}_mse_loss"].update(kwargs["mse_loss"].item(), n=batch_size)
+        meters[f"{mode}_cos_loss"].update(kwargs["cos_loss"].item(), n=batch_size)
+        meters[f"{mode}_l1_loss"].update(kwargs["l1_loss"].item(), n=batch_size)
+
+        for t in kwargs["target_model_names"]:
+            for loss in ["mse", "cos", "l1"]:
+                meters[f"{mode}_{t}_{loss}_loss"].update(kwargs[f"{loss}_losses_per_model"][t], n=batch_size)
+
+    # read out from meters or the raw for logging
+    if kwargs["upload_wandb"]:
+        if mode == "train":
+            metrics["loss"] = kwargs["main_loss"].item()
+            metrics["mse_loss"] = kwargs["mse_loss"].item()
+            metrics["cos_loss"] = kwargs["cos_loss"].item()
+            metrics["l1_loss"] = kwargs["l1_loss"].item()
+
+        metrics[f"avg_{mode}_mse_loss"] = meters[f"{mode}_mse_loss"].avg
+        metrics[f"avg_{mode}_cos_loss"] = meters[f"{mode}_cos_loss"].avg
+        metrics[f"avg_{mode}_l1_loss"] = meters[f"{mode}_l1_loss"].avg
+        for t in kwargs["target_model_names"]:
+            for loss in ["mse", "cos", "l1"]:
+                metrics[f"avg_{mode}_{t}_{loss}_loss"] = meters[f"{mode}_{t}_{loss}_loss"].avg
+
+        if kwargs["device"] == 0:
+            wandb.log(metrics)
diff --git a/theia/utils/seed.py b/theia/utils/seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bf644b7b76946906c5a1bdd9e0028a309869c1
--- /dev/null
+++ b/theia/utils/seed.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
+
+import os
+import random
+from typing import Any, Optional
+
+import numpy as np
+import torch
+
+max_seed_value = np.iinfo(np.uint32).max
+min_seed_value = np.iinfo(np.uint32).min
+
+
+def seed_everything(seed: Optional[Any] = None, workers: bool = False) -> int:
+    """Seed everything adopted from lightning_fabric.utilities.seed.seed_everything.
+
+    Avoid using lightning only for seeding.
+
+    Args:
+        seed (Optional[Any]): seed, preferably an integer, or other stuff can be converted to an integer.
+
+    Returns:
+        int: the actual seed used. It should be the same as input seed in most of the cases.
+    """
+    if seed is None:
+        env_seed = os.environ.get("PL_GLOBAL_SEED")
+        if env_seed is None:
+            seed = 0
+        else:
+            try:
+                seed = int(env_seed)
+            except ValueError:
+                seed = 0
+    elif not isinstance(seed, int):
+        seed = int(seed)
+
+    if not (min_seed_value <= seed <= max_seed_value):
+        seed = 0
+
+    os.environ["PL_GLOBAL_SEED"] = str(seed)
+    os.environ["PYTHON_SEED"] = str(seed)  # add python seed
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+
+    os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
+
+    return seed