diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..f6b1f326ca4ab7cf0c8798856f8fe0020ff82d58 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 *.zip filter=lfs diff=lfs merge=lfs -text
 *.zst filter=lfs diff=lfs merge=lfs -text
 *tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index a8cc3865dcc72ed09a6108308354af98f9eeffeb..04ac80b975ace53c2738b1053f5d4d9cf54d303f 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,11 @@
 ---
-title: LVM
-emoji: 🔥
-colorFrom: yellow
-colorTo: gray
+title: VQLM Demo
+emoji: 🎨
+colorFrom: "yellow"
+colorTo: "blue"
 sdk: gradio
-sdk_version: 4.36.1
+sdk_version: "4.29.0"
 app_file: app.py
 pinned: false
-license: apache-2.0
 ---
 
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d34083aa00b5137708967eba4245149d4545eb6
--- /dev/null
+++ b/app.py
@@ -0,0 +1,244 @@
+import gradio as gr
+import numpy as np
+import mlxu
+import os
+import re
+import torch
+
+from io import BytesIO
+from natsort import natsorted
+from PIL import Image
+
+from inference import LocalInferenceModel
+
+FLAGS, _ = mlxu.define_flags_with_default(
+    host='0.0.0.0',
+    port=5000,
+    dtype='float16',
+    checkpoint='Emma02/LVM_ckpts',
+    torch_devices='',
+    context_frames=16,
+)
+
+def natural_sort_key(s):
+    return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
+
+def load_example_image_groups(directory):
+    example_groups = {}
+    for subdir in os.listdir(directory):
+        subdir_path = os.path.join(directory, subdir)
+        if os.path.isdir(subdir_path):
+            example_groups[subdir] = []
+            images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+            images = natsorted(images, key=natural_sort_key)
+            for filename in images:
+                img = Image.open(os.path.join(subdir_path, filename))
+                example_groups[subdir].append(img)
+    return example_groups
+
+def main(_):
+    assert FLAGS.checkpoint != ''
+
+    model = LocalInferenceModel(
+        checkpoint=FLAGS.checkpoint,
+        torch_device=torch.device("cuda"),
+        dtype=FLAGS.dtype,
+        context_frames=FLAGS.context_frames,
+        use_lock=False,
+    )
+
+    checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
+    checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
+    checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
+
+    def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
+        assert len(input_images) > 0
+        input_images = [
+            np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
+            for img in input_images
+        ]
+        input_images = np.stack(input_images, axis=0)
+        output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
+
+        generated_images = []
+        for candidate in output_images:
+            concatenated_image = []
+            for i, img in enumerate(candidate):
+                concatenated_image.append(img)
+                if i < len(candidate) - 1:
+                    concatenated_image.append(checkerboard)
+            generated_images.append(
+                Image.fromarray(
+                    (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
+                )
+            )
+
+        return generated_images
+
+    with gr.Blocks(css="""
+        .small-button {
+            padding: 5px 10px;
+            min-width: 80px;
+        }
+        .large-gallery img {
+            width: 100%;
+            height: auto;
+            max-height: 150px;
+        }
+    """) as demo:
+        with gr.Column():
+            image_list = gr.State([])
+            gr.Markdown('# VQLM Demo')
+            gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
+            gr.Markdown('## Inputs')
+            with gr.Row():
+                upload_drag = gr.File(
+                    type='binary',
+                    file_types=['image'],
+                    file_count='multiple',
+                )
+                with gr.Column():
+                    gen_length_slider = gr.Slider(
+                        label='Generation length',
+                        minimum=1,
+                        maximum=32,
+                        value=1,
+                        step=1,
+                        interactive=True,
+                    )
+                    n_candidates_slider = gr.Slider(
+                        label='Number of candidates',
+                        minimum=1,
+                        maximum=10,
+                        value=1,
+                        step=1,
+                        interactive=True,
+                    )
+                    temp_slider = gr.Slider(
+                        label='Temperature',
+                        minimum=0,
+                        maximum=2.0,
+                        value=1.0,
+                        interactive=True,
+                    )
+                    top_p_slider = gr.Slider(
+                        label='Top p',
+                        minimum=0,
+                        maximum=1.0,
+                        value=0.9,
+                        interactive=True,
+                    )
+                    clear_btn = gr.Button(
+                        value='Clear',
+                        elem_classes=['small-button'],
+                    )
+                    generate_btn = gr.Button(
+                        value='Generate',
+                        interactive=False,
+                        elem_classes=['small-button'],
+                    )
+            input_gallery = gr.Gallery(
+                columns=7,
+                rows=1,
+                object_fit='scale-down',
+                label="Input image sequence"
+            )
+            gr.Markdown('## Outputs')
+            output_gallery = gr.Gallery(
+                columns=4,
+                object_fit='scale-down',
+                label="Output image"
+            )
+
+        def upload_image_fn(files, images):
+            for file in files:
+                images.append(Image.open(BytesIO(file)))
+
+            return {
+                upload_drag: None,
+                image_list: images,
+                input_gallery: images,
+                generate_btn: gr.update(interactive=True),
+            }
+
+        def clear_fn():
+            return {
+                image_list: [],
+                input_gallery: [],
+                generate_btn: gr.update(interactive=False),
+                output_gallery: [],
+            }
+
+        def disable_generate_btn():
+            return {
+                generate_btn: gr.update(interactive=False),
+            }
+
+        def generate_fn(images, n_candidates, gen_length, temperature, top_p):
+            new_images = generate_images(
+                images,
+                gen_length,
+                n_candidates=n_candidates,
+                temperature=temperature,
+                top_p=top_p,
+            )
+            return {
+                output_gallery: new_images,
+                generate_btn: gr.update(interactive=True),
+            }
+
+        upload_drag.upload(
+            upload_image_fn,
+            inputs=[upload_drag, image_list],
+            outputs=[upload_drag, image_list, input_gallery, generate_btn],
+        )
+        clear_btn.click(
+            clear_fn,
+            inputs=None,
+            outputs=[image_list, input_gallery, generate_btn, output_gallery],
+        )
+        generate_btn.click(
+            disable_generate_btn,
+            inputs=None,
+            outputs=[generate_btn],
+        ).then(
+            generate_fn,
+            inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
+            outputs=[output_gallery, generate_btn],
+        )
+
+        example_groups = load_example_image_groups('prompts')
+
+        def add_image_group_fn(group_name, images):
+            new_images = images + example_groups[group_name]
+            return {
+                image_list: new_images,
+                input_gallery: new_images,
+                generate_btn: gr.update(interactive=True),
+            }
+
+        for group_name, group_images in example_groups.items():
+            with gr.Row():
+                with gr.Column(scale=3):
+                    add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
+                with gr.Column(scale=7):
+                    group_gallery = gr.Gallery(
+                        value=[Image.fromarray(np.array(img)) for img in group_images],
+                        columns=5,
+                        rows=1,
+                        object_fit='scale-down',
+                        label=group_name,
+                        elem_classes=['large-gallery'],
+                    )
+
+                add_button.click(
+                    add_image_group_fn,
+                    inputs=[gr.State(group_name), image_list],
+                    outputs=[image_list, input_gallery, generate_btn],
+                )
+
+    demo.launch()
+
+if __name__ == "__main__":
+    mlxu.run(main)
+
diff --git a/batch_generation.py b/batch_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c044af5d192ad050be4199463085cb8df214dc8
--- /dev/null
+++ b/batch_generation.py
@@ -0,0 +1,223 @@
+"""
+Batch generation for sequnce of images. This script accept a jsonl file
+as input. Each line of the jsonl file representing a dictionary. Each line
+represents one example in the evaluation set. The dictionary should have two key:
+
+    input: a list of paths to the input images as context to the model.
+    output: a string representing the path to the output of generation to be saved.
+
+Ths script runs the mode to generate the output images, and concatenate the
+input and output images together and save them to the output path.
+"""
+
+import os
+import json
+from PIL import Image
+import numpy as np
+import mlxu
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+import einops
+import torch
+
+from .inference import MultiProcessInferenceModel
+from .utils import read_image_to_tensor, MultiProcessImageSaver
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+    input_file='',
+    checkpoint='',
+    input_base_dir='',
+    output_base_dir='',
+    evaluate_mse=False,
+    json_input_key='input',
+    json_output_key='output',
+    json_target_key='target',
+    n_new_frames=1,
+    n_candidates=2,
+    context_frames=16,
+    temperature=1.0,
+    top_p=1.0,
+    n_workers=8,
+    dtype='float16',
+    torch_devices='',
+    batch_size_factor=4,
+    max_examples=0,
+    resize_output='',
+    include_input=False,
+)
+
+# create this according to the json file.
+class MultiFrameDataset(torch.utils.data.Dataset):
+    def __init__(self, input_files, output_files, target_files=None):
+        assert len(input_files)
+        self.input_files = input_files
+        self.output_files = output_files
+        self.target_files = target_files
+
+    def __len__(self):
+        return len(self.input_files)
+
+    def __getitem__(self, idx):
+        original_size = Image.open(self.input_files[idx][-1]).size
+        input_images = np.stack(
+            [read_image_to_tensor(f) for f in self.input_files[idx]],
+            axis=0
+        )
+
+        if self.target_files is not None:
+            target_images = np.stack(
+                [read_image_to_tensor(f) for f in self.target_files[idx]],
+                axis=0
+            )
+        else:
+            target_images = None
+        return input_images, target_images, self.output_files[idx], np.array(original_size)
+
+
+def main(_):
+    assert FLAGS.checkpoint != ''
+
+    print(f'Loading checkpoint from {FLAGS.checkpoint}')
+    print(f'Evaluating input file from {FLAGS.input_file}')
+
+    # build a model.
+
+    model = MultiProcessInferenceModel(
+        checkpoint=FLAGS.checkpoint,
+        torch_devices=FLAGS.torch_devices,
+        dtype=FLAGS.dtype,
+        context_frames=FLAGS.context_frames,
+        use_lock=True,
+    )
+
+    # input_files: the json file that needs to be generated by the other file.
+    input_files = []
+    output_files = []
+
+    if FLAGS.evaluate_mse:
+        target_files = []
+    else:
+        target_files = None
+
+    with mlxu.open_file(FLAGS.input_file, 'r') as f:
+        for line in f:
+            record = json.loads(line)
+            input_files.append(record[FLAGS.json_input_key])
+            output_files.append(record[FLAGS.json_output_key])
+            if FLAGS.evaluate_mse:
+                target_files.append(record[FLAGS.json_target_key])
+
+
+    if FLAGS.max_examples > 0:
+        input_files = input_files[:FLAGS.max_examples]
+        output_files = output_files[:FLAGS.max_examples]
+        if FLAGS.evaluate_mse:
+            target_files = target_files[:FLAGS.max_examples]
+
+    if FLAGS.input_base_dir != '':
+        input_files = [
+            [os.path.join(FLAGS.input_base_dir, x) for x in y]
+            for y in input_files
+        ]
+        if FLAGS.evaluate_mse:
+            target_files = [
+                [os.path.join(FLAGS.input_base_dir, x) for x in y]
+                for y in target_files
+            ]
+
+    if FLAGS.output_base_dir != '':
+        os.makedirs(FLAGS.output_base_dir, exist_ok=True)
+        output_files = [
+            os.path.join(FLAGS.output_base_dir, x)
+            for x in output_files
+        ]
+
+    dataset = MultiFrameDataset(input_files, output_files, target_files)
+
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=FLAGS.batch_size_factor * model.n_processes,
+        shuffle=False,
+        num_workers=FLAGS.n_workers,
+    )
+
+    image_saver = MultiProcessImageSaver(FLAGS.n_workers)
+
+    mses = []
+
+    for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0):
+        
+        # batch_images is input.
+        batch_images = batch_images.numpy()
+
+        # 
+        context_length = batch_images.shape[1]
+        
+
+        generated_images = model(
+            batch_images,
+            FLAGS.n_new_frames,
+            FLAGS.n_candidates,
+            temperature=FLAGS.temperature,
+            top_p=FLAGS.top_p
+        )
+        
+        
+        repeated_batch = einops.repeat(
+            batch_images,
+            'b s h w c -> b n s h w c',
+            n=FLAGS.n_candidates,
+        )
+        generated_images = np.array(generated_images)
+
+        if FLAGS.evaluate_mse:
+            batch_targets = einops.repeat(
+                batch_targets.numpy(),
+                'b s h w c -> b n s h w c', # batch, candidate, s
+                n=FLAGS.n_candidates,
+            )
+            channels = batch_targets.shape[-1]
+            # calculate mse loss.
+            mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5))
+
+            mses.append(mse * channels)
+
+
+        if FLAGS.include_input:
+            combined = einops.rearrange(
+                np.concatenate([repeated_batch, generated_images], axis=2),
+                'b n s h w c -> b (n h) (s w) c'
+            )
+        else:
+            combined = einops.rearrange(
+                generated_images,
+                'b n s h w c -> b (n h) (s w) c'
+            )
+        combined = (combined * 255).astype(np.uint8)
+
+        n_frames = FLAGS.n_new_frames
+        if FLAGS.include_input:
+            n_frames += context_length
+
+        if FLAGS.resize_output == '':
+            resizes = None
+
+        elif FLAGS.resize_output == 'original':
+            resizes = batch_sizes.numpy()
+            resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
+        else:
+            resize = tuple(int(x) for x in FLAGS.resize_output.split(','))
+            resizes = np.array([resize] * len(batch_sizes))
+            resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]])
+
+        image_saver(combined, batch_output_files, resizes)
+
+    if FLAGS.evaluate_mse:
+        mses = np.concatenate(mses, axis=0)
+        print(f'MSE: {np.mean(mses)}')
+
+    image_saver.close()
+
+if __name__ == "__main__":
+    mlxu.run(main)
\ No newline at end of file
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..57060a49b8f8b92ef4243154e841268c5a3fbc28
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,263 @@
+import re
+from natsort import natsorted
+
+def natural_sort_key(s):
+    return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
+
+def load_example_image_groups(directory):
+    example_groups = {}
+    for subdir in os.listdir(directory):
+        subdir_path = os.path.join(directory, subdir)
+        if os.path.isdir(subdir_path):
+            example_groups[subdir] = []
+            images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+            images = natsorted(images, key=natural_sort_key)  # Natural sorting
+            for filename in images:
+                img = Image.open(os.path.join(subdir_path, filename))
+                example_groups[subdir].append(img)
+    return example_groups
+
+
+from io import BytesIO
+import gradio as gr
+import uvicorn
+from fastapi import FastAPI
+from PIL import Image
+import numpy as np
+import mlxu
+import os
+import re
+from natsort import natsorted
+
+from .inference import MultiProcessInferenceModel
+
+FLAGS, _ = mlxu.define_flags_with_default(
+    host='0.0.0.0',
+    port=5007,
+    dtype='float16',
+    checkpoint='',
+    torch_devices='',
+    context_frames=16,
+)
+
+def natural_sort_key(s):
+    return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]
+
+def load_example_image_groups(directory):
+    example_groups = {}
+    for subdir in os.listdir(directory):
+        subdir_path = os.path.join(directory, subdir)
+        if os.path.isdir(subdir_path):
+            example_groups[subdir] = []
+            images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+            images = natsorted(images, key=natural_sort_key)  # Natural sorting
+            for filename in images:
+                img = Image.open(os.path.join(subdir_path, filename))
+                example_groups[subdir].append(img)
+    return example_groups
+
+def main(_):
+    assert FLAGS.checkpoint != ''
+
+    model = MultiProcessInferenceModel(
+        checkpoint=FLAGS.checkpoint,
+        torch_devices=FLAGS.torch_devices,
+        dtype=FLAGS.dtype,
+        context_frames=FLAGS.context_frames,
+        use_lock=True,
+    )
+
+    checkerboard_r1 = np.concatenate([np.zeros((8, 8, 3)), np.ones((8, 8, 3)), np.zeros((8, 8, 3))], axis=1)
+    checkerboard_r2 = np.concatenate([np.ones((8, 8, 3)), np.zeros((8, 8, 3)), np.ones((8, 8, 3))], axis=1)
+    checkerboard = np.concatenate([checkerboard_r1, checkerboard_r2] * 16, axis=0).astype(np.float32)
+
+    def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
+        assert len(input_images) > 0
+        input_images = [
+            np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
+            for img in input_images
+        ]
+        input_images = np.stack(input_images, axis=0)
+        output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]
+
+        generated_images = []
+        for candidate in output_images:
+            concatenated_image = []
+            for i, img in enumerate(candidate):
+                concatenated_image.append(img)
+                if i < len(candidate) - 1:
+                    concatenated_image.append(checkerboard)
+            generated_images.append(
+                Image.fromarray(
+                    (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
+                )
+            )
+
+        return generated_images
+
+    with gr.Blocks(css="""
+        .small-button {
+            padding: 5px 10px; 
+            min-width: 80px;
+        }
+        .large-gallery img {
+            width: 100%; 
+            height: auto; 
+            max-height: 150px;
+        }
+    """) as demo:
+        with gr.Column():
+            image_list = gr.State([])
+            gr.Markdown('# LVM Demo')
+            gr.Markdown(f'Serving model: {FLAGS.checkpoint}')
+            gr.Markdown('## Inputs')
+            with gr.Row():
+                upload_drag = gr.File(
+                    type='binary',
+                    file_types=['image'],
+                    file_count='multiple',
+                )
+                with gr.Column():
+                    gen_length_slider = gr.Slider(
+                        label='Generation length',
+                        minimum=1,
+                        maximum=32,
+                        value=1,
+                        step=1,
+                        interactive=True,
+                    )
+                    n_candidates_slider = gr.Slider(
+                        label='Number of candidates',
+                        minimum=1,
+                        maximum=10,
+                        value=1,
+                        step=1,
+                        interactive=True,
+                    )
+                    temp_slider = gr.Slider(
+                        label='Temperature',
+                        minimum=0,
+                        maximum=2.0,
+                        value=1.0,
+                        interactive=True,
+                    )
+                    top_p_slider = gr.Slider(
+                        label='Top p',
+                        minimum=0,
+                        maximum=1.0,
+                        value=0.9,
+                        interactive=True,
+                    )
+                    clear_btn = gr.Button(
+                        value='Clear',
+                        elem_classes=['small-button'],
+                    )
+                    generate_btn = gr.Button(
+                        value='Generate',
+                        interactive=False,
+                        elem_classes=['small-button'],
+                    )
+            input_gallery = gr.Gallery(
+                columns=7,
+                rows=1,
+                object_fit='scale-down',
+            )
+            gr.Markdown('## Outputs')
+            output_gallery = gr.Gallery(
+                columns=4,
+                object_fit='scale-down',
+            )
+
+        def upload_image_fn(files, images):
+            for file in files:
+                images.append(Image.open(BytesIO(file)))
+
+            return {
+                upload_drag: None,
+                image_list: images,
+                input_gallery: images,
+                generate_btn: gr.update(interactive=True),
+            }
+
+        def clear_fn():
+            return {
+                image_list: [],
+                input_gallery: [],
+                generate_btn: gr.update(interactive=False),
+                output_gallery: [],
+            }
+
+        def disable_generate_btn():
+            return {
+                generate_btn: gr.update(interactive=False),
+            }
+
+        def generate_fn(images, n_candidates, gen_length, temperature, top_p):
+            new_images = generate_images(
+                images,
+                gen_length,
+                n_candidates=n_candidates,
+                temperature=temperature,
+                top_p=top_p,
+            )
+            return {
+                output_gallery: new_images,
+                generate_btn: gr.update(interactive=True),
+            }
+
+        upload_drag.upload(
+            upload_image_fn,
+            inputs=[upload_drag, image_list],
+            outputs=[upload_drag, image_list, input_gallery, generate_btn],
+        )
+        clear_btn.click(
+            clear_fn,
+            inputs=None,
+            outputs=[image_list, input_gallery, generate_btn, output_gallery],
+        )
+        generate_btn.click(
+            disable_generate_btn,
+            inputs=None,
+            outputs=[generate_btn],
+        ).then(
+            generate_fn,
+            inputs=[image_list, n_candidates_slider, gen_length_slider, temp_slider, top_p_slider],
+            outputs=[output_gallery, generate_btn],
+        )
+
+        example_groups = load_example_image_groups('/home/yutongbai/demo_images')
+
+        def add_image_group_fn(group_name, images):
+            new_images = images + example_groups[group_name]
+            return {
+                image_list: new_images,
+                input_gallery: new_images,
+                generate_btn: gr.update(interactive=True),
+            }
+
+        for group_name, group_images in example_groups.items():
+            with gr.Row():
+                with gr.Column(scale=3):
+                    add_button = gr.Button(value=f'Add {group_name}', elem_classes=['small-button'])
+                with gr.Column(scale=7):
+                    group_gallery = gr.Gallery(
+                        value=[Image.fromarray(np.array(img)) for img in group_images],
+                        columns=5,
+                        rows=1,
+                        object_fit='scale-down',
+                        label=group_name,
+                        elem_classes=['large-gallery'],
+                    )
+                
+                add_button.click(
+                    add_image_group_fn,
+                    inputs=[gr.State(group_name), image_list],
+                    outputs=[image_list, input_gallery, generate_btn],
+                )
+
+    app = FastAPI()
+    app = gr.mount_gradio_app(app, demo, '/')
+    uvicorn.run(app, host=FLAGS.host, port=FLAGS.port)
+
+if __name__ == "__main__":
+    mlxu.run(main)
diff --git a/eval_perplexity.py b/eval_perplexity.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bdefb882fc039d294209f0bb5e254ede100f340
--- /dev/null
+++ b/eval_perplexity.py
@@ -0,0 +1,127 @@
+"""
+Evaluating the perplexity on few shot tasks. This script accept a jsonl file
+as input. Each line of the jsonl file representing a dictionary. Each line
+represents one example in the evaluation set. The dictionary should have two key:
+
+    input: a list of paths to the input images as context to the model. This
+        list should include the few shot examples.
+    target: a list of paths to the target images to evaluate perplexity
+
+Ths script should run the model and compute the average perplexity on the
+evaluation set.
+"""
+
+import os
+import json
+from PIL import Image
+import numpy as np
+import mlxu
+from tqdm import tqdm, trange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+
+from .inference import MultiProcessInferenceModel
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+    input_file='',
+    checkpoint='',
+    input_base_dir='',
+    batch_size=2,
+    json_input_key='input',
+    json_target_key='target',
+    dtype='float16',
+    torch_devices='',
+    n_workers=4,
+    max_examples=0,
+)
+
+
+def read_image_to_tensor(path):
+    pil_im = Image.open(path).convert('RGB')
+    input_img = pil_im.resize((256, 256))
+    input_img = np.array(input_img) / 255.0
+    input_img = input_img.astype(np.float32)
+    return input_img
+
+
+class MultiFrameDataset(torch.utils.data.Dataset):
+    def __init__(self, input_files, target_files):
+        assert len(input_files) == len(target_files)
+        self.input_files = input_files
+        self.target_files = target_files
+
+    def __len__(self):
+        return len(self.input_files)
+
+    def __getitem__(self, idx):
+        input_list = np.stack(
+            [read_image_to_tensor(f) for f in self.input_files[idx]],
+            axis=0
+        )
+        target_list = np.stack(
+            [read_image_to_tensor(f) for f in self.target_files[idx]],
+            axis=0
+        )
+        return input_list, target_list
+
+
+def main(_):
+    assert FLAGS.checkpoint != ''
+
+    print(f'Loading checkpoint from {FLAGS.checkpoint}')
+    print(f'Evaluating input file from {FLAGS.input_file}')
+
+    model = MultiProcessInferenceModel(
+        checkpoint=FLAGS.checkpoint,
+        torch_devices=FLAGS.torch_devices,
+        dtype=FLAGS.dtype,
+        use_lock=True,
+        perplexity_batch_size=FLAGS.batch_size,
+    )
+
+    input_files = []
+    target_files = []
+
+    with mlxu.open_file(FLAGS.input_file, 'r') as f:
+        for line in f:
+            record = json.loads(line)
+            input_files.append(record[FLAGS.json_input_key])
+            target_files.append(record[FLAGS.json_target_key])
+
+    if FLAGS.input_base_dir != '':
+        input_files = [
+            [os.path.join(FLAGS.input_base_dir, x) for x in y]
+            for y in input_files
+        ]
+        target_files = [
+            [os.path.join(FLAGS.input_base_dir, x) for x in y]
+            for y in target_files
+        ]
+
+    if FLAGS.max_examples > 0:
+        input_files = input_files[:FLAGS.max_examples]
+        target_files = target_files[:FLAGS.max_examples]
+
+    dataset = MultiFrameDataset(input_files, target_files)
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=FLAGS.batch_size * model.n_processes,
+        shuffle=False,
+        num_workers=FLAGS.n_workers
+    )
+
+    perplexities = []
+
+    for input_images, target_images in tqdm(data_loader, ncols=0):
+        perplexity = model.compute_perplexity(input_images, target_images)
+        perplexities.append(perplexity)
+
+    perplexities = np.concatenate(perplexities, axis=0)
+    print(f'Perplexity: {np.mean(perplexities)}')
+
+
+if __name__ == "__main__":
+    mlxu.run(main)
\ No newline at end of file
diff --git a/eval_video_perplexity.py b/eval_video_perplexity.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ba70ee8f53089c9bc456dcf9ca4dda0faade1f
--- /dev/null
+++ b/eval_video_perplexity.py
@@ -0,0 +1,134 @@
+
+import os
+import glob
+from functools import partial
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+from PIL import Image
+import cv2
+import mlxu
+from natsort import natsorted
+import numpy as np
+import einops
+import torch
+
+from vqlm_demo.inference import MultiProcessInferenceModel
+from vqlm_demo.utils import (
+    is_video, random_square_crop,
+    read_frames_from_dir, read_frames_from_video
+)
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+    checkpoint='',
+    input_files='',
+    frame_input=False,
+    read_file_list='',
+    center_crop=1.0,
+    n_context_frames=15,
+    n_target_frames=1,
+    n_workers=8,
+    stride=8,
+    batch_size=2,
+    torch_devices='',
+    shuffle=False,
+    random_start=True,
+    max_examples=0,
+)
+
+
+class VideoDataset(torch.utils.data.Dataset):
+
+    def __init__(self, videos, frame_input=False, n_context_frames=15,
+                 n_target_frames=1, stride=1):
+        self.videos = videos
+        self.frame_input = frame_input
+        self.n_context_frames = n_context_frames
+        self.n_target_frames = n_target_frames
+        self.stride = stride
+
+    def __getitem__(self, index):
+        if self.frame_input:
+            frames = read_frames_from_dir(
+                self.videos[index],
+                self.n_context_frames + self.n_target_frames,
+                self.stride,
+                center_crop=FLAGS.center_crop,
+                random_start=FLAGS.random_start,
+            )
+        else:
+            frames = read_frames_from_video(
+                self.videos[index],
+                self.n_context_frames + self.n_target_frames,
+                self.stride,
+                center_crop=FLAGS.center_crop,
+                random_start=FLAGS.random_start,
+            )
+        if frames is None:
+            return self[np.random.randint(0, len(self))]
+        return frames[:self.n_context_frames], frames[self.n_context_frames:]
+
+    def __len__(self):
+        return len(self.videos)
+
+
+
+def main(_):
+    assert FLAGS.checkpoint != ''
+    assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
+
+    model = MultiProcessInferenceModel(
+        checkpoint=FLAGS.checkpoint,
+        torch_devices=FLAGS.torch_devices,
+        perplexity_batch_size=FLAGS.batch_size,
+    )
+
+    if FLAGS.read_file_list != '':
+        with open(FLAGS.read_file_list, 'r') as f:
+            videos = [x.strip() for x in f.readlines()]
+    else:
+        videos = glob.glob(FLAGS.input_files)
+
+    if FLAGS.frame_input:
+        videos = [x for x in videos if os.path.isdir(x)]
+    else:
+        videos = [x for x in videos if is_video(x)]
+
+    if FLAGS.shuffle:
+        np.random.shuffle(videos)
+
+    if FLAGS.max_examples > 0:
+        videos = videos[:FLAGS.max_examples]
+
+    dataset = VideoDataset(
+        videos,
+        frame_input=FLAGS.frame_input,
+        n_context_frames=FLAGS.n_context_frames,
+        n_target_frames=FLAGS.n_target_frames,
+        stride=FLAGS.stride
+    )
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=FLAGS.batch_size * model.n_processes * 4,
+        shuffle=False,
+        num_workers=FLAGS.n_workers,
+        prefetch_factor=4,
+        drop_last=True,
+    )
+
+    perplexities = []
+
+    for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0):
+        batch_context_frames = batch_context_frames.numpy()
+        batch_taret_frames = batch_taret_frames.numpy()
+        perplexity = model.compute_perplexity(
+            batch_context_frames, batch_taret_frames
+        )
+        perplexities.append(perplexity)
+
+    perplexities = np.concatenate(perplexities, axis=0)
+    print(f'Perplexity: {np.mean(perplexities)}')
+
+
+if __name__ == '__main__':
+    mlxu.run(main)
\ No newline at end of file
diff --git a/eval_videos.py b/eval_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..4822e54d1de6e622eb5e2ccd6212f12c5130ea7a
--- /dev/null
+++ b/eval_videos.py
@@ -0,0 +1,160 @@
+import os
+import glob
+from functools import partial
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+from PIL import Image
+import cv2
+import mlxu
+from natsort import natsorted
+import numpy as np
+import einops
+import torch
+
+from vqlm_demo.inference import MultiProcessInferenceModel
+from vqlm_demo.utils import (
+    is_video, random_square_crop,
+    read_frames_from_dir, read_frames_from_video
+)
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+    checkpoint='',
+    input_files='',
+    frame_input=False,
+    read_file_list='',
+    output_dir='',
+    center_crop=1.0,
+    n_context_frames=12,
+    n_new_frames=4,
+    n_candidates=8,
+    temperature=1.0,
+    top_p=1.0,
+    n_workers=8,
+    stride=8,
+    batch_size=32,
+    torch_devices='',
+    shuffle=False,
+    max_examples=0,
+)
+
+
+def save_image(args):
+    image, filename = args
+    base = FLAGS.input_files.split('*')[0]
+    filename = filename[len(base):].replace('/', '_') + '.png'
+    Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
+
+
+class VideoDataset(torch.utils.data.Dataset):
+
+    def __init__(self, videos, frame_input=False, n_frames=8, stride=1, new_frame=1):
+        self.videos = videos
+        self.frame_input = frame_input
+        self.n_frames = n_frames
+        self.stride = stride
+        self.new_frames = new_frames
+
+    def __getitem__(self, index):
+        if self.frame_input:
+            frames = read_frames_from_dir(
+                self.videos[index], self.n_frames, self.stride,
+                center_crop=FLAGS.center_crop,
+            )
+
+        else:
+            # 's h w c'
+            frames = read_frames_from_video(
+                self.videos[index], self.n_frames, self.stride,
+                center_crop=FLAGS.center_crop,
+            )
+            target_frames = frames[n_frames-new_frame:n_frames, :, :, :]
+
+        if frames is None:
+            return self[np.random.randint(0, len(self))]
+        
+        
+        return frames, target_frames, self.videos[index]
+
+    def __len__(self):
+        return len(self.videos)
+
+
+
+def main(_):
+    assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
+    assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
+    os.makedirs(FLAGS.output_dir, exist_ok=True)
+
+    if FLAGS.read_file_list != '':
+        with open(FLAGS.read_file_list, 'r') as f:
+            videos = [x.strip() for x in f.readlines()]
+    else:
+        videos = glob.glob(FLAGS.input_files)
+
+    if FLAGS.frame_input:
+        videos = [x for x in videos if os.path.isdir(x)]
+    else:
+        videos = [x for x in videos if is_video(x)]
+
+    if FLAGS.shuffle:
+        np.random.shuffle(videos)
+
+    if FLAGS.max_examples > 0:
+        videos = videos[:FLAGS.max_examples]
+
+    dataset = VideoDataset(
+        videos,
+        frame_input=FLAGS.frame_input,
+        n_frames=FLAGS.n_context_frames,
+        stride=FLAGS.stride
+    )
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=FLAGS.batch_size,
+        shuffle=False,
+        num_workers=FLAGS.n_workers,
+        prefetch_factor=4,
+        drop_last=True,
+    )
+
+    if FLAGS.torch_devices == '':
+        torch_devices = None
+    else:
+        torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
+
+    model = MultiProcessInferenceModel(
+        checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
+    )
+
+    save_img_pool = Pool(FLAGS.n_workers)
+
+
+    fids
+
+    for batch, batch_targets, filenames in tqdm(dataloader, ncols=0):
+        
+        batch = batch.numpy() # 'b s h w c '
+
+
+
+        generated = model(
+            batch,
+            n_new_frames=FLAGS.n_new_frames,
+            n_candidates=FLAGS.n_candidates,
+            temperature=FLAGS.temperature,
+            top_p=FLAGS.top_p,
+        )
+
+        
+        generated = np.array(generated)
+
+        batch_targets = einops.repeat(
+            batch_targets.numpy(),
+            'b s h w c -> b n s h w c', # batch, candidate, sequence, h, w, c. 
+            n=FLAGS.n_candidates,
+        )
+
+
+if __name__ == '__main__':
+    mlxu.run(main)
\ No newline at end of file
diff --git a/generate_videos.py b/generate_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f6d129adb843cd53723f32e16555a3108cda76
--- /dev/null
+++ b/generate_videos.py
@@ -0,0 +1,168 @@
+
+import os
+import glob
+from functools import partial
+from tqdm import tqdm, trange
+from multiprocessing import Pool
+from PIL import Image
+import cv2
+import mlxu
+from natsort import natsorted
+import numpy as np
+import einops
+import torch
+
+from vqlm_demo.inference import MultiProcessInferenceModel
+from vqlm_demo.utils import (
+    is_video, random_square_crop,
+    read_frames_from_dir, read_frames_from_video
+)
+
+
+FLAGS, _ = mlxu.define_flags_with_default(
+    checkpoint='',
+    input_files='',
+    frame_input=False,
+    read_file_list='',
+    output_dir='',
+    center_crop=1.0,
+    n_context_frames=12,
+    n_new_frames=4,
+    n_candidates=8,
+    temperature=1.0,
+    top_p=1.0,
+    n_workers=8,
+    stride=8,
+    batch_size=32,
+    torch_devices='',
+    shuffle=False,
+    max_examples=0,
+)
+
+
+def save_image(args):
+    image, filename = args
+    base = FLAGS.input_files.split('*')[0]
+    filename = filename[len(base):].replace('/', '_') + '.png'
+    Image.fromarray(image).save(os.path.join(FLAGS.output_dir, filename))
+
+
+class VideoDataset(torch.utils.data.Dataset):
+
+    def __init__(self, videos, frame_input=False, n_frames=8, stride=1):
+        self.videos = videos
+        self.frame_input = frame_input
+        self.n_frames = n_frames
+        self.stride = stride
+
+    def __getitem__(self, index):
+        if self.frame_input:
+            frames = read_frames_from_dir(
+                self.videos[index], self.n_frames, self.stride,
+                center_crop=FLAGS.center_crop,
+            )
+        else:
+            frames = read_frames_from_video(
+                self.videos[index], self.n_frames, self.stride,
+                center_crop=FLAGS.center_crop,
+            )
+        if frames is None:
+            return self[np.random.randint(0, len(self))]
+        return frames, self.videos[index]
+
+    def __len__(self):
+        return len(self.videos)
+
+
+
+def main(_):
+    assert FLAGS.checkpoint != '' and FLAGS.output_dir != ''
+    assert FLAGS.read_file_list != '' or FLAGS.input_files != ''
+    os.makedirs(FLAGS.output_dir, exist_ok=True)
+
+    if FLAGS.read_file_list != '':
+        with open(FLAGS.read_file_list, 'r') as f:
+            videos = [x.strip() for x in f.readlines()]
+    else:
+        videos = glob.glob(FLAGS.input_files)
+
+    if FLAGS.frame_input:
+        videos = [x for x in videos if os.path.isdir(x)]
+    else:
+        videos = [x for x in videos if is_video(x)]
+
+    if FLAGS.shuffle:
+        np.random.shuffle(videos)
+
+    if FLAGS.max_examples > 0:
+        videos = videos[:FLAGS.max_examples]
+
+    dataset = VideoDataset(
+        videos,
+        frame_input=FLAGS.frame_input,
+        n_frames=FLAGS.n_context_frames,
+        stride=FLAGS.stride
+    )
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=FLAGS.batch_size,
+        shuffle=False,
+        num_workers=FLAGS.n_workers,
+        prefetch_factor=4,
+        drop_last=True,
+    )
+
+    if FLAGS.torch_devices == '':
+        torch_devices = None
+    else:
+        torch_devices = [f'cuda:{x}' for x in FLAGS.torch_devices.split(',')]
+
+    model = MultiProcessInferenceModel(
+        checkpoint=FLAGS.checkpoint, torch_devices=torch_devices,
+    )
+
+    save_img_pool = Pool(FLAGS.n_workers)
+
+
+
+    for batch, filenames in tqdm(dataloader, ncols=0):
+        
+        
+        
+        batch = batch.numpy()
+
+
+
+        generated = model(
+            batch,
+            n_new_frames=FLAGS.n_new_frames,
+            n_candidates=FLAGS.n_candidates,
+            temperature=FLAGS.temperature,
+            top_p=FLAGS.top_p,
+        )
+
+
+        generated = np.array(generated)
+
+
+
+
+        output_batch = einops.repeat(
+            batch,
+            'b s h w c -> b n s h w c',
+            n=FLAGS.n_candidates,
+        )
+
+
+        combined = einops.rearrange(
+            np.concatenate([output_batch, generated], axis=2),
+            'b n s h w c -> b (n h) (s w) c'
+        )
+
+        
+        combined = (np.clip(combined, 0, 1) * 255).astype(np.uint8)
+        save_img_pool.imap(save_image, zip(combined, filenames))
+
+
+if __name__ == '__main__':
+    mlxu.run(main)
\ No newline at end of file
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..73b1b816a632da1e9fa2ec70b608c1d04fc54e1a
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,240 @@
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+import time
+import os
+from functools import partial
+from copy import deepcopy
+from multiprocessing import Pool
+from threading import Lock
+from PIL import Image
+import numpy as np
+import torch
+import torch.nn.functional as F
+import einops
+from transformers import LlamaForCausalLM
+import spaces
+
+from vqvae_muse import VQGANModel, get_tokenizer_muse
+from torch_vqvae_model import get_tokenizer
+
+
+def get_torch_float_dtype(dtype):
+    if dtype in (torch.float16, torch.bfloat16, torch.float32):
+        return dtype
+    return {
+        'float16': torch.float16,
+        'fp16': torch.float16,
+        'f16': torch.float16,
+        'bfloat16': torch.bfloat16,
+        'bf16': torch.bfloat16,
+        'float32': torch.float32,
+        'fp32': torch.float32,
+        'f32': torch.float32,
+    }[dtype]
+
+
+def get_pid():
+    time.sleep(1)
+    return os.getpid()
+
+
+class InferenceModel(ABC):
+
+    @abstractmethod
+    def __call__(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
+        raise NotImplementedError()
+
+
+class LocalInferenceModel(InferenceModel):
+
+    def __init__(self, checkpoint, dtype='float16', torch_device='cuda',
+                 context_frames=16, use_lock=False):
+        self.checkpoint = checkpoint
+        self.dtype = dtype
+        self.torch_device = torch_device
+        self.context_frames = context_frames
+
+        # new tokenizer
+        self.tokenizer = get_tokenizer_muse()
+        self.tokenizer.to(self.torch_device)
+
+        self.model = LlamaForCausalLM.from_pretrained(
+            self.checkpoint, torch_dtype=get_torch_float_dtype(self.dtype)
+        ).to(self.torch_device)
+        print("torch device", self.torch_device)
+        print("init device", self.model.device)
+
+        if use_lock:
+            self.lock = Lock()
+        else:
+            self.lock = nullcontext()
+
+    @torch.no_grad()
+    def compute_perplexity(self, input_images, target_images):
+        input_images = np.array(input_images)
+        target_images = np.array(target_images)
+        assert len(input_images.shape) == 5 and len(target_images.shape) == 5  # [B, S, H, W, C]
+        assert input_images.shape[0] == target_images.shape[0]
+        batch_size = input_images.shape[0]
+        with self.lock:
+            input_images = torch.tensor(
+                einops.rearrange(input_images, 'b s h w c -> b s c h w')
+            ).to(self.torch_device)
+            target_images = torch.tensor(
+                einops.rearrange(target_images, 'b s h w c -> b s c h w')
+            ).to(self.torch_device)
+            input_ids = self.tokenizer.tokenize(input_images).view(batch_size, -1)
+            target_ids = self.tokenizer.tokenize(target_images).view(batch_size, -1)
+            all_ids = torch.cat([input_ids, target_ids], dim=1)
+            logits = self.model(all_ids).logits
+            log_probs = F.log_softmax(logits, dim=-1)
+            target_ids_onehot = F.one_hot(target_ids, num_classes=logits.shape[-1])
+            target_log_probs = log_probs[:, input_ids.shape[1] - 1 : -1]
+            perplexity = torch.exp(
+                -torch.mean(
+                    torch.sum(target_log_probs * target_ids_onehot, dim=-1),
+                    dim=-1
+                )
+            )
+            return perplexity.detach().cpu().numpy()
+
+    @torch.no_grad()
+    def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0):
+        assert type(input_images) == np.ndarray
+        with self.lock:
+            input_images = np.array(input_images, dtype=np.float32)
+            input_images = torch.tensor(
+                einops.rearrange(input_images, 'b h w c -> b c h w')
+            ).to(self.torch_device)
+
+            # not quite sure why i need to redo it here
+            self.model.to(self.torch_device)
+            self.tokenizer.to(self.torch_device)
+
+            # new tokenizer
+            _, input_ids = self.tokenizer.encode(input_images)
+            input_ids = input_ids.view(1, -1)
+
+
+            input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
+
+            new_tokens = []
+            current_context_frames = input_ids.shape[1] // 256
+            fisrt_generation_left = self.context_frames - current_context_frames
+            first_new_frames = min(fisrt_generation_left, n_new_frames)
+            input_ids = self.model.generate(
+                input_ids=input_ids,
+                attention_mask=torch.ones_like(input_ids),
+                pad_token_id=8192,
+                max_new_tokens=256 * first_new_frames,
+                do_sample=True,
+                top_p=top_p,
+                temperature=temperature,
+                suppress_tokens=list(range(8192, self.model.vocab_size)),
+            )
+            new_tokens.append(input_ids[:, -256 * first_new_frames:])
+            input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
+
+            for _ in range(max(0, n_new_frames - first_new_frames)):
+                input_ids = self.model.generate(
+                    input_ids=input_ids,
+                    attention_mask=torch.ones_like(input_ids),
+                    pad_token_id=8192,
+                    max_new_tokens=256,
+                    do_sample=True,
+                    top_p=top_p,
+                    temperature=temperature,
+                    suppress_tokens=list(range(8192, self.model.vocab_size)),
+                )
+                new_tokens.append(input_ids[:, -256:])
+                input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
+
+            new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256)
+            new_images = einops.rearrange(
+                torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0),
+                'b c h w -> b h w c'
+            ).detach().cpu().numpy()
+        return new_images
+
+    @spaces.GPU(duration=180)
+    def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
+        output = []
+        for seq in input_images:
+            output.append(
+                [self.generate_once(seq, n_new_frames, temperature, top_p)
+                 for _ in range(n_candidates)]
+            )
+        return output
+
+
+class MultiProcessInferenceModel(InferenceModel):
+
+    def __init__(self, checkpoint, torch_devices=None, dtype='float16',
+                 context_frames=16, use_lock=False, perplexity_batch_size=2):
+        if torch_devices is None or torch_devices == '':
+            torch_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
+
+        self.torch_devices = torch_devices
+        self.n_processes = len(torch_devices)
+        print(f'Using {self.n_processes} processes for inference')
+        self.worker_pool = Pool(self.n_processes)
+        self.worker_pids = self.worker_pool.starmap(get_pid, [tuple() for _ in range(self.n_processes)])
+        self.device_map = {
+            pid: torch_device
+            for pid, torch_device in zip(self.worker_pids, self.torch_devices)
+        }
+        self.worker_pool.starmap(
+            self.initialize_worker,
+            [(self.device_map, checkpoint, dtype, context_frames) for _ in range(self.n_processes)]
+        )
+        self.perplexity_batch_size = perplexity_batch_size
+        if use_lock:
+            self.lock = Lock()
+        else:
+            self.lock = nullcontext()
+
+    @staticmethod
+    def initialize_worker(device_map, checkpoint, dtype, context_frames):
+        global _current_process_backend
+        torch_device = device_map[os.getpid()]
+        _current_process_backend = LocalInferenceModel(
+            checkpoint, dtype, torch_device, context_frames
+        )
+
+    @staticmethod
+    def generate_once(input_images, n_new_frames, temperature=1.0, top_p=1.0):
+        return _current_process_backend.generate_once(input_images, n_new_frames, temperature, top_p)
+
+    @staticmethod
+    def compute_perplexity_once(input_images, target_images):
+        return _current_process_backend.compute_perplexity(input_images, target_images)
+
+    def compute_perplexity(self, input_images, target_images):
+        with self.lock:
+            map_args = []
+            for i in range(0, len(input_images), self.perplexity_batch_size):
+                map_args.append((
+                    input_images[i : i + self.perplexity_batch_size],
+                    target_images[i : i + self.perplexity_batch_size]
+                ))
+            outputs = self.worker_pool.starmap(self.compute_perplexity_once, map_args)
+            return np.concatenate(outputs, axis=0)
+
+    def __call__(self, input_images, n_new_frames, n_candidates, temperature=1.0, top_p=1.0):
+        with self.lock:
+            map_args = []
+            for seq in input_images:
+                for _ in range(n_candidates):
+                    map_args.append((seq, n_new_frames, temperature, top_p))
+
+            outputs = self.worker_pool.starmap(self.generate_once, map_args)
+            reshaped_output = []
+            index = 0
+            for _ in range(len(input_images)):
+                candidates = []
+                for _ in range(n_candidates):
+                    candidates.append(outputs[index])
+                    index += 1
+                reshaped_output.append(candidates)
+        return reshaped_output
+
diff --git a/prompts/.DS_Store b/prompts/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..3b3b755588b8880f623b777ccfaf36048c71b851
Binary files /dev/null and b/prompts/.DS_Store differ
diff --git a/prompts/Composition/Slide1.png b/prompts/Composition/Slide1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c20dd2a265aef200d02fec3495af8cdb4fece30d
--- /dev/null
+++ b/prompts/Composition/Slide1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d926922e8e28f02c46e723b85d3d4969da271f892654ce492cf59cbf3f322a0
+size 194501
diff --git a/prompts/Composition/Slide10.png b/prompts/Composition/Slide10.png
new file mode 100644
index 0000000000000000000000000000000000000000..c8a0e88cbc36c13d575f223edd6681fe95f63a86
--- /dev/null
+++ b/prompts/Composition/Slide10.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4480a3a7905b703ca1802e5391ea47e90e84cdc7eacb5229ade606ce4f5b6bb
+size 443693
diff --git a/prompts/Composition/Slide11.png b/prompts/Composition/Slide11.png
new file mode 100644
index 0000000000000000000000000000000000000000..0549fe330577e0adc697dc03fb284aa15b14f441
--- /dev/null
+++ b/prompts/Composition/Slide11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:91cbe861bd47c4ec08e79bccdb64b993cc4b3b21549c346f834a985b1b0a1a6e
+size 464548
diff --git a/prompts/Composition/Slide12.png b/prompts/Composition/Slide12.png
new file mode 100644
index 0000000000000000000000000000000000000000..31116cc53413933baaaf93aa5c7a4373e713944d
--- /dev/null
+++ b/prompts/Composition/Slide12.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d05d2db2a5e7bc7e33795583e10cdc03ea53bacd250010680a161ab07b7ad65
+size 487835
diff --git a/prompts/Composition/Slide13.png b/prompts/Composition/Slide13.png
new file mode 100644
index 0000000000000000000000000000000000000000..f5c89f17f384cb855047917b3bdc589919cd4504
--- /dev/null
+++ b/prompts/Composition/Slide13.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d94cfad17df77fa90ab84bdd89d3ad09938a5fe768b4e211c2bac140b36c12cb
+size 489967
diff --git a/prompts/Composition/Slide14.png b/prompts/Composition/Slide14.png
new file mode 100644
index 0000000000000000000000000000000000000000..de90d3fa3d4b3af9d32fbce6803389b072d26322
--- /dev/null
+++ b/prompts/Composition/Slide14.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04b42409ec1ca2ddbde1114eb8426a34c5e0064159e224af808b766ae003d2fd
+size 492423
diff --git a/prompts/Composition/Slide15.png b/prompts/Composition/Slide15.png
new file mode 100644
index 0000000000000000000000000000000000000000..871bd579690ad66a0c714a1e5fdc33846aa9147c
--- /dev/null
+++ b/prompts/Composition/Slide15.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9919156ccdd9c2cbb30811529e94c83bb2afb277c90fed503ea4716be702cdde
+size 491891
diff --git a/prompts/Composition/Slide2.png b/prompts/Composition/Slide2.png
new file mode 100644
index 0000000000000000000000000000000000000000..dbe072dca6d9ef3fe4491d740af5df5c6b010c68
--- /dev/null
+++ b/prompts/Composition/Slide2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0c9f6467cc732b562c167770d38a164162e4454127a242a16e3bdae7e717d27
+size 193143
diff --git a/prompts/Composition/Slide3.png b/prompts/Composition/Slide3.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e8ac80de12ba6e1b5c4bf4e6bfa6ac7ebad7ad6
--- /dev/null
+++ b/prompts/Composition/Slide3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f702f10001fd9e7ad523753c884f8cef532da878d62656ffdbd566e104b67c7
+size 199394
diff --git a/prompts/Composition/Slide4.png b/prompts/Composition/Slide4.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e0643b567bf4a0181ff14b6b954356c30ad7b06
--- /dev/null
+++ b/prompts/Composition/Slide4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18c2d2384e4c97f35ae4cddc9bea4e600946eefefefff1f4fb683a51a54d4384
+size 202638
diff --git a/prompts/Composition/Slide5.png b/prompts/Composition/Slide5.png
new file mode 100644
index 0000000000000000000000000000000000000000..59c9ae2435567cdce73774b3ef5342d70a0f13da
--- /dev/null
+++ b/prompts/Composition/Slide5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4af292b97a2abe48d253fb2f1badd8d147402a3124fd12a2a0750307487c4f27
+size 190546
diff --git a/prompts/Composition/Slide6.png b/prompts/Composition/Slide6.png
new file mode 100644
index 0000000000000000000000000000000000000000..fb2d05758aadbd0ee184bc274ece6ede5714dd6c
--- /dev/null
+++ b/prompts/Composition/Slide6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8b5fe9521e4094950384fce57733496363750b6a7c816ebae3cf43e6bcdb626
+size 173097
diff --git a/prompts/Composition/Slide7.png b/prompts/Composition/Slide7.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa5e53aa0367d64009edede3023ab3ab1cdfa196
--- /dev/null
+++ b/prompts/Composition/Slide7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce9418614363adfcd1b96b6df3b990d8204ed0a0341c348f9f340d7c128b4900
+size 174070
diff --git a/prompts/Composition/Slide8.png b/prompts/Composition/Slide8.png
new file mode 100644
index 0000000000000000000000000000000000000000..d214f14a6149fd5fdf7a9a55b558b2e74b192359
--- /dev/null
+++ b/prompts/Composition/Slide8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66a589f649600c65b7e808d824322d5a7c36b39675704cd5857fc31ce4f5af7f
+size 180144
diff --git a/prompts/Composition/Slide9.png b/prompts/Composition/Slide9.png
new file mode 100644
index 0000000000000000000000000000000000000000..c2be5efa234c711a7da20cd11f717e797b1e9bf8
--- /dev/null
+++ b/prompts/Composition/Slide9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee514628d3da8c4525853c86d1e8d348de7a0641312a0e3c79fae6b5d73ae11f
+size 454702
diff --git a/prompts/Depth Estimation/1.png b/prompts/Depth Estimation/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Depth Estimation/1.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Depth Estimation/1_depth.png b/prompts/Depth Estimation/1_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4a416e3cf5ee9e7b46f95c5590905192d50f99f
--- /dev/null
+++ b/prompts/Depth Estimation/1_depth.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b22aa119576ab691bab3db3fdd7eacf53dadc9e4cb3a9bfe4f4cb9c6fc0f6c6
+size 13888
diff --git a/prompts/Depth Estimation/2.png b/prompts/Depth Estimation/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Depth Estimation/2.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Depth Estimation/2_depth.png b/prompts/Depth Estimation/2_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..70e9ccb6b44f0284cac2e39bc3e9e4981d5ac373
--- /dev/null
+++ b/prompts/Depth Estimation/2_depth.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:37eacaf9208cf21693ae99802697e8894a9e8cf40cc221c704a50358f14dc954
+size 12257
diff --git a/prompts/Depth Estimation/3.png b/prompts/Depth Estimation/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Depth Estimation/3.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Depth Estimation/3_depth.png b/prompts/Depth Estimation/3_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..f39d22d9069072ee46affd295b14f8e02a70ceb6
--- /dev/null
+++ b/prompts/Depth Estimation/3_depth.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a91a47d21378bef0d535f7e33c0185e60cc23baa7ced20bc1ffb028a5d95b5c4
+size 13332
diff --git a/prompts/Depth Estimation/4.png b/prompts/Depth Estimation/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Depth Estimation/4.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Depth Estimation/4_depth.png b/prompts/Depth Estimation/4_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..d578e93b900f2e0d74ca15723de3c92b90fc66d2
--- /dev/null
+++ b/prompts/Depth Estimation/4_depth.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0685d0c4755206910cb1b1feea54a1e843cdee9dd140e414c0df56a885b68d85
+size 13447
diff --git a/prompts/Depth Estimation/5.png b/prompts/Depth Estimation/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Depth Estimation/5.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Depth Estimation/5_depth.png b/prompts/Depth Estimation/5_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..9790c23cff31b23afe755370fae8c4fbac6316f3
--- /dev/null
+++ b/prompts/Depth Estimation/5_depth.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e9874362d8a1c85b0030590399a5f6388fe69b9dd42ec762313b97d37817eb7
+size 12020
diff --git a/prompts/Depth Estimation/6.png b/prompts/Depth Estimation/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Depth Estimation/6.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Depth Estimation/6_depth.png b/prompts/Depth Estimation/6_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f84aeef983f2a70a1a26033942ffbc1247eeaa5
--- /dev/null
+++ b/prompts/Depth Estimation/6_depth.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:588df6be45864d2164b5e215429f268ba82731540adb07cd4ea47db0ca8f5319
+size 11946
diff --git a/prompts/Depth Estimation/7.png b/prompts/Depth Estimation/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Depth Estimation/7.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Depth Estimation/7_depth.png b/prompts/Depth Estimation/7_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..fe14a639f330ac40b0ec82a7c19e8f61befe9496
--- /dev/null
+++ b/prompts/Depth Estimation/7_depth.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c0edcab180411a6966d899de8f282870293a5d275e58d4a185e3cb31d9ca6b0d
+size 13252
diff --git a/prompts/Depth Estimation/8.png b/prompts/Depth Estimation/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Depth Estimation/8.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Eaten Apples/1.png b/prompts/Eaten Apples/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..d32aa001b22d9a408645e06cc02ad52505e230a2
--- /dev/null
+++ b/prompts/Eaten Apples/1.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a75364bb67ce5741004e2bb18178b362fd1d4dee12a76d9ae4be2124fb3452a0
+size 199368
diff --git a/prompts/Eaten Apples/10.png b/prompts/Eaten Apples/10.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa963c0a65d4e7605efbabc10244c171172835be
--- /dev/null
+++ b/prompts/Eaten Apples/10.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:05f9235b7c283915d0d81b2423915f05f587b91d691ae0cae6f0bc5b68e84588
+size 142649
diff --git a/prompts/Eaten Apples/2.png b/prompts/Eaten Apples/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..dadf5d493d454a5ec31696919aa2be33eea1d6ab
--- /dev/null
+++ b/prompts/Eaten Apples/2.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25b08de2de0ac2bcc59be060bf19574931091c9dc6472f8122f7ac1243c59c6f
+size 214103
diff --git a/prompts/Eaten Apples/3.png b/prompts/Eaten Apples/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..17bb70919745f0f0401140da26a4aa67c9224875
--- /dev/null
+++ b/prompts/Eaten Apples/3.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eee49d97068ac9de19bf6aead9a4b0c88ba9108bc9eb9d19f43a3b5919c88367
+size 212059
diff --git a/prompts/Eaten Apples/4.png b/prompts/Eaten Apples/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..bc6979b7ca710ff50f9af0f26f6089d9a25a0b53
--- /dev/null
+++ b/prompts/Eaten Apples/4.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:88d401e4c2c2b1b21119b953e230a276305af15d295d0035a221e498665af5b4
+size 212147
diff --git a/prompts/Eaten Apples/5.png b/prompts/Eaten Apples/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..60e4028c1752e1caf2581949529b7c76db2b788b
--- /dev/null
+++ b/prompts/Eaten Apples/5.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:116e8eb9ecc170c4a00f54a3b7b8996b67cd585932e34f4e2a25f8e589b7ae3d
+size 204197
diff --git a/prompts/Eaten Apples/6.png b/prompts/Eaten Apples/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..3011620226f9165f918dd3e1120a3569e4ef3bfd
--- /dev/null
+++ b/prompts/Eaten Apples/6.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c683df4c90c7da5bee98499fbf7233b6ac13fe2480aa9e1d4cb80a25ff9a500
+size 192756
diff --git a/prompts/Eaten Apples/7.png b/prompts/Eaten Apples/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..c890af0c7b3b95c297ff07db0da4134754de58b4
--- /dev/null
+++ b/prompts/Eaten Apples/7.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e3ad84f16c9326a9819da7e2c9485705b073a47097f742592ada91c10f706c0
+size 181082
diff --git a/prompts/Eaten Apples/8.png b/prompts/Eaten Apples/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..89545aa186079020196c335da0200c2a1c88c80a
--- /dev/null
+++ b/prompts/Eaten Apples/8.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:919a814cf9c604923a3384c26473750d353bac17f3422865b68c1d86e45552f7
+size 167449
diff --git a/prompts/Edge Detection/1.png b/prompts/Edge Detection/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Edge Detection/1.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Edge Detection/1_edge.png b/prompts/Edge Detection/1_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..48f7dee44046eb0e7853d81be37870da7f7578fb
--- /dev/null
+++ b/prompts/Edge Detection/1_edge.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ca5651653eb5ed8e3d08600c936985982f5019ce6f2f2489e82c112ea75686a
+size 30563
diff --git a/prompts/Edge Detection/2.png b/prompts/Edge Detection/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Edge Detection/2.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Edge Detection/2_edge.png b/prompts/Edge Detection/2_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..49db076ecece6c9db3bee69a720b8207fadfd8e5
--- /dev/null
+++ b/prompts/Edge Detection/2_edge.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8a5327db1e20457b8ab478a296e5176f2779c5e4bb2c034e5b6f0183854866b
+size 30437
diff --git a/prompts/Edge Detection/3.png b/prompts/Edge Detection/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Edge Detection/3.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Edge Detection/3_edge.png b/prompts/Edge Detection/3_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f234ff898f4fc9f390bc4d3e5fd114662411b51
--- /dev/null
+++ b/prompts/Edge Detection/3_edge.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d12107a3ead99860b853db28ddd7c8cc3fe65966cdb9221f6afd2154dadeb507
+size 32196
diff --git a/prompts/Edge Detection/4.png b/prompts/Edge Detection/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Edge Detection/4.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Edge Detection/4_edge.png b/prompts/Edge Detection/4_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..171d43defccfc3a9473faff804b31564483380d6
--- /dev/null
+++ b/prompts/Edge Detection/4_edge.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a4d3730f25b2b7305dfc56875b12135b99c0a93cba8ac1a1ff899b7d68eb8ef
+size 39602
diff --git a/prompts/Edge Detection/5.png b/prompts/Edge Detection/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Edge Detection/5.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Edge Detection/5_edge.png b/prompts/Edge Detection/5_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..c2b7f3b15d3e3b4aab01eabebccb8556531d7619
--- /dev/null
+++ b/prompts/Edge Detection/5_edge.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c958bf6a35c685e34ae07216dd112ac4fcedd9c4629d096615333bb69b45d45c
+size 16448
diff --git a/prompts/Edge Detection/6.png b/prompts/Edge Detection/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Edge Detection/6.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Edge Detection/6_edge.png b/prompts/Edge Detection/6_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..9d62a8aa115b5cde4f7a317a744f96673102a8c4
--- /dev/null
+++ b/prompts/Edge Detection/6_edge.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75383c3fb8acb0318fb95c48c43b56df8574ca119e0f9f577066dab8fdb8fca3
+size 36706
diff --git a/prompts/Edge Detection/7.png b/prompts/Edge Detection/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Edge Detection/7.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Edge Detection/7_edge.png b/prompts/Edge Detection/7_edge.png
new file mode 100644
index 0000000000000000000000000000000000000000..be7cf97bb9f5f33d331ece657a7f23817bccf235
--- /dev/null
+++ b/prompts/Edge Detection/7_edge.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93a646f6a4fcd48b949aba04dbdeb781b01e0425a32a03d960e7c9617375fe90
+size 29210
diff --git a/prompts/Edge Detection/8.png b/prompts/Edge Detection/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Edge Detection/8.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Emoji/smile1.png b/prompts/Emoji/smile1.png
new file mode 100644
index 0000000000000000000000000000000000000000..069348959a22a84e5993121e89cf401fa576c923
--- /dev/null
+++ b/prompts/Emoji/smile1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8923aa6be860685271dbef2fcc1fc87230dec0dca44d0815f9ccdf1f8d5aea26
+size 21247
diff --git a/prompts/Emoji/smile2.png b/prompts/Emoji/smile2.png
new file mode 100644
index 0000000000000000000000000000000000000000..01948d4b355fdccce546cd367554db3724231938
--- /dev/null
+++ b/prompts/Emoji/smile2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beadf44be4d267c8f88d684bfd6b02ca13eeff5fb007f2fbfc8e2f52b8459c64
+size 22703
diff --git a/prompts/Emoji/smile3.png b/prompts/Emoji/smile3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a527c61b51f40daa4ad8425342a3918f140143bf
--- /dev/null
+++ b/prompts/Emoji/smile3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:919155477aa8d2d17502cacf0dbdcec8b8feba5faa2065e86b832ac6020e2169
+size 24308
diff --git a/prompts/Emoji/smile4.png b/prompts/Emoji/smile4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc8f33154f294ae7c6141c6df333645f8e314f80
--- /dev/null
+++ b/prompts/Emoji/smile4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9571b1c938cd2ccb522bdd15cddb4ff93b84c6ac938923a34f1ee01fdb2a002a
+size 24451
diff --git a/prompts/Object Tracking/Picture1.png b/prompts/Object Tracking/Picture1.png
new file mode 100644
index 0000000000000000000000000000000000000000..9ee9959fc01d71202482304f5031bf3cb7fa04db
--- /dev/null
+++ b/prompts/Object Tracking/Picture1.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d311a18e7a9c934ce775fa6db6576f9ed27b9fcac82f224626af8b5b074b3c35
+size 733163
diff --git a/prompts/Object Tracking/Picture2.png b/prompts/Object Tracking/Picture2.png
new file mode 100644
index 0000000000000000000000000000000000000000..5f23477cb63d62c8b57fe69360905c617bde0155
--- /dev/null
+++ b/prompts/Object Tracking/Picture2.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a2437f5bd694d69a384c82ad2b163eb665b50f2814a4a30cc1719e4436393f
+size 730944
diff --git a/prompts/Object Tracking/Picture3.png b/prompts/Object Tracking/Picture3.png
new file mode 100644
index 0000000000000000000000000000000000000000..580685c71839bb687193ff48f07300fb423ed80b
--- /dev/null
+++ b/prompts/Object Tracking/Picture3.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d31cdf20c257094e3624cc032c975a5326d904d53664ed7acfd022c829961e96
+size 723189
diff --git a/prompts/Object Tracking/Picture4.png b/prompts/Object Tracking/Picture4.png
new file mode 100644
index 0000000000000000000000000000000000000000..85b9aa885f9187d6b7124a04356874280a099caf
--- /dev/null
+++ b/prompts/Object Tracking/Picture4.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb5e559a8421bd48bcdb2689142b52ad3384a49d8b59c8f00fb65ab583c62cb1
+size 709894
diff --git a/prompts/Object Tracking/Picture5.png b/prompts/Object Tracking/Picture5.png
new file mode 100644
index 0000000000000000000000000000000000000000..f8b0f91b49529756511fd094ced04288c2636137
--- /dev/null
+++ b/prompts/Object Tracking/Picture5.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e486056f0c6898326b54410878f1c0deeba2a5c5db7acec9ba5a14617ab59068
+size 690117
diff --git a/prompts/Object Tracking/Picture6.png b/prompts/Object Tracking/Picture6.png
new file mode 100644
index 0000000000000000000000000000000000000000..10e56c3f9db3a014eef054561236356c444d84c3
--- /dev/null
+++ b/prompts/Object Tracking/Picture6.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf16762c2dd316e52721618f96bcc0811118f812427e933fef50f937adaea3d4
+size 671387
diff --git a/prompts/Object Tracking/Picture7.png b/prompts/Object Tracking/Picture7.png
new file mode 100644
index 0000000000000000000000000000000000000000..c41b3438aa20a8a8f6460003206c945100658073
--- /dev/null
+++ b/prompts/Object Tracking/Picture7.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:227bf301dc7f06208e897ab3b6092e15e0b32f7ca0ef77a51f8fb81bd258e9a3
+size 654387
diff --git a/prompts/Object Tracking/Picture8.png b/prompts/Object Tracking/Picture8.png
new file mode 100644
index 0000000000000000000000000000000000000000..6de5ddbfcfa353d01d176426271f2fca2c770961
--- /dev/null
+++ b/prompts/Object Tracking/Picture8.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e38170d147a836c88782b37d7f2bd60ee051016c6daf53ea879320f0967c818c
+size 642514
diff --git a/prompts/Object Tracking/Picture9.png b/prompts/Object Tracking/Picture9.png
new file mode 100644
index 0000000000000000000000000000000000000000..117698ee19e804cc082023da4dff47427f68532f
--- /dev/null
+++ b/prompts/Object Tracking/Picture9.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14dc7b90ae871b5be240e5775e1bd0be2ae18e73ef5ac87e7a96a928613c03b5
+size 622249
diff --git a/prompts/Outpainting/2.png b/prompts/Outpainting/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..b77b519a6a263166ce81d41efc2a17e528c5dafc
--- /dev/null
+++ b/prompts/Outpainting/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d73f108e47033bf08b8f5b4558e72c73231f5cb06e06c6283b782fe260739f1
+size 440489
diff --git a/prompts/Outpainting/3.png b/prompts/Outpainting/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e2d1fdfecf4447a2475ed09e23035f11d161e95
--- /dev/null
+++ b/prompts/Outpainting/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fba02eef20fd3a32ff01fc43bae60811185ce34cac343699a4e2b7383a81e8ef
+size 829260
diff --git a/prompts/Outpainting/4.png b/prompts/Outpainting/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..0375a22e663f9855153362011d827ee04dda1011
--- /dev/null
+++ b/prompts/Outpainting/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e36507bd52fc628b911f8e3d2806549344fa51a41821bb8a31b772296c5c8207
+size 1295391
diff --git a/prompts/Outpainting/5.png b/prompts/Outpainting/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..48e5357ccdadd953ecbeb7e45c832f44f063f9d1
--- /dev/null
+++ b/prompts/Outpainting/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42b3a730e6ad1adf868aac31dfadbd2773df5850b3602dcb84e4c10b57e3f4a1
+size 1980696
diff --git a/prompts/Segmentation/1.png b/prompts/Segmentation/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Segmentation/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Segmentation/1_seg.png b/prompts/Segmentation/1_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..7ea7e500d9d2dc981deb196fb5693acf96459636
--- /dev/null
+++ b/prompts/Segmentation/1_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2642904a2acd1bbf4cc099ccf3bb0d0f7ee3370079eda0a21d04df75923a2dd
+size 1234
diff --git a/prompts/Segmentation/2.png b/prompts/Segmentation/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Segmentation/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Segmentation/2_seg.png b/prompts/Segmentation/2_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..a38cbb9400b37bcda2aa0d92c022a2883512728b
--- /dev/null
+++ b/prompts/Segmentation/2_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1070ec66a0b1295a20871f0ef04e2f150e708b8fedcd915733b62978cbfcd0d
+size 2243
diff --git a/prompts/Segmentation/3.png b/prompts/Segmentation/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Segmentation/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Segmentation/3_seg.png b/prompts/Segmentation/3_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f3645ca6724fce3c749631221a657ae42dba25c
--- /dev/null
+++ b/prompts/Segmentation/3_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf8dc0e197459bdef57d3246e3cd5355f7a8c6d5f4822eb05fe1a8813d3f71b5
+size 1869
diff --git a/prompts/Segmentation/4.png b/prompts/Segmentation/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Segmentation/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Segmentation/4_seg.png b/prompts/Segmentation/4_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..358cb0caa8e26115253f7e579484ec207ecf11bf
--- /dev/null
+++ b/prompts/Segmentation/4_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:101793eff1e1cb3613cff94fd1cbd3e65ce9407a29dd9efd8ffa78f90443fa72
+size 3129
diff --git a/prompts/Segmentation/5.png b/prompts/Segmentation/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Segmentation/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Segmentation/5_seg.png b/prompts/Segmentation/5_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..4356240ef23bfbdff9edbf1a5ea36d75bae5e93a
--- /dev/null
+++ b/prompts/Segmentation/5_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fda0b493538ed52658c31498c678ce098229bd9c57c3e76db9096298f2b2309d
+size 1814
diff --git a/prompts/Segmentation/6.png b/prompts/Segmentation/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Segmentation/6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Segmentation/6_seg.png b/prompts/Segmentation/6_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..d0aa3baf6aed286426f0dd6a6190ce0c094333df
--- /dev/null
+++ b/prompts/Segmentation/6_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10e8d222b3db095de09124c8b129ac63da360776d588b1fb5027b31fa9b5d1d0
+size 1684
diff --git a/prompts/Segmentation/7.png b/prompts/Segmentation/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Segmentation/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Segmentation/7_seg.png b/prompts/Segmentation/7_seg.png
new file mode 100644
index 0000000000000000000000000000000000000000..f768b93c59c7e66c0617d1160a5381080b8b44f4
--- /dev/null
+++ b/prompts/Segmentation/7_seg.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c11400a1bc6c696c2f479165e5ffe7d0be2212c252b311b29eeae2d111c927a
+size 1887
diff --git a/prompts/Segmentation/8.png b/prompts/Segmentation/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Segmentation/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Surface Normal/1.png b/prompts/Surface Normal/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b177a5cc844032ae63fd8fedc1c85a4cca33f62
--- /dev/null
+++ b/prompts/Surface Normal/1.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d38e1e29282fcd6b540a5679870b10e18e8e03bf056a6d1bacf6e2e8a1b8b2
+size 48533
diff --git a/prompts/Surface Normal/1_surfave_norm.png b/prompts/Surface Normal/1_surfave_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9865e90d1661296e48a27338eef101aa5b26ba3
--- /dev/null
+++ b/prompts/Surface Normal/1_surfave_norm.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f40b8f98c35392de51e72db501b26dd04f3fa057287839df4076f0402c2fe05b
+size 42875
diff --git a/prompts/Surface Normal/2.png b/prompts/Surface Normal/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df2f5bbf2bb13d6ccdb619a59f94269d86f2fe5
--- /dev/null
+++ b/prompts/Surface Normal/2.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd8e5b14e677c5832bf0a6d1f7be1be9c10b7797345a2edd97dd8f284032511b
+size 54286
diff --git a/prompts/Surface Normal/2_surface_norm.png b/prompts/Surface Normal/2_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ca2da625560053b50957186fc30d9c9249333b5
--- /dev/null
+++ b/prompts/Surface Normal/2_surface_norm.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e52e3619a44c896fbfcc4cd5a60435fb2187c907da2122c7a5584363e8b7109
+size 45167
diff --git a/prompts/Surface Normal/3.png b/prompts/Surface Normal/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0288717878e6ebadaf8fbb98edf3081fad14e18
--- /dev/null
+++ b/prompts/Surface Normal/3.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e983fd4a47ad0b66e428c23305ae7cf634bd01a563126fdd51792805e29f9c00
+size 52593
diff --git a/prompts/Surface Normal/3_surface_norm.png b/prompts/Surface Normal/3_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..a17702432fea0029d9ad22445dac6b9ba1a2b4af
--- /dev/null
+++ b/prompts/Surface Normal/3_surface_norm.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2bb1d3d41c0f72faf0c4fca3c0030021650d0f36707cdfdb381d3ef9a5edb6b
+size 49907
diff --git a/prompts/Surface Normal/4.png b/prompts/Surface Normal/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..cec7dc85ab07d7f1ddda7ccca63e7eb874e7688a
--- /dev/null
+++ b/prompts/Surface Normal/4.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52a7b6aed029ee2d4c3fa8d4027eb8dc2a4f12a2e3d97c0bf3676aa2ce04d50d
+size 60589
diff --git a/prompts/Surface Normal/4_surface_norm.png b/prompts/Surface Normal/4_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..b92d9199042136de71381a3a0d3463d070f96728
--- /dev/null
+++ b/prompts/Surface Normal/4_surface_norm.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95d833513daa9de65dc68f3eeb4e589ec43d2e42e7305c789f4235f6213c5338
+size 46127
diff --git a/prompts/Surface Normal/5.png b/prompts/Surface Normal/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..487324656c94ed5a6155c20d5fc19b57e3997ddd
--- /dev/null
+++ b/prompts/Surface Normal/5.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63b4ce2ffe207fa64f3dec7dc470a035ada964f8192ffdb11eca8f1f2522bd8b
+size 21984
diff --git a/prompts/Surface Normal/5_surface_norm.png b/prompts/Surface Normal/5_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..d30c9ee11141aceff4713cc580a857d876a7c9f8
--- /dev/null
+++ b/prompts/Surface Normal/5_surface_norm.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9898e670f6c5e5bdd4bb4836b7ad9632ea90e2e91d57a524f22214c9cf6ef2cc
+size 34050
diff --git a/prompts/Surface Normal/6.png b/prompts/Surface Normal/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..07a8a60c24d4097ab5258231c053dfbbd840f252
--- /dev/null
+++ b/prompts/Surface Normal/6.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3bdba39d41de867cc1ec1441aea356e8d838ee20fb434b05f2021ae2abf04547
+size 30704
diff --git a/prompts/Surface Normal/6_surface_norm.png b/prompts/Surface Normal/6_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..8188e4041605d668f198a92ba661882417302e51
--- /dev/null
+++ b/prompts/Surface Normal/6_surface_norm.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b87d3709a62bfced6176861c06102bb1fe45583e92fadacb5a38562bba34339
+size 45259
diff --git a/prompts/Surface Normal/7.png b/prompts/Surface Normal/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5be190943366773f6ec479e94f1935257a60ccc6
--- /dev/null
+++ b/prompts/Surface Normal/7.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbb21133156699e366a5af8333146fcbe67c0692e26672d11baffce94c5938f7
+size 49450
diff --git a/prompts/Surface Normal/7_surface_norm.png b/prompts/Surface Normal/7_surface_norm.png
new file mode 100644
index 0000000000000000000000000000000000000000..d8955586a49136487bbd689017efa9d73ad2544a
--- /dev/null
+++ b/prompts/Surface Normal/7_surface_norm.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84b9147bdcfa2c56c05bafc73d025efc251ba9b687dc9a83a00e2f9633c7a172
+size 44708
diff --git a/prompts/Surface Normal/8.png b/prompts/Surface Normal/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6c4740401f12c840428994c0624ca8b29e3269d
--- /dev/null
+++ b/prompts/Surface Normal/8.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b2bfe546c36f3110d80f7ff58f133df77b02db94b9ac3b5a7fea30e97edba38
+size 50877
diff --git a/prompts/Synthetic Object Replication/Slide1.png b/prompts/Synthetic Object Replication/Slide1.png
new file mode 100644
index 0000000000000000000000000000000000000000..aaa3ff562f23d81a718d76104668bd190ea6ebc7
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide1.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d00ed3bc13809d3c1d54fbf851c334e3a0e6f088a4c7420c73af078e5517e05
+size 229453
diff --git a/prompts/Synthetic Object Replication/Slide2.png b/prompts/Synthetic Object Replication/Slide2.png
new file mode 100644
index 0000000000000000000000000000000000000000..3de22bebc4cd4cc363f223ef323d48b0c75b3719
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide2.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29aeaa8d76977173a58a957c463cb794a5a8d780b092706836fea4832b787f35
+size 235912
diff --git a/prompts/Synthetic Object Replication/Slide3.png b/prompts/Synthetic Object Replication/Slide3.png
new file mode 100644
index 0000000000000000000000000000000000000000..e45d49883dc9528572525fc3b7e3ffff941eb68f
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide3.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0093baa4b93b0f75872637787cda5412fc371195b3f91188c1bf6e03dfdd49d2
+size 233967
diff --git a/prompts/Synthetic Object Replication/Slide4.png b/prompts/Synthetic Object Replication/Slide4.png
new file mode 100644
index 0000000000000000000000000000000000000000..97991be222d6e76c92a53ffb76a3e307fff53641
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide4.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcce41ad98b8a7a3b7f1153b48813e936a06590bb716c184fbe012045063f815
+size 238995
diff --git a/prompts/Synthetic Object Replication/Slide5.png b/prompts/Synthetic Object Replication/Slide5.png
new file mode 100644
index 0000000000000000000000000000000000000000..22565ba0e7b64c76e00fbb0519570f6ac35980b9
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide5.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a686955f717be1a514f68786268394eea86485b3aba333913d62c6548c8417d
+size 236567
diff --git a/prompts/Synthetic Object Replication/Slide6.png b/prompts/Synthetic Object Replication/Slide6.png
new file mode 100644
index 0000000000000000000000000000000000000000..6bcb3367c8db69f98a63a9995e5147458d66ca11
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide6.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f1b4e759376a1c26a837ff59a918ffd85a1689cb25064b39e8f7976aacb3ba7
+size 240757
diff --git a/prompts/Synthetic Object Replication/Slide7.png b/prompts/Synthetic Object Replication/Slide7.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5d7852cadabb41ae06e051d81d864916ed35abd
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide7.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1132a1a1f41a406069b628dcf2d81995ef48bc0a493d3e501c266c617d722499
+size 238630
diff --git a/prompts/Synthetic Object Replication/Slide8.png b/prompts/Synthetic Object Replication/Slide8.png
new file mode 100644
index 0000000000000000000000000000000000000000..953743749cbd61eaa0d6371dadee67cccf4156a4
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide8.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de4e4c3b27e1d7691b697704c6506081e2cd5903929dca66deab53c4be08d886
+size 235972
diff --git a/prompts/Synthetic Object Replication/Slide9.png b/prompts/Synthetic Object Replication/Slide9.png
new file mode 100644
index 0000000000000000000000000000000000000000..7430d331a7cb500dbe9852d0e5ad4995b3d132d8
--- /dev/null
+++ b/prompts/Synthetic Object Replication/Slide9.png	
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2dbf3f768e24d4e46c04bad22f41798d717a15c566ad3d6b0d0813e23feb995f
+size 233023
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..18ebbfe08f54d5ba0a5ca1c4f58220ce4da99c4d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+numpy
+scipy
+matplotlib
+seaborn
+jupyter
+tqdm
+pillow
+--extra-index-url https://download.pytorch.org/whl/cu118
+transformers==4.34.1
+torch==2.0.1
+einops
+absl-py
+ml_collections
+requests
+mlxu==0.1.11
+pydantic
+fastapi
+uvicorn
+gradio
+fastapi
+uvicorn
+opencv-python-headless
+scikit-video
+scikit-image
+natsort
+accelerate
diff --git a/torch_vqvae_model.py b/torch_vqvae_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe6053b36c5410da5f3a914e1a6644f840734311
--- /dev/null
+++ b/torch_vqvae_model.py
@@ -0,0 +1,257 @@
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+from einops.layers.torch import Rearrange
+
+
+def normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+def swish(x):
+    return x*torch.sigmoid(x)
+
+class ResBlock(nn.Module):
+    def __init__(self, in_channels, out_channels=None, activation_fn="relu"):
+        super(ResBlock, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = in_channels if out_channels is None else out_channels
+        self.norm1 = normalize(in_channels)
+        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
+        self.norm2 = normalize(out_channels)
+        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
+        if self.in_channels != self.out_channels:
+            self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
+        self.activation_fn = activation_fn
+        if activation_fn=="relu":
+            self.actn = nn.ReLU()
+
+
+    def forward(self, x_in):
+        x = x_in
+        x = self.norm1(x)
+        if self.activation_fn=="relu":
+            x = self.actn(x)
+        elif self.activation_fn=="swish":
+            x = swish(x)
+        x = self.conv1(x)
+        x = self.norm2(x)
+        if self.activation_fn=="relu":
+            x = self.actn(x)
+        elif self.activation_fn=="swish":
+            x = swish(x)
+        x = self.conv2(x)
+        if self.in_channels != self.out_channels:
+            x_in = self.conv_out(x_in)
+
+        return x + x_in
+
+class Encoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+
+        self.filters = 128
+        self.num_res_blocks =  2
+        self.ch_mult = [1,1,2,2,4]
+        self.in_ch_mult = (1,)+tuple(self.ch_mult)
+        self.embedding_dim = 32
+        self.conv_downsample =  False
+
+        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False)
+        blocks = []
+        for i in range(len(self.ch_mult)):
+            block_in_ch = self.filters  * self.in_ch_mult[i]
+            block_out_ch = self.filters  * self.ch_mult[i]
+            for _ in range(self.num_res_blocks):
+                blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+                block_in_ch = block_out_ch
+        for _ in range(self.num_res_blocks):
+            blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+        self.norm1 = normalize(block_in_ch)
+        self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0)
+        self.blocks = nn.ModuleList(blocks)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        for i in range(len(self.ch_mult)):
+            for j in range(self.num_res_blocks):
+                x = self.blocks[i*2+j](x)
+
+            if i < len(self.ch_mult) -1:
+                x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2))
+
+        x = self.blocks[-2](x)
+        x = self.blocks[-1](x)
+
+        x = self.norm1(x)
+        x = swish(x)
+        x = self.conv2(x)
+        return x
+
+class VectorQuantizer(nn.Module):
+    def __init__(self, codebook_size=8192, emb_dim=32, beta=None):
+        super(VectorQuantizer, self).__init__()
+        self.codebook_size = codebook_size  # number of embeddings
+        self.emb_dim = emb_dim  # dimension of embedding
+        self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+        self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+        self.beta=0.0
+        self.z_dim = emb_dim
+
+    def forward(self, z):
+        # preprocess
+
+        b, c, h, w = z.size()
+        flatten = z.permute(0, 2, 3, 1).reshape(-1, c)
+        codebook = self.embedding.weight
+        with torch.no_grad():
+            tokens = torch.cdist(flatten, codebook).argmin(dim=1)
+        quantized = F.embedding(tokens,
+                                codebook).view(b, h, w, c).permute(0, 3, 1, 2)
+
+        # compute loss
+        codebook_loss = F.mse_loss(quantized, z.detach())
+        commitment_loss = F.mse_loss(quantized.detach(), z)
+        loss = codebook_loss + self.beta * commitment_loss
+
+        # perplexity
+        counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype)
+        # dist.all_reduce(counts)
+        p = counts / counts.sum()
+        perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10)))
+
+        # postprocess
+        tokens = tokens.view(b, h, w)
+        quantized = z + (quantized - z).detach()
+
+        # quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c))
+
+        return quantized, tokens, loss, perplexity
+
+
+    def get_codebook_feat(self, indices, shape=None):
+        # input indices: batch*token_num -> (batch*token_num)*1
+        # shape: batch, height, width, channel
+        indices = indices.view(-1,1)
+        min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+        min_encodings.scatter_(1, indices, 1)
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+        if shape is not None:  # reshape back to match original input shape
+            z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+        return z_q
+
+
+class Decoder(nn.Module):
+    def __init__(self,):
+        super().__init__()
+        self.filters = 128
+        self.num_res_blocks =  2
+        self.ch_mult = [1,1,2,2,4]
+        self.in_ch_mult = (1,)+tuple(self.ch_mult)
+        self.embedding_dim =32
+        self.out_channels = 3
+        self.in_channels = self.embedding_dim
+        self.conv_downsample =  False
+
+        self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1)
+        blocks = []
+        block_in_ch = self.filters * self.ch_mult[-1]
+        block_out_ch = self.filters * self.ch_mult[-1]
+        #blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+        for _ in range(self.num_res_blocks):
+            blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+        upsample_conv_layers = []
+        for i in reversed(range(len(self.ch_mult))):
+            block_out_ch = self.filters * self.ch_mult[i]
+            for _ in range(self.num_res_blocks):
+                blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
+                block_in_ch = block_out_ch
+            if i > 0:
+                upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1))
+
+        self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2)
+        self.norm1 = normalize(block_in_ch)
+        # self.act_fn
+        self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)
+        self.blocks = nn.ModuleList(blocks)
+        self.up_convs = nn.ModuleList(upsample_conv_layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.blocks[0](x)
+        x = self.blocks[1](x)
+        for i in range(len(self.ch_mult)):
+            for j in range(self.num_res_blocks):
+                x = self.blocks[2+i*2+j](x)
+            if i < len(self.ch_mult)-1:
+                x = self.up_convs[i](x)
+                #print("pre: x.size()",x.size())
+                x = x.permute(0,2,3,1)
+                x = self.upsample(x)
+                x = x.permute(0,3,1,2)
+                #print("post: x.size()", x.size())
+        x = self.norm1(x)
+        x = swish(x)
+        x = self.conv6(x)
+        return x
+
+
+class VQVAE(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        self.encoder = Encoder()
+        self.quantizer = VectorQuantizer()
+        self.decoder = Decoder()
+
+    def forward(self, x):
+        x = self.encoder(x)
+        quant,tokens, loss, perplexity = self.quantizer(x)
+        x = self.decoder(quant)
+        return x
+
+    def tokenize(self, x):
+        batch_shape = x.shape[:-3]
+        x = x.reshape(-1, *x.shape[-3:])
+        x = self.encoder(x)
+        quant,tokens, loss, perplexity = self.quantizer(x)
+        return tokens.reshape(*batch_shape, *tokens.shape[1:])
+
+    def decode(self, tokens):
+        tokens = einops.rearrange(tokens, 'b ... -> b (...)')
+        b = tokens.shape[0]
+        if tokens.shape[-1] == 256:
+            hw = 16
+        elif tokens.shape[-1] == 224:
+            hw = 14
+        else:
+            raise ValueError("Invalid tokens shape")
+        quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32))
+        x = self.decoder(quant)
+        return x
+
+
+class VAEDecoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        self.quantizer = VectorQuantizer()
+        self.decoder = Decoder()
+
+    def forward(self, x):
+        quant = self.quantizer.get_codebook_feat(x,(1,14,14,32))
+        x = self.decoder(quant)
+        return x
+
+
+def get_tokenizer():
+    checkpoint_path = os.path.join(
+        os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth"
+    )
+    torch_state_dict = torch.load(checkpoint_path)
+    net = VQVAE()
+    net.load_state_dict(torch_state_dict)
+    return net
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b059a22e2fa8786ebba48dab800318b362507715
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,296 @@
+import os
+from multiprocessing import Pool
+import numpy as np
+import random
+from PIL import Image
+import re
+import cv2
+import glob
+from natsort import natsorted
+
+
+class MultiProcessImageSaver(object):
+
+    def __init__(self, n_workers=1):
+        self.pool = Pool(n_workers)
+
+    def __call__(self, images, output_files, resizes=None):
+        if resizes is None:
+            resizes = [None for _ in range(len(images))]
+        return self.pool.imap(
+            self.save_image,
+            zip(images, output_files, resizes),
+        )
+
+    def close(self):
+        self.pool.close()
+        self.pool.join()
+
+    @staticmethod
+    def save_image(args):
+        image, filename, resize = args
+        image = Image.fromarray(image)
+        if resize is not None:
+            image = image.resize(tuple(resize))
+        image.save(filename)
+
+
+def list_dir_with_full_path(path):
+    return [os.path.join(path, f) for f in os.listdir(path)]
+
+
+def find_all_files_in_dir(path):
+    files = []
+    for root, _, files in os.walk(path):
+        for file in files:
+            files.append(os.path.join(root, file))
+    return files
+
+
+def is_image(path):
+    return (
+        path.endswith('.jpg')
+        or path.endswith('.png')
+        or path.endswith('.jpeg')
+        or path.endswith('.JPG')
+        or path.endswith('.PNG')
+        or path.endswith('.JPEG')
+    )
+
+
+def is_video(path):
+    return (
+        path.endswith('.mp4')
+        or path.endswith('.avi')
+        or path.endswith('.MP4')
+        or path.endswith('.AVI')
+        or path.endswith('.webm')
+        or path.endswith('.WEBM')
+        or path.endswith('.mkv')
+        or path.endswith('.MVK')
+    )
+
+
+def random_square_crop(img, random_generator=None):
+    # If no random generator is provided, use numpy's default
+    if random_generator is None:
+        random_generator = np.random.default_rng()
+
+    # Get the width and height of the image
+    width, height = img.size
+
+    # Determine the shorter side
+    min_size = min(width, height)
+
+    # Randomly determine the starting x and y coordinates for the crop
+    if width > height:
+        left = random_generator.integers(0, width - min_size)
+        upper = 0
+    else:
+        left = 0
+        upper = random_generator.integers(0, height - min_size)
+
+    # Calculate the ending x and y coordinates for the crop
+    right = left + min_size
+    lower = upper + min_size
+
+    # Crop the image
+    return img.crop((left, upper, right, lower))
+
+
+def read_image_to_tensor(path, center_crop=1.0):
+    pil_im = Image.open(path).convert('RGB')
+    if center_crop < 1.0:
+        width, height = pil_im.size
+        pil_im = pil_im.crop((
+            int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2),
+            int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2),
+        ))
+    input_img = pil_im.resize((256, 256))
+    input_img = np.array(input_img) / 255.0
+    input_img = input_img.astype(np.float32)
+    return input_img
+
+
+def match_mulitple_path(root_dir, regex):
+    videos = []
+    for root, _, files in os.walk(root_dir):
+        for file in files:
+            videos.append(os.path.join(root, file))
+
+    videos = [v for v in videos if not v.split('/')[-1].startswith('.')]
+
+    grouped_path = {}
+    for r in regex:
+        r = re.compile(r)
+        for v in videos:
+            matched = r.findall(v)
+            if len(matched) > 0:
+                groups = matched[0]
+                if groups not in grouped_path:
+                    grouped_path[groups] = []
+                grouped_path[groups].append(v)
+
+    grouped_path = {
+        k: tuple(v) for k, v in grouped_path.items()
+        if len(v) == len(regex)
+    }
+    return list(grouped_path.values())
+
+
+def randomly_subsample_frame_indices(length, n_frames, max_stride=30, random_start=True):
+    assert length >= n_frames
+    max_stride = min(
+        (length - 1) // (n_frames - 1),
+        max_stride
+    )
+    stride = np.random.randint(1, max_stride + 1)
+    if random_start:
+        start = np.random.randint(0, length - (n_frames - 1) * stride)
+    else:
+        start = 0
+    return np.arange(n_frames) * stride + start
+
+
+def read_frames_from_dir(dir_path, n_frames, stride, random_start=True, center_crop=1.0):
+    files = [os.path.join(dir_path, x) for x in os.listdir(dir_path)]
+    files = natsorted([x for x in files if is_image(x)])
+
+    total_frames = len(files)
+
+    if total_frames < n_frames:
+        return None
+
+    max_stride = (total_frames - 1) // (n_frames - 1)
+    stride = min(max_stride, stride)
+
+    if random_start:
+        start = np.random.randint(0, total_frames - (n_frames - 1) * stride)
+    else:
+        start = 0
+    frame_indices = np.arange(n_frames) * stride + start
+
+    frames = []
+    for frame_index in sorted(frame_indices):
+        # Check if the frame_index is valid
+        frames.append(read_image_to_tensor(files[frame_index], center_crop=center_crop))
+    if len(frames) < n_frames:
+        return None
+    frames = np.stack(frames, axis=0)
+    return frames
+
+
+def read_frames_from_video(video_path, n_frames, stride, random_start=True, center_crop=1.0):
+
+    frames = []
+    cap = cv2.VideoCapture(video_path)
+
+    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+    if total_frames < n_frames:
+        cap.release()
+        return None
+
+    max_stride = (total_frames - 1) // (n_frames - 1)
+    stride = min(max_stride, stride)
+
+    if random_start:
+        start = np.random.randint(0, total_frames - (n_frames - 1) * stride)
+    else:
+        start = 0
+    frame_indices = np.arange(n_frames) * stride + start
+
+    for frame_index in sorted(frame_indices):
+        # Check if the frame_index is valid
+        if 0 <= frame_index < total_frames:
+            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
+            ret, frame = cap.read()
+            if ret:
+                if center_crop < 1.0:
+                    height, width, _ = frame.shape
+                    frame = frame[
+                        int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2),
+                        int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2),
+                        :
+                    ]
+                frame = cv2.resize(frame, (256, 256))
+
+                frames.append(frame)
+
+        else:
+            print(f"Frame index {frame_index} is out of bounds. Skipping...")
+
+    cap.release()
+    if len(frames) < n_frames:
+        return None
+    frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
+
+    # From BGR to RGB
+    return np.stack(
+        [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
+    )
+
+
+def read_all_frames_from_video(video_path, center_crop=1.0):
+
+    frames = []
+    cap = cv2.VideoCapture(video_path)
+
+    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+
+    for frame_index in range(total_frames):
+        # Check if the frame_index is valid
+        if 0 <= frame_index < total_frames:
+            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
+            ret, frame = cap.read()
+            if ret:
+                if center_crop < 1.0:
+                    height, width, _ = frame.shape
+                    frame = frame[
+                        int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2),
+                        int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2),
+                        :
+                    ]
+                frames.append(cv2.resize(frame, (256, 256)))
+        else:
+            print(f"Frame index {frame_index} is out of bounds. Skipping...")
+
+    cap.release()
+    if len(frames) == 0:
+        return None
+    frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
+    # From BGR to RGB
+    return np.stack(
+        [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
+    )
+
+
+def read_max_span_frames_from_video(video_path, n_frames):
+    frames = []
+    cap = cv2.VideoCapture(video_path)
+
+    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+    if total_frames < n_frames:
+        cap.release()
+        return None
+    stride = (total_frames - 1) // (n_frames - 1)
+    frame_indices = np.arange(n_frames) * stride
+
+    frames = []
+    for frame_index in frame_indices:
+        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
+        ret, frame = cap.read()
+        if ret:
+            frames.append(cv2.resize(frame, (256, 256)))
+
+    cap.release()
+    if len(frames) < n_frames:
+        return None
+
+    frames = np.stack(frames, axis=0).astype(np.float32) / 255.0
+    # From BGR to RGB
+    return np.stack(
+        [frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1
+    )
+
diff --git a/vqvae/.DS_Store b/vqvae/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..38734ca2de71d90578b12a191d5ff30a57f26d5c
Binary files /dev/null and b/vqvae/.DS_Store differ
diff --git a/vqvae/__init__.py b/vqvae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9b32aee9a68b1192ae0be7214ca92f35defd717
--- /dev/null
+++ b/vqvae/__init__.py
@@ -0,0 +1,25 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__version__ = "0.0.1"
+
+# from .modeling_ema import EMAModel
+# from .modeling_maskgit_vqgan import MaskGitVQGAN
+# from .modeling_movq import MOVQ
+# from .modeling_paella_vq import PaellaVQModel
+# from .modeling_utils import VQGANModel
+# from .modeling_transformer import MaskGitTransformer, MaskGiTUViT
+# from .pipeline_muse import PipelineMuse, PipelineMuseInpainting
+# from .sampling import get_mask_chedule
diff --git a/vqvae/logging.py b/vqvae/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..65814a82380e47e54434c4be97026141772f7298
--- /dev/null
+++ b/vqvae/logging.py
@@ -0,0 +1,338 @@
+# coding=utf-8
+# Copyright 2023 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Logging utilities."""
+
+import logging
+import os
+import sys
+import threading
+from logging import CRITICAL  # NOQA
+from logging import DEBUG  # NOQA
+from logging import ERROR  # NOQA
+from logging import FATAL  # NOQA
+from logging import INFO  # NOQA
+from logging import NOTSET  # NOQA
+from logging import WARN  # NOQA
+from logging import WARNING  # NOQA
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+    "debug": logging.DEBUG,
+    "info": logging.INFO,
+    "warning": logging.WARNING,
+    "error": logging.ERROR,
+    "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+    """
+    If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+    not - fall back to `_default_log_level`
+    """
+    env_level_str = os.getenv("muse_VERBOSITY", None)
+    if env_level_str:
+        if env_level_str in log_levels:
+            return log_levels[env_level_str]
+        else:
+            logging.getLogger().warning(
+                f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
+            )
+    return _default_log_level
+
+
+def _get_library_name() -> str:
+    return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+    return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+    global _default_handler
+
+    with _lock:
+        if _default_handler:
+            # This library has already configured the library root logger.
+            return
+        _default_handler = logging.StreamHandler()  # Set sys.stderr as stream.
+        _default_handler.flush = sys.stderr.flush
+
+        # Apply our default configuration to the library root logger.
+        library_root_logger = _get_library_root_logger()
+        library_root_logger.addHandler(_default_handler)
+        library_root_logger.setLevel(_get_default_logging_level())
+        library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+    global _default_handler
+
+    with _lock:
+        if not _default_handler:
+            return
+
+        library_root_logger = _get_library_root_logger()
+        library_root_logger.removeHandler(_default_handler)
+        library_root_logger.setLevel(logging.NOTSET)
+        _default_handler = None
+
+
+def get_log_levels_dict():
+    return log_levels
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+    """
+    Return a logger with the specified name.
+
+    This function is not supposed to be directly accessed unless you are writing a custom muse module.
+    """
+
+    if name is None:
+        name = _get_library_name()
+
+    _configure_library_root_logger()
+    return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+    """
+    Return the current level for the 🤗 muse' root logger as an int.
+
+    Returns:
+        `int`: The logging level.
+
+    <Tip>
+
+    🤗 muse has following logging levels:
+
+    - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
+    - 40: `muse.logging.ERROR`
+    - 30: `muse.logging.WARNING` or `muse.logging.WARN`
+    - 20: `muse.logging.INFO`
+    - 10: `muse.logging.DEBUG`
+
+    </Tip>"""
+
+    _configure_library_root_logger()
+    return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+    """
+    Set the verbosity level for the 🤗 muse' root logger.
+
+    Args:
+        verbosity (`int`):
+            Logging level, e.g., one of:
+
+            - `muse.logging.CRITICAL` or `muse.logging.FATAL`
+            - `muse.logging.ERROR`
+            - `muse.logging.WARNING` or `muse.logging.WARN`
+            - `muse.logging.INFO`
+            - `muse.logging.DEBUG`
+    """
+
+    _configure_library_root_logger()
+    _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+    """Set the verbosity to the `INFO` level."""
+    return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+    """Set the verbosity to the `WARNING` level."""
+    return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+    """Set the verbosity to the `DEBUG` level."""
+    return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+    """Set the verbosity to the `ERROR` level."""
+    return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+    """Disable the default handler of the HuggingFace muse' root logger."""
+
+    _configure_library_root_logger()
+
+    assert _default_handler is not None
+    _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+    """Enable the default handler of the HuggingFace muse' root logger."""
+
+    _configure_library_root_logger()
+
+    assert _default_handler is not None
+    _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+    """adds a handler to the HuggingFace muse' root logger."""
+
+    _configure_library_root_logger()
+
+    assert handler is not None
+    _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+    """removes given handler from the HuggingFace muse' root logger."""
+
+    _configure_library_root_logger()
+
+    assert handler is not None and handler not in _get_library_root_logger().handlers
+    _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+    """
+    Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+    """
+
+    _configure_library_root_logger()
+    _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+    """
+    Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
+    double logging if the root logger has been configured.
+    """
+
+    _configure_library_root_logger()
+    _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+    """
+    Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
+    ```
+        [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+    ```
+    All handlers currently bound to the root logger are affected by this method.
+    """
+    handlers = _get_library_root_logger().handlers
+
+    for handler in handlers:
+        formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+        handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+    """
+    Resets the formatting for HuggingFace muse' loggers.
+
+    All handlers currently bound to the root logger are affected by this method.
+    """
+    handlers = _get_library_root_logger().handlers
+
+    for handler in handlers:
+        handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+    """
+    This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
+    warning will not be printed
+    """
+    no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
+    if no_advisory_warnings:
+        return
+    self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+class EmptyTqdm:
+    """Dummy tqdm which doesn't do anything."""
+
+    def __init__(self, *args, **kwargs):  # pylint: disable=unused-argument
+        self._iterator = args[0] if args else None
+
+    def __iter__(self):
+        return iter(self._iterator)
+
+    def __getattr__(self, _):
+        """Return empty function."""
+
+        def empty_fn(*args, **kwargs):  # pylint: disable=unused-argument
+            return
+
+        return empty_fn
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type_, value, traceback):
+        return
+
+
+class _tqdm_cls:
+    def __call__(self, *args, **kwargs):
+        if _tqdm_active:
+            return tqdm_lib.tqdm(*args, **kwargs)
+        else:
+            return EmptyTqdm(*args, **kwargs)
+
+    def set_lock(self, *args, **kwargs):
+        self._lock = None
+        if _tqdm_active:
+            return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+    def get_lock(self):
+        if _tqdm_active:
+            return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+    """Return a boolean indicating whether tqdm progress bars are enabled."""
+    global _tqdm_active
+    return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+    """Enable tqdm progress bar."""
+    global _tqdm_active
+    _tqdm_active = True
+
+
+def disable_progress_bar():
+    """Disable tqdm progress bar."""
+    global _tqdm_active
+    _tqdm_active = False
diff --git a/vqvae/modeling_utils.py b/vqvae/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdbdb68a7c2f154c670fe40950c035fad06e4691
--- /dev/null
+++ b/vqvae/modeling_utils.py
@@ -0,0 +1,1171 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import inspect
+import json
+import os
+from collections import OrderedDict
+from functools import partial
+from pathlib import PosixPath
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import accelerate
+import numpy as np
+import torch
+from accelerate.utils import set_module_tensor_to_device
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import (
+    EntryNotFoundError,
+    RepositoryNotFoundError,
+    RevisionNotFoundError,
+)
+from requests import HTTPError
+from torch import Tensor, device
+
+from . import __version__, logging
+
+logger = logging.get_logger(__name__)
+
+
+hf_cache_home = os.path.expanduser(
+    os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
+)
+default_cache_path = os.path.join(hf_cache_home, "muse")
+
+
+CONFIG_NAME = "config.json"
+WEIGHTS_NAME = "pytorch_model.bin"
+SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+MUSE_CACHE = default_cache_path
+MUSE_DYNAMIC_MODULE_NAME = "myse_modules"
+HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
+
+
+_LOW_CPU_MEM_USAGE_DEFAULT = True
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+    try:
+        return next(parameter.parameters()).device
+    except StopIteration:
+        # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+            return tuples
+
+        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+        first_tuple = next(gen)
+        return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+    try:
+        return next(parameter.parameters()).dtype
+    except StopIteration:
+        # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+            return tuples
+
+        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+        first_tuple = next(gen)
+        return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+    """
+    Reads a checkpoint file, returning properly formatted errors if they arise.
+    """
+    try:
+        if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
+            return torch.load(checkpoint_file, map_location="cpu")
+    except Exception as e:
+        try:
+            with open(checkpoint_file) as f:
+                if f.read().startswith("version"):
+                    raise OSError(
+                        "You seem to have cloned a repository without having git-lfs installed. Please install "
+                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+                        "you cloned."
+                    )
+                else:
+                    raise ValueError(
+                        f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+                        "model. Make sure you have saved the model properly."
+                    ) from e
+        except (UnicodeDecodeError, ValueError):
+            raise OSError(
+                f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
+                f"at '{checkpoint_file}'. "
+                "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+            )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+    # Convert old format to new format if needed from a PyTorch state_dict
+    # copy state_dict so _load_from_state_dict can modify it
+    state_dict = state_dict.copy()
+    error_msgs = []
+
+    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+    # so we need to apply the function recursively.
+    def load(module: torch.nn.Module, prefix=""):
+        args = (state_dict, prefix, {}, True, [], [], error_msgs)
+        module._load_from_state_dict(*args)
+
+        for name, child in module._modules.items():
+            if child is not None:
+                load(child, prefix + name + ".")
+
+    load(model_to_load)
+
+    return error_msgs
+
+
+def _get_model_file(
+    pretrained_model_name_or_path,
+    *,
+    weights_name,
+    subfolder,
+    cache_dir,
+    force_download,
+    proxies,
+    resume_download,
+    local_files_only,
+    use_auth_token,
+    user_agent,
+    revision,
+):
+    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+    if os.path.isfile(pretrained_model_name_or_path):
+        return pretrained_model_name_or_path
+    elif os.path.isdir(pretrained_model_name_or_path):
+        if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
+            # Load from a PyTorch checkpoint
+            model_file = os.path.join(pretrained_model_name_or_path, weights_name)
+            return model_file
+        elif subfolder is not None and os.path.isfile(
+            os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+        ):
+            model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+            return model_file
+        else:
+            raise EnvironmentError(
+                f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
+            )
+    else:
+        try:
+            # Load from URL or cache if already cached
+            model_file = hf_hub_download(
+                pretrained_model_name_or_path,
+                filename=weights_name,
+                cache_dir=cache_dir,
+                force_download=force_download,
+                proxies=proxies,
+                resume_download=resume_download,
+                local_files_only=local_files_only,
+                use_auth_token=use_auth_token,
+                user_agent=user_agent,
+                subfolder=subfolder,
+                revision=revision,
+            )
+            return model_file
+
+        except RepositoryNotFoundError:
+            raise EnvironmentError(
+                f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+                "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+                "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+                "login`."
+            )
+        except RevisionNotFoundError:
+            raise EnvironmentError(
+                f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+                "this model name. Check the model page at "
+                f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+            )
+        except EntryNotFoundError:
+            raise EnvironmentError(
+                f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
+            )
+        except HTTPError as err:
+            raise EnvironmentError(
+                f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
+            )
+        except ValueError:
+            raise EnvironmentError(
+                f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+                f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+                f" directory containing a file named {weights_name} or"
+                " \nCheckout your internet connection or see how to run the library in"
+                " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+            )
+        except EnvironmentError:
+            raise EnvironmentError(
+                f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+                "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+                f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+                f"containing a file named {weights_name}"
+            )
+
+
+class ModelMixin(torch.nn.Module):
+    r"""
+    Base class for all models.
+
+    [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+    and saving models.
+
+        - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+          [`~models.ModelMixin.save_pretrained`].
+    """
+    config_name = CONFIG_NAME
+    _automatically_saved_args = ["_version", "_class_name", "_name_or_path"]
+    _supports_gradient_checkpointing = False
+
+    def __init__(self):
+        super().__init__()
+
+    @property
+    def is_gradient_checkpointing(self) -> bool:
+        """
+        Whether gradient checkpointing is activated for this model or not.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+        """
+        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+    def enable_gradient_checkpointing(self):
+        """
+        Activates gradient checkpointing for the current model.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+        """
+        if not self._supports_gradient_checkpointing:
+            raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+        self.apply(partial(self._set_gradient_checkpointing, value=True))
+
+    def disable_gradient_checkpointing(self):
+        """
+        Deactivates gradient checkpointing for the current model.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+        """
+        if self._supports_gradient_checkpointing:
+            self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+    def set_use_memory_efficient_attention_xformers(
+        self, valid: bool, attention_op: Optional[Callable] = None
+    ) -> None:
+        # Recursively walk through all the children.
+        # Any children which exposes the set_use_memory_efficient_attention_xformers method
+        # gets the message
+        def fn_recursive_set_mem_eff(module: torch.nn.Module):
+            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+                module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+
+            for child in module.children():
+                fn_recursive_set_mem_eff(child)
+
+        for module in self.children():
+            if isinstance(module, torch.nn.Module):
+                fn_recursive_set_mem_eff(module)
+
+    def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
+        r"""
+        Enable memory efficient attention as implemented in xformers.
+
+        When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+        time. Speed up at training time is not guaranteed.
+
+        Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+        is used.
+
+        Parameters:
+            attention_op (`Callable`, *optional*):
+                Override the default `None` operator for use as `op` argument to the
+                [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+                function of xFormers.
+
+        Examples:
+
+        ```py
+        >>> import torch
+        >>> from diffusers import UNet2DConditionModel
+        >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+        >>> model = UNet2DConditionModel.from_pretrained(
+        ...     "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
+        ... )
+        >>> model = model.to("cuda")
+        >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+        ```
+        """
+        self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+    def disable_xformers_memory_efficient_attention(self):
+        r"""
+        Disable memory efficient attention as implemented in xformers.
+        """
+        self.set_use_memory_efficient_attention_xformers(False)
+
+    def save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike],
+        is_main_process: bool = True,
+        save_function: Callable = None,
+        state_dict: Optional[Dict[str, torch.Tensor]] = None,
+    ):
+        """
+        Save a model and its configuration file to a directory, so that it can be re-loaded using the
+        `[`~models.ModelMixin.from_pretrained`]` class method.
+
+        Arguments:
+            save_directory (`str` or `os.PathLike`):
+                Directory to which to save. Will be created if it doesn't exist.
+            is_main_process (`bool`, *optional*, defaults to `True`):
+                Whether the process calling this is the main process or not. Useful when in distributed training like
+                TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+                the main process to avoid race conditions.
+            save_function (`Callable`):
+                The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+                need to replace `torch.save` by another method. Can be configured with the environment variable
+                `DIFFUSERS_SAVE_MODE`.
+            state_dict (`Dict[str, torch.Tensor]`, *optional*):
+                The state dictionary to save. If `None`, the model's state dictionary will be saved.
+        """
+        if os.path.isfile(save_directory):
+            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+            return
+
+        if save_function is None:
+            save_function = torch.save
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        model_to_save = self
+
+        # Attach architecture to the config
+        # Save the config
+        if is_main_process:
+            model_to_save.save_config(save_directory)
+
+        # Save the model
+        if state_dict is None:
+            state_dict = model_to_save.state_dict()
+
+        weights_name = WEIGHTS_NAME
+
+        # Save the model
+        save_function(state_dict, os.path.join(save_directory, weights_name))
+
+        logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+        r"""
+        Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+        the model, you should first set it back in training mode with `model.train()`.
+
+        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+        task.
+
+        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+        weights are discarded.
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+                Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                      Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+                    - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+                      `./my_model_directory/`.
+
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory in which a downloaded pretrained model configuration should be cached if the
+                standard cache should not be used.
+            torch_dtype (`str` or `torch.dtype`, *optional*):
+                Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+                will be automatically derived from the model's weights.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+                file exists.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            output_loading_info(`bool`, *optional*, defaults to `False`):
+                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+            local_files_only(`bool`, *optional*, defaults to `False`):
+                Whether or not to only look at local files (i.e., do not try to download the model).
+            use_auth_token (`str` or *bool*, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+                when running `diffusers-cli login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+            from_flax (`bool`, *optional*, defaults to `False`):
+                Load the model weights from a Flax checkpoint save file.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo (either remote in
+                huggingface.co or downloaded locally), you can specify the folder name here.
+
+            mirror (`str`, *optional*):
+                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+                Please refer to the mirror site for more information.
+            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+                A map that specifies where each submodule should go. It doesn't need to be refined to each
+                parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+                same device.
+
+                To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+                more information about each option see [designing a device
+                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+                Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+                also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+                model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+                setting this argument to `True` will raise an error.
+
+        <Tip>
+
+         It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+         models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+        </Tip>
+
+        <Tip>
+
+        Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+        this method in a firewalled environment.
+
+        </Tip>
+
+        """
+        cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
+        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", False)
+        proxies = kwargs.pop("proxies", None)
+        output_loading_info = kwargs.pop("output_loading_info", False)
+        local_files_only = kwargs.pop("local_files_only", False)  # TODO
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        revision = kwargs.pop("revision", None)
+        torch_dtype = kwargs.pop("torch_dtype", None)
+        subfolder = kwargs.pop("subfolder", None)
+        device_map = kwargs.pop("device_map", None)
+        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+        if low_cpu_mem_usage is False and device_map is not None:
+            raise ValueError(
+                f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
+                " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
+            )
+
+        user_agent = {
+            "diffusers": __version__,
+            "file_type": "model",
+            "framework": "pytorch",
+        }
+
+        # Load config if we don't provide a configuration
+        config_path = pretrained_model_name_or_path
+
+        # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+        # Load model
+
+        model_file = None
+
+        if model_file is None:
+            model_file = _get_model_file(
+                pretrained_model_name_or_path,
+                weights_name=WEIGHTS_NAME,
+                cache_dir=cache_dir,
+                force_download=force_download,
+                resume_download=resume_download,
+                proxies=proxies,
+                local_files_only=local_files_only,
+                use_auth_token=use_auth_token,
+                revision=revision,
+                subfolder=subfolder,
+                user_agent=user_agent,
+            )
+
+        if low_cpu_mem_usage:
+            # Instantiate model with empty weights
+            with accelerate.init_empty_weights():
+                config, unused_kwargs = cls.load_config(
+                    config_path,
+                    cache_dir=cache_dir,
+                    return_unused_kwargs=True,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    use_auth_token=use_auth_token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    device_map=device_map,
+                    **kwargs,
+                )
+                model = cls.from_config(config, **unused_kwargs)
+
+            # if device_map is None, load the state dict and move the params from meta device to the cpu
+            if device_map is None:
+                param_device = "cpu"
+                state_dict = load_state_dict(model_file)
+                # move the params from meta device to cpu
+                missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
+                if len(missing_keys) > 0:
+                    raise ValueError(
+                        f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
+                        f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
+                        " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
+                        " those weights or else make sure your checkpoint file is correct."
+                    )
+
+                for param_name, param in state_dict.items():
+                    accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
+                    if accepts_dtype:
+                        set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
+                    else:
+                        set_module_tensor_to_device(model, param_name, param_device, value=param)
+            else:  # else let accelerate handle loading and dispatching.
+                # Load weights and dispatch according to the device_map
+                # by deafult the device_map is None and the weights are loaded on the CPU
+                accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
+
+            loading_info = {
+                "missing_keys": [],
+                "unexpected_keys": [],
+                "mismatched_keys": [],
+                "error_msgs": [],
+            }
+        else:
+            config, unused_kwargs = cls.load_config(
+                config_path,
+                cache_dir=cache_dir,
+                return_unused_kwargs=True,
+                force_download=force_download,
+                resume_download=resume_download,
+                proxies=proxies,
+                local_files_only=local_files_only,
+                use_auth_token=use_auth_token,
+                revision=revision,
+                subfolder=subfolder,
+                device_map=device_map,
+                **kwargs,
+            )
+            model = cls.from_config(config, **unused_kwargs)
+
+            state_dict = load_state_dict(model_file)
+
+            model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+                model,
+                state_dict,
+                model_file,
+                pretrained_model_name_or_path,
+                ignore_mismatched_sizes=ignore_mismatched_sizes,
+            )
+
+            loading_info = {
+                "missing_keys": missing_keys,
+                "unexpected_keys": unexpected_keys,
+                "mismatched_keys": mismatched_keys,
+                "error_msgs": error_msgs,
+            }
+
+        if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+            raise ValueError(
+                f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+            )
+        elif torch_dtype is not None:
+            model = model.to(torch_dtype)
+
+        model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+        # Set model in evaluation mode to deactivate DropOut modules by default
+        model.eval()
+        if output_loading_info:
+            return model, loading_info
+
+        return model
+
+    @classmethod
+    def _load_pretrained_model(
+        cls,
+        model,
+        state_dict,
+        resolved_archive_file,
+        pretrained_model_name_or_path,
+        ignore_mismatched_sizes=False,
+    ):
+        # Retrieve missing & unexpected_keys
+        model_state_dict = model.state_dict()
+        loaded_keys = [k for k in state_dict.keys()]
+
+        expected_keys = list(model_state_dict.keys())
+
+        original_loaded_keys = loaded_keys
+
+        missing_keys = list(set(expected_keys) - set(loaded_keys))
+        unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+        # Make sure we are able to load base models as well as derived models (with heads)
+        model_to_load = model
+
+        def _find_mismatched_keys(
+            state_dict,
+            model_state_dict,
+            loaded_keys,
+            ignore_mismatched_sizes,
+        ):
+            mismatched_keys = []
+            if ignore_mismatched_sizes:
+                for checkpoint_key in loaded_keys:
+                    model_key = checkpoint_key
+
+                    if (
+                        model_key in model_state_dict
+                        and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+                    ):
+                        mismatched_keys.append(
+                            (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+                        )
+                        del state_dict[checkpoint_key]
+            return mismatched_keys
+
+        if state_dict is not None:
+            # Whole checkpoint
+            mismatched_keys = _find_mismatched_keys(
+                state_dict,
+                model_state_dict,
+                original_loaded_keys,
+                ignore_mismatched_sizes,
+            )
+            error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+        if len(error_msgs) > 0:
+            error_msg = "\n\t".join(error_msgs)
+            if "size mismatch" in error_msg:
+                error_msg += (
+                    "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+                )
+            raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+        if len(unexpected_keys) > 0:
+            logger.warning(
+                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+                " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+                " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+                f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+                " identical (initializing a BertForSequenceClassification model from a"
+                " BertForSequenceClassification model)."
+            )
+        else:
+            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+        if len(missing_keys) > 0:
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+            )
+        elif len(mismatched_keys) == 0:
+            logger.info(
+                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+                f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+                " without further training."
+            )
+        if len(mismatched_keys) > 0:
+            mismatched_warning = "\n".join(
+                [
+                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+                    for key, shape1, shape2 in mismatched_keys
+                ]
+            )
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+                f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+                " able to use it for predictions and inference."
+            )
+
+        return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+    @property
+    def device(self) -> device:
+        """
+        `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+        device).
+        """
+        return get_parameter_device(self)
+
+    @property
+    def dtype(self) -> torch.dtype:
+        """
+        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+        """
+        return get_parameter_dtype(self)
+
+    def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+        """
+        Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+        Args:
+            only_trainable (`bool`, *optional*, defaults to `False`):
+                Whether or not to return only the number of trainable parameters
+
+            exclude_embeddings (`bool`, *optional*, defaults to `False`):
+                Whether or not to return only the number of non-embeddings parameters
+
+        Returns:
+            `int`: The number of parameters.
+        """
+
+        if exclude_embeddings:
+            embedding_param_names = [
+                f"{name}.weight"
+                for name, module_type in self.named_modules()
+                if isinstance(module_type, torch.nn.Embedding)
+            ]
+            non_embedding_parameters = [
+                parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+            ]
+            return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+        else:
+            return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+""" ConfigMixin base class and utilities."""
+
+
+class FrozenDict(OrderedDict):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        for key, value in self.items():
+            setattr(self, key, value)
+
+        self.__frozen = True
+
+    def __delitem__(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+    def setdefault(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+    def pop(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+    def update(self, *args, **kwargs):
+        raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+    def __setattr__(self, name, value):
+        if hasattr(self, "__frozen") and self.__frozen:
+            raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+        super().__setattr__(name, value)
+
+    def __setitem__(self, name, value):
+        if hasattr(self, "__frozen") and self.__frozen:
+            raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+        super().__setitem__(name, value)
+
+
+class ConfigMixin:
+    r"""
+    Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+    methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+        - [`~ConfigMixin.from_config`]
+        - [`~ConfigMixin.save_config`]
+
+    Class attributes:
+        - **config_name** (`str`) -- A filename under which the config should stored when calling
+          [`~ConfigMixin.save_config`] (should be overridden by parent class).
+        - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+          overridden by subclass).
+        - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
+        - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
+          should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
+          subclass).
+    """
+    config_name = None
+    ignore_for_config = []
+    has_compatibles = False
+
+    _deprecated_kwargs = []
+
+    def register_to_config(self, **kwargs):
+        if self.config_name is None:
+            raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+        # Special case for `kwargs` used in deprecation warning added to schedulers
+        # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
+        # or solve in a more general way.
+        kwargs.pop("kwargs", None)
+        for key, value in kwargs.items():
+            try:
+                setattr(self, key, value)
+            except AttributeError as err:
+                logger.error(f"Can't set {key} with value {value} for {self}")
+                raise err
+
+        if not hasattr(self, "_internal_dict"):
+            internal_dict = kwargs
+        else:
+            previous_dict = dict(self._internal_dict)
+            internal_dict = {**self._internal_dict, **kwargs}
+            logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+        self._internal_dict = FrozenDict(internal_dict)
+
+    def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+        """
+        Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+        [`~ConfigMixin.from_config`] class method.
+
+        Args:
+            save_directory (`str` or `os.PathLike`):
+                Directory where the configuration JSON file will be saved (will be created if it does not exist).
+        """
+        if os.path.isfile(save_directory):
+            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        # If we save using the predefined names, we can load using `from_config`
+        output_config_file = os.path.join(save_directory, self.config_name)
+
+        self.to_json_file(output_config_file)
+        logger.info(f"Configuration saved in {output_config_file}")
+
+    @classmethod
+    def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, **kwargs):
+        r"""
+        Instantiate a Python class from a config dictionary
+
+        Parameters:
+            config (`Dict[str, Any]`):
+                A config dictionary from which the Python class will be instantiated. Make sure to only load
+                configuration files of compatible classes.
+            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+                Whether kwargs that are not consumed by the Python class should be returned or not.
+
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the Python class.
+                `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
+                overwrite same named arguments of `config`.
+
+        Examples:
+
+        ```python
+        >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
+
+        >>> # Download scheduler from huggingface.co and cache.
+        >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
+
+        >>> # Instantiate DDIM scheduler class with same config as DDPM
+        >>> scheduler = DDIMScheduler.from_config(scheduler.config)
+
+        >>> # Instantiate PNDM scheduler class with same config as DDPM
+        >>> scheduler = PNDMScheduler.from_config(scheduler.config)
+        ```
+        """
+        # <===== TO BE REMOVED WITH DEPRECATION
+        # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
+        if "pretrained_model_name_or_path" in kwargs:
+            config = kwargs.pop("pretrained_model_name_or_path")
+
+        if config is None:
+            raise ValueError("Please make sure to provide a config as the first positional argument.")
+        # ======>
+
+        # Return model and optionally state and/or unused_kwargs
+        model = cls(**config)
+        return model
+
+    @classmethod
+    def load_config(
+        cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
+    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+        r"""
+        Instantiate a Python class from a config dictionary
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+                Can be either:
+
+                    - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+                      organization name, like `google/ddpm-celebahq-256`.
+                    - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+                      `./my_model_directory/`.
+
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory in which a downloaded pretrained model configuration should be cached if the
+                standard cache should not be used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+                file exists.
+            proxies (`Dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            output_loading_info(`bool`, *optional*, defaults to `False`):
+                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+            local_files_only(`bool`, *optional*, defaults to `False`):
+                Whether or not to only look at local files (i.e., do not try to download the model).
+            use_auth_token (`str` or *bool*, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+                when running `transformers-cli login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo (either remote in
+                huggingface.co or downloaded locally), you can specify the folder name here.
+
+        <Tip>
+
+         It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+         models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+        </Tip>
+
+        <Tip>
+
+        Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+        use this method in a firewalled environment.
+
+        </Tip>
+        """
+        cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", False)
+        proxies = kwargs.pop("proxies", None)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        local_files_only = kwargs.pop("local_files_only", False)
+        revision = kwargs.pop("revision", None)
+        _ = kwargs.pop("mirror", None)
+        subfolder = kwargs.pop("subfolder", None)
+
+        user_agent = {"file_type": "config"}
+
+        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+        if cls.config_name is None:
+            raise ValueError(
+                "`self.config_name` is not defined. Note that one should not load a config from "
+                "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+            )
+
+        if os.path.isfile(pretrained_model_name_or_path):
+            config_file = pretrained_model_name_or_path
+        elif os.path.isdir(pretrained_model_name_or_path):
+            if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+                # Load from a PyTorch checkpoint
+                config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+            elif subfolder is not None and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+            ):
+                config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+            else:
+                raise EnvironmentError(
+                    f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+                )
+        else:
+            try:
+                # Load from URL or cache if already cached
+                config_file = hf_hub_download(
+                    pretrained_model_name_or_path,
+                    filename=cls.config_name,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    local_files_only=local_files_only,
+                    use_auth_token=use_auth_token,
+                    user_agent=user_agent,
+                    subfolder=subfolder,
+                    revision=revision,
+                )
+
+            except RepositoryNotFoundError:
+                raise EnvironmentError(
+                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+                    " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+                    " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+                    " login`."
+                )
+            except RevisionNotFoundError:
+                raise EnvironmentError(
+                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+                    " this model name. Check the model page at"
+                    f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+                )
+            except EntryNotFoundError:
+                raise EnvironmentError(
+                    f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+                )
+            except HTTPError as err:
+                raise EnvironmentError(
+                    "There was a specific connection error when trying to load"
+                    f" {pretrained_model_name_or_path}:\n{err}"
+                )
+            except ValueError:
+                raise EnvironmentError(
+                    f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+                    f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+                    f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+                    " run the library in offline mode at"
+                    " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+                )
+            except EnvironmentError:
+                raise EnvironmentError(
+                    f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+                    f"containing a {cls.config_name} file"
+                )
+
+        try:
+            # Load config dict
+            config_dict = cls._dict_from_json_file(config_file)
+        except (json.JSONDecodeError, UnicodeDecodeError):
+            raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+        if return_unused_kwargs:
+            return config_dict, kwargs
+
+        return config_dict
+
+    @staticmethod
+    def _get_init_keys(cls):
+        return set(dict(inspect.signature(cls.__init__).parameters).keys())
+
+    @classmethod
+    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+        with open(json_file, "r", encoding="utf-8") as reader:
+            text = reader.read()
+        return json.loads(text)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__} {self.to_json_string()}"
+
+    @property
+    def config(self) -> Dict[str, Any]:
+        """
+        Returns the config of the class as a frozen dictionary
+
+        Returns:
+            `Dict[str, Any]`: Config of the class.
+        """
+        return self._internal_dict
+
+    def to_json_string(self) -> str:
+        """
+        Serializes this instance to a JSON string.
+
+        Returns:
+            `str`: String containing all the attributes that make up this configuration instance in JSON format.
+        """
+        config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+        config_dict["_class_name"] = self.__class__.__name__
+        config_dict["_version"] = __version__
+
+        def to_json_saveable(value):
+            if isinstance(value, np.ndarray):
+                value = value.tolist()
+            elif isinstance(value, PosixPath):
+                value = str(value)
+            return value
+
+        config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
+        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+        """
+        Save this instance to a JSON file.
+
+        Args:
+            json_file_path (`str` or `os.PathLike`):
+                Path to the JSON file in which this configuration instance's parameters will be saved.
+        """
+        with open(json_file_path, "w", encoding="utf-8") as writer:
+            writer.write(self.to_json_string())
+
+
+def register_to_config(init):
+    r"""
+    Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+    automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+    shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+    Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+    """
+
+    @functools.wraps(init)
+    def inner_init(self, *args, **kwargs):
+        # Ignore private kwargs in the init.
+        init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+
+        config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
+        if not isinstance(self, ConfigMixin):
+            raise RuntimeError(
+                f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+                "not inherit from `ConfigMixin`."
+            )
+
+        ignore = getattr(self, "ignore_for_config", [])
+        # Get positional arguments aligned with kwargs
+        new_kwargs = {}
+        signature = inspect.signature(init)
+        parameters = {
+            name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+        }
+        for arg, name in zip(args, parameters.keys()):
+            new_kwargs[name] = arg
+
+        # Then add all kwargs
+        new_kwargs.update(
+            {
+                k: init_kwargs.get(k, default)
+                for k, default in parameters.items()
+                if k not in ignore and k not in new_kwargs
+            }
+        )
+        new_kwargs = {**config_init_kwargs, **new_kwargs}
+        getattr(self, "register_to_config")(**new_kwargs)
+        init(self, *args, **init_kwargs)
+
+    return inner_init
diff --git a/vqvae_muse.py b/vqvae_muse.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2d57f7e54c46b71c75bff0124f02d6df146d7cb
--- /dev/null
+++ b/vqvae_muse.py
@@ -0,0 +1,594 @@
+# coding=utf-8
+# Copyright 2023 The Taming Transformers Authors and The HuggingFace Inc. team.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import partial
+from typing import Tuple
+import os
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+from vqvae.modeling_utils import ConfigMixin, ModelMixin, register_to_config
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels: int, with_conv: bool):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = nn.Conv2d(
+                in_channels,
+                in_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+            )
+
+    def forward(self, hidden_states):
+        hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            hidden_states = self.conv(hidden_states)
+        return hidden_states
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels: int, with_conv: bool):
+        super().__init__()
+
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+    def forward(self, hidden_states):
+        if self.with_conv:
+            pad = (0, 1, 0, 1)  # pad height and width dim
+            hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
+            hidden_states = self.conv(hidden_states)
+        else:
+            hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)
+        return hidden_states
+
+
+class ResnetBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int = None,
+        use_conv_shortcut: bool = False,
+        dropout_prob: float = 0.0,
+    ):
+        super().__init__()
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
+        self.use_conv_shortcut = use_conv_shortcut
+
+        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+        self.conv1 = nn.Conv2d(
+            self.in_channels,
+            self.out_channels_,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
+        self.dropout = nn.Dropout(dropout_prob)
+        self.conv2 = nn.Conv2d(
+            self.out_channels_,
+            self.out_channels_,
+            kernel_size=3,
+            stride=(1, 1),
+            padding=1,
+        )
+
+        if self.in_channels != self.out_channels_:
+            if use_conv_shortcut:
+                self.conv_shortcut = nn.Conv2d(
+                    self.in_channels,
+                    self.out_channels_,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                )
+            else:
+                self.nin_shortcut = nn.Conv2d(
+                    self.in_channels,
+                    self.out_channels_,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+
+    def forward(self, hidden_states):
+        residual = hidden_states
+        hidden_states = self.norm1(hidden_states)
+        hidden_states = F.silu(hidden_states)
+        hidden_states = self.conv1(hidden_states)
+
+        hidden_states = self.norm2(hidden_states)
+        hidden_states = F.silu(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.conv2(hidden_states)
+
+        if self.in_channels != self.out_channels_:
+            if self.use_conv_shortcut:
+                residual = self.conv_shortcut(residual)
+            else:
+                residual = self.nin_shortcut(residual)
+
+        return hidden_states + residual
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels: int):
+        super().__init__()
+
+        self.in_channels = in_channels
+        conv = partial(nn.Conv2d, self.in_channels, self.in_channels, kernel_size=1, stride=1, padding=0)
+
+        self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
+        self.q, self.k, self.v = conv(), conv(), conv()
+        self.proj_out = conv()
+
+    def forward(self, hidden_states):
+        residual = hidden_states
+        hidden_states = self.norm(hidden_states)
+
+        query = self.q(hidden_states)
+        key = self.k(hidden_states)
+        value = self.v(hidden_states)
+
+        # compute attentions
+        batch, channels, height, width = query.shape
+        query = query.reshape((batch, channels, height * width))
+        query = query.permute(0, 2, 1)  # (b, hw, c)
+        key = key.reshape((batch, channels, height * width))
+
+        attn_weights = torch.bmm(query, key)  # b,hw,hw
+        attn_weights = attn_weights * (int(channels) ** -0.5)
+        attn_weights = nn.functional.softmax(attn_weights, dim=2)
+
+        # attend to values
+        value = value.reshape((batch, channels, height * width))
+        attn_weights = attn_weights.permute(0, 2, 1)
+        hidden_states = torch.bmm(value, attn_weights)
+        hidden_states = hidden_states.reshape((batch, channels, height, width))
+
+        hidden_states = self.proj_out(hidden_states)
+        hidden_states = hidden_states + residual
+        return hidden_states
+
+
+class UpsamplingBlock(nn.Module):
+    def __init__(self, config, curr_res: int, block_idx: int):
+        super().__init__()
+
+        self.config = config
+        self.block_idx = block_idx
+        self.curr_res = curr_res
+
+        if self.block_idx == self.config.num_resolutions - 1:
+            block_in = self.config.hidden_channels * self.config.channel_mult[-1]
+        else:
+            block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
+
+        block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+
+        res_blocks = []
+        attn_blocks = []
+        for _ in range(self.config.num_res_blocks + 1):
+            res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
+            block_in = block_out
+            if self.curr_res in self.config.attn_resolutions:
+                attn_blocks.append(AttnBlock(block_in))
+
+        self.block = nn.ModuleList(res_blocks)
+        self.attn = nn.ModuleList(attn_blocks)
+
+        self.upsample = None
+        if self.block_idx != 0:
+            self.upsample = Upsample(block_in, self.config.resample_with_conv)
+
+    def forward(self, hidden_states):
+        for i, res_block in enumerate(self.block):
+            hidden_states = res_block(hidden_states)
+            if len(self.attn) > 1:
+                hidden_states = self.attn[i](hidden_states)
+
+        if self.upsample is not None:
+            hidden_states = self.upsample(hidden_states)
+
+        return hidden_states
+
+
+class DownsamplingBlock(nn.Module):
+    def __init__(self, config, curr_res: int, block_idx: int):
+        super().__init__()
+
+        self.config = config
+        self.curr_res = curr_res
+        self.block_idx = block_idx
+
+        in_channel_mult = (1,) + tuple(self.config.channel_mult)
+        block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
+        block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
+
+        res_blocks = nn.ModuleList()
+        attn_blocks = nn.ModuleList()
+        for _ in range(self.config.num_res_blocks):
+            res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
+            block_in = block_out
+            if self.curr_res in self.config.attn_resolutions:
+                attn_blocks.append(AttnBlock(block_in))
+
+        self.block = res_blocks
+        self.attn = attn_blocks
+
+        self.downsample = None
+        if self.block_idx != self.config.num_resolutions - 1:
+            self.downsample = Downsample(block_in, self.config.resample_with_conv)
+
+    def forward(self, hidden_states):
+        for i, res_block in enumerate(self.block):
+            hidden_states = res_block(hidden_states)
+            if len(self.attn) > 1:
+                hidden_states = self.attn[i](hidden_states)
+
+        if self.downsample is not None:
+            hidden_states = self.downsample(hidden_states)
+
+        return hidden_states
+
+
+class MidBlock(nn.Module):
+    def __init__(self, config, in_channels: int, no_attn: False, dropout: float):
+        super().__init__()
+
+        self.config = config
+        self.in_channels = in_channels
+        self.no_attn = no_attn
+        self.dropout = dropout
+
+        self.block_1 = ResnetBlock(
+            self.in_channels,
+            self.in_channels,
+            dropout_prob=self.dropout,
+        )
+        if not no_attn:
+            self.attn_1 = AttnBlock(self.in_channels)
+        self.block_2 = ResnetBlock(
+            self.in_channels,
+            self.in_channels,
+            dropout_prob=self.dropout,
+        )
+
+    def forward(self, hidden_states):
+        hidden_states = self.block_1(hidden_states)
+        if not self.no_attn:
+            hidden_states = self.attn_1(hidden_states)
+        hidden_states = self.block_2(hidden_states)
+        return hidden_states
+
+
+class Encoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.config = config
+
+        # downsampling
+        self.conv_in = nn.Conv2d(
+            self.config.num_channels,
+            self.config.hidden_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+        curr_res = self.config.resolution
+        downsample_blocks = []
+        for i_level in range(self.config.num_resolutions):
+            downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level))
+
+            if i_level != self.config.num_resolutions - 1:
+                curr_res = curr_res // 2
+        self.down = nn.ModuleList(downsample_blocks)
+
+        # middle
+        mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
+        self.mid = MidBlock(config, mid_channels, self.config.no_attn_mid_block, self.config.dropout)
+
+        # end
+        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
+        self.conv_out = nn.Conv2d(
+            mid_channels,
+            self.config.z_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+    def forward(self, pixel_values):
+        # downsampling
+        hidden_states = self.conv_in(pixel_values)
+        for block in self.down:
+            hidden_states = block(hidden_states)
+
+        # middle
+        hidden_states = self.mid(hidden_states)
+
+        # end
+        hidden_states = self.norm_out(hidden_states)
+        hidden_states = F.silu(hidden_states)
+        hidden_states = self.conv_out(hidden_states)
+
+        return hidden_states
+
+
+class Decoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.config = config
+
+        # compute in_channel_mult, block_in and curr_res at lowest res
+        block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
+        curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
+        self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
+
+        # z to block_in
+        self.conv_in = nn.Conv2d(
+            self.config.z_channels,
+            block_in,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+        # middle
+        self.mid = MidBlock(config, block_in, self.config.no_attn_mid_block, self.config.dropout)
+
+        # upsampling
+        upsample_blocks = []
+        for i_level in reversed(range(self.config.num_resolutions)):
+            upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level))
+            if i_level != 0:
+                curr_res = curr_res * 2
+        self.up = nn.ModuleList(list(reversed(upsample_blocks)))  # reverse to get consistent order
+
+        # end
+        block_out = self.config.hidden_channels * self.config.channel_mult[0]
+        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
+        self.conv_out = nn.Conv2d(
+            block_out,
+            self.config.num_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+    def forward(self, hidden_states):
+        # z to block_in
+        hidden_states = self.conv_in(hidden_states)
+
+        # middle
+        hidden_states = self.mid(hidden_states)
+
+        # upsampling
+        for block in reversed(self.up):
+            hidden_states = block(hidden_states)
+
+        # end
+        hidden_states = self.norm_out(hidden_states)
+        hidden_states = F.silu(hidden_states)
+        hidden_states = self.conv_out(hidden_states)
+
+        return hidden_states
+
+
+class VectorQuantizer(nn.Module):
+    """
+    see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+    Discretization bottleneck part of the VQ-VAE.
+    """
+
+    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
+        r"""
+        Args:
+            num_embeddings: number of vectors in the quantized space.
+            embedding_dim: dimensionality of the tensors in the quantized space.
+                Inputs to the modules must be in this format as well.
+            commitment_cost: scalar which controls the weighting of the loss terms
+                (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
+        """
+        super().__init__()
+
+        self.num_embeddings = num_embeddings
+        self.embedding_dim = embedding_dim
+        self.commitment_cost = commitment_cost
+
+        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
+        self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
+
+    def forward(self, hidden_states, return_loss=False):
+        """
+        Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
+        closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
+        quantization pipeline:
+            1. get encoder input (B,C,H,W)
+            2. flatten input to (B*H*W,C)
+        """
+        # reshape z -> (batch, height, width, channel) and flatten
+        hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
+
+        distances = self.compute_distances(hidden_states)
+        min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
+        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
+        min_encodings.scatter_(1, min_encoding_indices, 1)
+
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
+
+        # reshape to (batch, num_tokens)
+        min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
+
+        # compute loss for embedding
+        loss = None
+        if return_loss:
+            loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
+                (z_q - hidden_states.detach()) ** 2
+            )
+            # preserve gradients
+            z_q = hidden_states + (z_q - hidden_states).detach()
+
+        # reshape back to match original input shape
+        z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+        return z_q, min_encoding_indices, loss
+
+    def compute_distances(self, hidden_states):
+        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+        hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
+        emb_weights = self.embedding.weight.t()
+
+        inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
+        codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
+        distances = torch.addmm(
+            inputs_norm_sq + codebook_t_norm_sq,
+            hidden_states_flattended,
+            emb_weights,
+            alpha=-2.0,
+        )
+        return distances
+
+    def get_codebook_entry(self, indices):
+        # indices are expected to be of shape (batch, num_tokens)
+        # get quantized latent vectors
+        batch, num_tokens = indices.shape
+        z_q = self.embedding(indices)
+        z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
+        return z_q
+
+    # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
+    def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
+        hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()  # (batch, height, width, channel)
+        distances = self.compute_distances(hidden_states)  # (batch * height * width, num_embeddings)
+
+        soft_code = F.softmax(-distances / temp, dim=-1)  # (batch * height * width, num_embeddings)
+        if stochastic:
+            code = torch.multinomial(soft_code, 1)  # (batch * height * width, 1)
+        else:
+            code = distances.argmin(dim=-1)  # (batch * height * width)
+
+        code = code.reshape(hidden_states.shape[0], -1)  # (batch, height * width)
+        batch, num_tokens = code.shape
+        soft_code = soft_code.reshape(batch, num_tokens, -1)  # (batch, height * width, num_embeddings)
+        return soft_code, code
+
+    def get_code(self, hidden_states):
+        # reshape z -> (batch, height, width, channel)
+        hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
+        distances = self.compute_distances(hidden_states)
+        indices = torch.argmin(distances, axis=1).unsqueeze(1)
+        indices = indices.reshape(hidden_states.shape[0], -1)
+        return indices
+
+
+class VQGANModel(ModelMixin, ConfigMixin):
+    @register_to_config
+    def __init__(
+        self,
+        resolution: int = 256,
+        num_channels: int = 3,
+        hidden_channels: int = 128,
+        channel_mult: Tuple = (1, 1, 2, 2, 4),
+        num_res_blocks: int = 2,
+        attn_resolutions: int = (16,),
+        no_attn_mid_block: bool = False,
+        z_channels: int = 256,
+        num_embeddings: int = 1024,
+        quantized_embed_dim: int = 256,
+        dropout: float = 0.0,
+        resample_with_conv: bool = True,
+        commitment_cost: float = 0.25,
+    ):
+        super().__init__()
+
+        self.config.num_resolutions = len(channel_mult)
+        self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1)
+        self.config.latent_size = resolution // self.config.reduction_factor
+
+        self.encoder = Encoder(self.config)
+        self.decoder = Decoder(self.config)
+        self.quantize = VectorQuantizer(
+            self.config.num_embeddings, self.config.quantized_embed_dim, self.config.commitment_cost
+        )
+        self.quant_conv = nn.Conv2d(
+            self.config.z_channels,
+            self.config.quantized_embed_dim,
+            kernel_size=1,
+        )
+        self.post_quant_conv = nn.Conv2d(
+            self.config.quantized_embed_dim,
+            self.config.z_channels,
+            kernel_size=1,
+        )
+
+    def encode(self, pixel_values, return_loss=False):
+        hidden_states = self.encoder(pixel_values)
+        hidden_states = self.quant_conv(hidden_states)
+        quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
+        output = (quantized_states, codebook_indices)
+        if return_loss:
+            output = output + (codebook_loss,)
+        return output
+
+    def decode(self, quantized_states):
+        hidden_states = self.post_quant_conv(quantized_states)
+        reconstructed_pixel_values = self.decoder(hidden_states)
+        return reconstructed_pixel_values
+
+    def decode_code(self, codebook_indices):
+        quantized_states = self.quantize.get_codebook_entry(codebook_indices)
+        reconstructed_pixel_values = self.decode(quantized_states)
+        return reconstructed_pixel_values
+
+    def get_code(self, pixel_values):
+        hidden_states = self.encoder(pixel_values)
+        hidden_states = self.quant_conv(hidden_states)
+        codebook_indices = self.quantize.get_code(hidden_states)
+        return codebook_indices
+
+    def forward(self, pixel_values, return_loss=False):
+        hidden_states = self.encoder(pixel_values)
+        hidden_states = self.quant_conv(hidden_states)
+        quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
+        reconstructed_pixel_values = self.decode(quantized_states)
+        outputs = (reconstructed_pixel_values, quantized_states, codebook_indices)
+        if return_loss:
+            outputs = outputs + (codebook_loss,)
+        return outputs
+
+
+
+def get_tokenizer_muse():
+
+    ckpts_path = "Emma02/vqvae_ckpts"
+    net = VQGANModel.from_pretrained(ckpts_path)
+
+    return net