diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..5f2d567169f2cfa9f214b4ac7658b44d9b1cc378 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 *.zip filter=lfs diff=lfs merge=lfs -text
 *.zst filter=lfs diff=lfs merge=lfs -text
 *tfevents* filter=lfs diff=lfs merge=lfs -text
+*.ttf filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2da852da8361b3732166878065f1ddab3f124d5
--- /dev/null
+++ b/app.py
@@ -0,0 +1,401 @@
+'''
+AnyText: Multilingual Visual Text Generation And Editing
+Paper: https://arxiv.org/abs/2311.03054
+Code: https://github.com/tyxsspa/AnyText
+Copyright (c) Alibaba, Inc. and its affiliates.
+'''
+import os
+from modelscope.pipelines import pipeline
+import cv2
+import gradio as gr
+import numpy as np
+import re
+from gradio.components import Component
+from util import check_channels, resize_image, save_images
+import json
+
+BBOX_MAX_NUM = 8
+img_save_folder = 'SaveImages'
+load_model = True
+if load_model:
+    inference = pipeline('my-anytext-task', model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.0')
+
+
+def count_lines(prompt):
+    prompt = prompt.replace('“', '"')
+    prompt = prompt.replace('”', '"')
+    p = '"(.*?)"'
+    strs = re.findall(p, prompt)
+    if len(strs) == 0:
+        strs = [' ']
+    return len(strs)
+
+
+def generate_rectangles(w, h, n, max_trys=200):
+    img = np.zeros((h, w, 1), dtype=np.uint8)
+    rectangles = []
+    attempts = 0
+    n_pass = 0
+    low_edge = int(max(w, h)*0.3 if n <= 3 else max(w, h)*0.2)  # ~150, ~100
+    while attempts < max_trys:
+        rect_w = min(np.random.randint(max((w*0.5)//n, low_edge), w), int(w*0.8))
+        ratio = np.random.uniform(4, 10)
+        rect_h = max(low_edge, int(rect_w/ratio))
+        rect_h = min(rect_h, int(h*0.8))
+        # gen rotate angle
+        rotation_angle = 0
+        rand_value = np.random.rand()
+        if rand_value < 0.7:
+            pass
+        elif rand_value < 0.8:
+            rotation_angle = np.random.randint(0, 40)
+        elif rand_value < 0.9:
+            rotation_angle = np.random.randint(140, 180)
+        else:
+            rotation_angle = np.random.randint(85, 95)
+        # rand position
+        x = np.random.randint(0, w - rect_w)
+        y = np.random.randint(0, h - rect_h)
+        # get vertex
+        rect_pts = cv2.boxPoints(((rect_w/2, rect_h/2), (rect_w, rect_h), rotation_angle))
+        rect_pts = np.int32(rect_pts)
+        # move
+        rect_pts += (x, y)
+        # check boarder
+        if np.any(rect_pts < 0) or np.any(rect_pts[:, 0] >= w) or np.any(rect_pts[:, 1] >= h):
+            attempts += 1
+            continue
+        # check overlap
+        if any(check_overlap_polygon(rect_pts, rp) for rp in rectangles):
+            attempts += 1
+            continue
+        n_pass += 1
+        cv2.fillPoly(img, [rect_pts], 255)
+        rectangles.append(rect_pts)
+        if n_pass == n:
+            break
+    print("attempts:", attempts)
+    if len(rectangles) != n:
+        raise gr.Error(f'Failed in auto generate positions after {attempts} attempts, try again!')
+    return img
+
+
+def check_overlap_polygon(rect_pts1, rect_pts2):
+    poly1 = cv2.convexHull(rect_pts1)
+    poly2 = cv2.convexHull(rect_pts2)
+    rect1 = cv2.boundingRect(poly1)
+    rect2 = cv2.boundingRect(poly2)
+    if rect1[0] + rect1[2] >= rect2[0] and rect2[0] + rect2[2] >= rect1[0] and rect1[1] + rect1[3] >= rect2[1] and rect2[1] + rect2[3] >= rect1[1]:
+        return True
+    return False
+
+
+def draw_rects(width, height, rects):
+    img = np.zeros((height, width, 1), dtype=np.uint8)
+    for rect in rects:
+        x1 = int(rect[0] * width)
+        y1 = int(rect[1] * height)
+        w = int(rect[2] * width)
+        h = int(rect[3] * height)
+        x2 = x1 + w
+        y2 = y1 + h
+        cv2.rectangle(img, (x1, y1), (x2, y2), 255, -1)
+    return img
+
+
+def process(mode, prompt, pos_radio, sort_radio, revise_pos, show_debug, draw_img, rect_img, ref_img, ori_img, img_count, ddim_steps, w, h, strength, cfg_scale, seed, eta, a_prompt, n_prompt, *rect_list):
+    n_lines = count_lines(prompt)
+    # Text Generation
+    if mode == 'gen':
+        # create pos_imgs
+        if pos_radio == 'Manual-draw(手绘)':
+            if draw_img is not None:
+                pos_imgs = 255 - draw_img['image']
+                if 'mask' in draw_img:
+                    pos_imgs = pos_imgs.astype(np.float32) + draw_img['mask'][..., 0:3].astype(np.float32)
+                    pos_imgs = pos_imgs.clip(0, 255).astype(np.uint8)
+            else:
+                pos_imgs = np.zeros((w, h, 1))
+        elif pos_radio == 'Manual-rect(拖框)':
+            rect_check = rect_list[:BBOX_MAX_NUM]
+            rect_xywh = rect_list[BBOX_MAX_NUM:]
+            checked_rects = []
+            for idx, c in enumerate(rect_check):
+                if c:
+                    _xywh = rect_xywh[4*idx:4*(idx+1)]
+                    checked_rects += [_xywh]
+            pos_imgs = draw_rects(w, h, checked_rects)
+        elif pos_radio == 'Auto-rand(随机)':
+            pos_imgs = generate_rectangles(w, h, n_lines, max_trys=500)
+    # Text Editing
+    elif mode == 'edit':
+        revise_pos = False  # disable pos revise in edit mode
+        if ref_img is None or ori_img is None:
+            raise gr.Error('No reference image, please upload one for edit!')
+        edit_image = ori_img.clip(1, 255)  # for mask reason
+        edit_image = check_channels(edit_image)
+        edit_image = resize_image(edit_image, max_length=768)
+        h, w = edit_image.shape[:2]
+        if isinstance(ref_img, dict) and 'mask' in ref_img and ref_img['mask'].mean() > 0:
+            pos_imgs = 255 - edit_image
+            edit_mask = cv2.resize(ref_img['mask'][..., 0:3], (w, h))
+            pos_imgs = pos_imgs.astype(np.float32) + edit_mask.astype(np.float32)
+            pos_imgs = pos_imgs.clip(0, 255).astype(np.uint8)
+        else:
+            if isinstance(ref_img, dict) and 'image' in ref_img:
+                ref_img = ref_img['image']
+            pos_imgs = 255 - ref_img  # example input ref_img is used as pos
+    cv2.imwrite('pos_imgs.png', 255-pos_imgs[..., ::-1])
+    params = {
+        "sort_priority": sort_radio,
+        "show_debug": show_debug,
+        "revise_pos": revise_pos,
+        "image_count": img_count,
+        "ddim_steps": ddim_steps,
+        "image_width": w,
+        "image_height": h,
+        "strength": strength,
+        "cfg_scale": cfg_scale,
+        "eta": eta,
+        "a_prompt": a_prompt,
+        "n_prompt": n_prompt
+    }
+    input_data = {
+        "prompt": prompt,
+        "seed": seed,
+        "draw_pos": pos_imgs,
+        "ori_image": ori_img,
+    }
+    results, rtn_code, rtn_warning, debug_info = inference(input_data, mode=mode, **params)
+    if rtn_code >= 0:
+        # save_images(results, img_save_folder)
+        # print(f'Done, result images are saved in: {img_save_folder}')
+        if rtn_warning:
+            gr.Warning(rtn_warning)
+    else:
+        raise gr.Error(rtn_warning)
+    return results, gr.Markdown(debug_info, visible=show_debug)
+
+
+def create_canvas(w=512, h=512, c=3, line=5):
+    image = np.full((h, w, c), 200, dtype=np.uint8)
+    for i in range(h):
+        if i % (w//line) == 0:
+            image[i, :, :] = 150
+    for j in range(w):
+        if j % (w//line) == 0:
+            image[:, j, :] = 150
+    image[h//2-8:h//2+8, w//2-8:w//2+8, :] = [200, 0, 0]
+    return image
+
+
+def resize_w(w, img1, img2):
+    if isinstance(img2, dict):
+        img2 = img2['image']
+    return [cv2.resize(img1, (w, img1.shape[0])), cv2.resize(img2, (w, img2.shape[0]))]
+
+
+def resize_h(h, img1, img2):
+    if isinstance(img2, dict):
+        img2 = img2['image']
+    return [cv2.resize(img1, (img1.shape[1], h)), cv2.resize(img2, (img2.shape[1], h))]
+
+
+is_t2i = 'true'
+block = gr.Blocks(css='style.css', theme=gr.themes.Soft()).queue()
+
+with open('javascript/bboxHint.js', 'r') as file:
+    value = file.read()
+escaped_value = json.dumps(value)
+
+with block:
+    block.load(fn=None,
+               _js=f"""() => {{
+               const script = document.createElement("script");
+               const text =  document.createTextNode({escaped_value});
+               script.appendChild(text);
+               document.head.appendChild(script);
+               }}""")
+    gr.HTML('<div style="text-align: center; margin: 20px auto;"> \
+            <img id="banner" src="https://modelscope.cn/api/v1/studio/damo/studio_anytext/repo?Revision=master&FilePath=example_images/banner.png&View=true" alt="anytext"> <br>  \
+            [<a href="https://arxiv.org/abs/2311.03054" style="color:blue; font-size:18px;">arXiv</a>] \
+            [<a href="https://github.com/tyxsspa/AnyText" style="color:blue; font-size:18px;">Code</a>] \
+            [<a href="https://modelscope.cn/models/damo/cv_anytext_text_generation_editing/summary" style="color:blue; font-size:18px;">ModelScope</a>]\
+            version: 1.1.0 </div>')
+    with gr.Row(variant='compact'):
+        with gr.Column():
+            with gr.Accordion('🕹Instructions(说明)', open=False,):
+                with gr.Tabs():
+                    with gr.Tab("English"):
+                        gr.Markdown('<span style="color:navy;font-size:20px">Run Examples</span>')
+                        gr.Markdown('<span style="color:black;font-size:16px">AnyText has two modes: Text Generation and Text Editing, and we provides a variety of examples. Select one, click on [Run!] button to run.</span>')
+                        gr.Markdown('<span style="color:gray;font-size:12px">Please note, before running examples, ensure the manual draw area is empty, otherwise may get wrong results. Additionally, different examples use \
+                                     different parameters (such as resolution, seed, etc.). When generate your own, please pay attention to the parameter changes, or refresh the page to restore the default parameters.</span>')
+                        gr.Markdown('<span style="color:navy;font-size:20px">Text Generation</span>')
+                        gr.Markdown('<span style="color:black;font-size:16px">Enter the textual description (in Chinese or English) of the image you want to generate in [Prompt]. Each text line that needs to be generated should be \
+                                     enclosed in double quotes. Then, manually draw the specified position for each text line to generate the image.</span>\
+                                     <span style="color:red;font-size:16px">The drawing of text positions is crucial to the quality of the resulting image</span>, \
+                                     <span style="color:black;font-size:16px">please do not draw too casually or too small. The number of positions should match the number of text lines, and the size of each position should be matched \
+                                     as closely as possible to the length or width of the corresponding text line. If [Manual-draw] is inconvenient, you can try dragging rectangles [Manual-rect] or random positions [Auto-rand].</span>')
+                        gr.Markdown('<span style="color:gray;font-size:12px">When generating multiple lines, each position is matched with the text line according to a certain rule. The [Sort Position] option is used to \
+                                     determine whether to prioritize sorting from top to bottom or from left to right. You can open the [Show Debug] option in the parameter settings to observe the text position and glyph image \
+                                     in the result. You can also select the [Revise Position] which uses the bounding box of the rendered text as the revised position. However, it is occasionally found that the creativity of the \
+                                     generated text is slightly lower using this method.</span>')
+                        gr.Markdown('<span style="color:navy;font-size:20px">Text Editing</span>')
+                        gr.Markdown('<span style="color:black;font-size:16px">Please upload an image in [Ref] as a reference image, then adjust the brush size, and mark the area(s) to be edited. Input the textual description and \
+                                     the new text to be modified in [Prompt], then generate the image.</span>')
+                        gr.Markdown('<span style="color:gray;font-size:12px">The reference image can be of any resolution, but it will be internally processed with a limit that the longer side cannot exceed 768 pixels, and the \
+                                     width and height will both be scaled to multiples of 64.</span>')
+                    with gr.Tab("简体中文"):
+                        gr.Markdown('<span style="color:navy;font-size:20px">运行示例</span>')
+                        gr.Markdown('<span style="color:black;font-size:16px">AnyText有两种运行模式:文字生成和文字编辑,每种模式下提供了丰富的示例,选择一个,点击[Run!]即可。</span>')
+                        gr.Markdown('<span style="color:gray;font-size:12px">请注意,运行示例前确保手绘位置区域是空的,防止影响示例结果,另外不同示例使用不同的参数(如分辨率,种子数等),如果要自行生成时,请留意参数变化,或刷新页面恢复到默认参数。</span>')
+                        gr.Markdown('<span style="color:navy;font-size:20px">文字生成</span>')
+                        gr.Markdown('<span style="color:black;font-size:16px">在Prompt中输入描述提示词(支持中英文),需要生成的每一行文字用双引号包裹,然后依次手绘指定每行文字的位置,生成图片。</span>\
+                                     <span style="color:red;font-size:16px">文字位置的绘制对成图质量很关键</span>, \
+                                     <span style="color:black;font-size:16px">请不要画的太随意或太小,位置的数量要与文字行数量一致,每个位置的尺寸要与对应的文字行的长短或宽高尽量匹配。如果手绘(Manual-draw)不方便,\
+                                     可以尝试拖框矩形(Manual-rect)或随机生成(Auto-rand)。</span>')
+                        gr.Markdown('<span style="color:gray;font-size:12px">多行生成时,每个位置按照一定规则排序后与文字行做对应,Sort Position选项用于确定排序时优先从上到下还是从左到右。\
+                                     可以在参数设置中打开Show Debug选项,在结果图像中观察文字位置和字形图。也可以勾选Revise Position选项,这样会用渲染文字的外接矩形作为修正后的位置,不过偶尔发现这样生成的文字创造性略低。</span>')
+                        gr.Markdown('<span style="color:navy;font-size:20px">文字编辑</span>')
+                        gr.Markdown('<span style="color:black;font-size:16px">请上传一张待编辑的图片作为参考图(Ref),然后调整笔触大小后,在参考图上涂抹要编辑的位置,在Prompt中输入描述提示词和要修改的文字内容,生成图片。</span>')
+                        gr.Markdown('<span style="color:gray;font-size:12px">参考图可以为任意分辨率,但内部处理时会限制长边不能超过768,并且宽高都被缩放为64的整数倍。</span>')
+            with gr.Accordion('🛠Parameters(参数)', open=False):
+                with gr.Row(variant='compact'):
+                    img_count = gr.Slider(label="Image Count(图片数)", minimum=1, maximum=12, value=4, step=1)
+                    ddim_steps = gr.Slider(label="Steps(步数)", minimum=1, maximum=100, value=20, step=1)
+                with gr.Row(variant='compact'):
+                    image_width = gr.Slider(label="Image Width(宽度)", minimum=256, maximum=768, value=512, step=64)
+                    image_height = gr.Slider(label="Image Height(高度)", minimum=256, maximum=768, value=512, step=64)
+                with gr.Row(variant='compact'):
+                    strength = gr.Slider(label="Strength(控制力度)", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
+                    cfg_scale = gr.Slider(label="CFG-Scale(CFG强度)", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
+                with gr.Row(variant='compact'):
+                    seed = gr.Slider(label="Seed(种子数)", minimum=-1, maximum=99999999, step=1, randomize=False, value=-1)
+                    eta = gr.Number(label="eta (DDIM)", value=0.0)
+                with gr.Row(variant='compact'):
+                    show_debug = gr.Checkbox(label='Show Debug(调试信息)', value=False)
+                    gr.Markdown('<span style="color:silver;font-size:12px">whether show glyph image and debug information in the result(是否在结果中显示glyph图以及调试信息)</span>')
+                a_prompt = gr.Textbox(label="Added Prompt(附加提示词)", value='best quality, extremely detailed,4k, HD, supper legible text,  clear text edges,  clear strokes, neat writing, no watermarks')
+                n_prompt = gr.Textbox(label="Negative Prompt(负向提示词)", value='low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture')
+            prompt = gr.Textbox(label="Prompt(提示词)")
+            with gr.Tabs() as tab_modes:
+                with gr.Tab("🖼Text Generation(文字生成)", elem_id='MD-tab-t2i') as mode_gen:
+                    pos_radio = gr.Radio(["Manual-draw(手绘)", "Manual-rect(拖框)", "Auto-rand(随机)"], value='Manual-draw(手绘)', label="Pos-Method(位置方式)", info="choose a method to specify text positions(选择方法用于指定文字位置).")
+                    with gr.Row():
+                        sort_radio = gr.Radio(["↕", "↔"], value='↕', label="Sort Position(位置排序)", info="position sorting priority(位置排序时的优先级)")
+                        revise_pos = gr.Checkbox(label='Revise Position(修正位置)', value=False)
+                        # gr.Markdown('<span style="color:silver;font-size:12px">try to revise according to text\'s bounding rectangle(尝试通过渲染后的文字行的外接矩形框修正位置)</span>')
+                    with gr.Row(variant='compact'):
+                        rect_cb_list: list[Component] = []
+                        rect_xywh_list: list[Component] = []
+                        for i in range(BBOX_MAX_NUM):
+                            e = gr.Checkbox(label=f'{i}', value=False, visible=False, min_width='10')
+                            x = gr.Slider(label='x', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-x', visible=False)
+                            y = gr.Slider(label='y', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-y',  visible=False)
+                            w = gr.Slider(label='w', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-w',  visible=False)
+                            h = gr.Slider(label='h', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-t2i-{i}-h',  visible=False)
+                            x.change(fn=None, inputs=x, outputs=x, _js=f'v => onBoxChange({is_t2i}, {i}, "x", v)', show_progress=False, queue=False)
+                            y.change(fn=None, inputs=y, outputs=y, _js=f'v => onBoxChange({is_t2i}, {i}, "y", v)', show_progress=False, queue=False)
+                            w.change(fn=None, inputs=w, outputs=w, _js=f'v => onBoxChange({is_t2i}, {i}, "w", v)', show_progress=False, queue=False)
+                            h.change(fn=None, inputs=h, outputs=h, _js=f'v => onBoxChange({is_t2i}, {i}, "h", v)', show_progress=False, queue=False)
+
+                            e.change(fn=None, inputs=e, outputs=e, _js=f'e => onBoxEnableClick({is_t2i}, {i}, e)', queue=False)
+                            rect_cb_list.extend([e])
+                            rect_xywh_list.extend([x, y, w, h])
+
+                    rect_img = gr.Image(value=create_canvas(), label="Rext Position(方框位置)", elem_id="MD-bbox-rect-t2i", show_label=False, visible=False)
+                    draw_img = gr.Image(value=create_canvas(), label="Draw Position(绘制位置)", visible=True, tool='sketch', show_label=False, brush_radius=60)
+
+                    def re_draw():
+                        return [gr.Image(value=create_canvas(), tool='sketch'), gr.Slider(value=512), gr.Slider(value=512)]
+                    draw_img.clear(re_draw, None, [draw_img, image_width, image_height])
+                    image_width.release(resize_w, [image_width, rect_img, draw_img], [rect_img, draw_img])
+                    image_height.release(resize_h, [image_height, rect_img, draw_img], [rect_img, draw_img])
+
+                    def change_options(selected_option):
+                        return [gr.Checkbox(visible=selected_option == 'Manual-rect(拖框)')] * BBOX_MAX_NUM + \
+                                [gr.Image(visible=selected_option == 'Manual-rect(拖框)'),
+                                 gr.Image(visible=selected_option == 'Manual-draw(手绘)'),
+                                 gr.Radio(visible=selected_option != 'Auto-rand(随机)'),
+                                 gr.Checkbox(value=selected_option == 'Auto-rand(随机)')]
+                    pos_radio.change(change_options, pos_radio, rect_cb_list + [rect_img, draw_img, sort_radio, revise_pos], show_progress=False, queue=False)
+                    with gr.Row():
+                        gr.Markdown("")
+                        run_gen = gr.Button(value="Run(运行)!", scale=0.3, elem_classes='run')
+                        gr.Markdown("")
+
+                    def exp_gen_click():
+                        return [gr.Slider(value=512), gr.Slider(value=512)]  # all examples are 512x512, refresh draw_img
+                    exp_gen = gr.Examples(
+                        [
+                            ['一只浣熊站在黑板前,上面写着"深度学习"', "example_images/gen1.png", "Manual-draw(手绘)", "↕", False, 4, 81808278],
+                            ['一个儿童蜡笔画,森林里有一个可爱的蘑菇形状的房子,标题是"森林小屋"', "example_images/gen16.png", "Manual-draw(手绘)", "↕", False, 4, 40173333],
+                            ['一个精美设计的logo,画的是一个黑白风格的厨师,带着厨师帽,logo下方写着“深夜食堂”', "example_images/gen14.png", "Manual-draw(手绘)", "↕", False, 4, 6970544],
+                            ['photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream', "example_images/gen9.png", "Manual-draw(手绘)", "↕", False, 4, 66273235],
+                            ['一张户外雪地靴的电商广告,上面写着 “双12大促!”,“立减50”,“加绒加厚”,“穿脱方便”,“温暖24小时送达”, “包邮”,高级设计感,精美构图', "example_images/gen15.png", "Manual-draw(手绘)", "↕", False, 4, 66980376],
+                            ['Sign on the clean building that reads "科学" and "과학"  and "ステップ" and "SCIENCE"', "example_images/gen6.png", "Manual-draw(手绘)", "↕", True, 4, 13246309],
+                            ['一个精致的马克杯,上面雕刻着一首中国古诗,内容是 "花落知多少" "夜来风雨声" "处处闻啼鸟" "春眠不觉晓"', "example_images/gen3.png", "Manual-draw(手绘)", "↔", False, 4, 60358279],
+                            ['A delicate square cake, cream and fruit, with "CHEERS" "to the" and "GRADUATE" written in chocolate', "example_images/gen8.png", "Manual-draw(手绘)", "↕", False, 4, 93424638],
+                            ['一件精美的毛衣,上面有针织的文字:"通义丹青"', "example_images/gen4.png", "Manual-draw(手绘)", "↕", False, 4, 48769450],
+                            ['一个双肩包的特写照,上面用针织文字写着”为了无法“ ”计算的价值“', "example_images/gen12.png", "Manual-draw(手绘)", "↕", False, 4, 35552323],
+                            ['A nice drawing in pencil of Michael Jackson,  with the words "Micheal" and "Jackson" written on it', "example_images/gen7.png", "Manual-draw(手绘)", "↕", False, 4, 83866922],
+                            ['一个漂亮的蜡笔画,有行星,宇航员,还有宇宙飞船,上面写的是"去火星旅行", "王小明", "11月1日"', "example_images/gen5.png", "Manual-draw(手绘)", "↕", False, 4, 42328250],
+                            ['一个装饰华丽的蛋糕,上面用奶油写着“阿里云”和"APSARA"', "example_images/gen13.png", "Manual-draw(手绘)", "↕", False, 4, 62357019],
+                            ['一张关于墙上的彩色涂鸦艺术的摄影作品,上面写着“人工智能" 和 "神经网络"', "example_images/gen10.png", "Manual-draw(手绘)", "↕", False, 4, 64722007],
+                            ['一枚中国古代铜钱,  上面的文字是 "康"  "寶" "通" "熙"', "example_images/gen2.png", "Manual-draw(手绘)", "↕", False, 4, 24375031],
+                            ['a well crafted ice sculpture that made with "Happy" and "Holidays". Dslr photo, perfect illumination', "example_images/gen11.png", "Manual-draw(手绘)", "↕", True, 4, 64901362],
+                        ],
+                        [prompt, draw_img, pos_radio, sort_radio, revise_pos, img_count, seed],
+                        examples_per_page=5,
+                    )
+                    exp_gen.dataset.click(exp_gen_click, None, [image_width, image_height])
+
+                with gr.Tab("🎨Text Editing(文字编辑)") as mode_edit:
+                    with gr.Row(variant='compact'):
+                        ref_img = gr.Image(label='Ref(参考图)', source='upload')
+                        ori_img = gr.Image(label='Ori(原图)')
+
+                    def upload_ref(x):
+                        return [gr.Image(type="numpy", brush_radius=60, tool='sketch'),
+                                gr.Image(value=x)]
+
+                    def clear_ref(x):
+                        return gr.Image(source='upload', tool=None)
+                    ref_img.upload(upload_ref, ref_img, [ref_img, ori_img])
+                    ref_img.clear(clear_ref, ref_img, ref_img)
+                    with gr.Row():
+                        gr.Markdown("")
+                        run_edit = gr.Button(value="Run(运行)!", scale=0.3, elem_classes='run')
+                        gr.Markdown("")
+                    gr.Examples(
+                        [
+                            ['精美的书法作品,上面写着“志” “存” “高” ”远“', "example_images/ref10.jpg", "example_images/edit10.png", 4, 98053044],
+                            ['一个表情包,小猪说 "下班"', "example_images/ref2.jpg", "example_images/edit2.png", 2, 43304008],
+                            ['Characters written in chalk on the blackboard that says "DADDY"', "example_images/ref8.jpg", "example_images/edit8.png", 4, 73556391],
+                            ['一个中国古代铜钱,上面写着"乾" "隆"', "example_images/ref12.png", "example_images/edit12.png", 4, 89159482],
+                            ['黑板上写着"Here"', "example_images/ref11.jpg", "example_images/edit11.png", 2, 15353513],
+                            ['A letter picture that says "THER"', "example_images/ref6.jpg", "example_images/edit6.png", 4, 72321415],
+                            ['一堆水果, 中间写着“UIT”', "example_images/ref13.jpg", "example_images/edit13.png", 4, 54263567],
+                            ['一个漫画,上面写着" "', "example_images/ref14.png", "example_images/edit14.png", 4, 94081527],
+                            ['一个黄色标志牌,上边写着"不要" 和 "大意"', "example_images/ref3.jpg", "example_images/edit3.png", 2, 64010349],
+                            ['A cake with colorful characters that reads "EVERYDAY"', "example_images/ref7.jpg", "example_images/edit7.png", 4, 8943410],
+                            ['一个青铜鼎,上面写着"  "和"  "', "example_images/ref4.jpg", "example_images/edit4.png", 4, 71139289],
+                            ['一个建筑物前面的字母标牌, 上面写着 " "', "example_images/ref5.jpg", "example_images/edit5.png", 4, 50416289],
+                        ],
+                        [prompt, ori_img, ref_img, img_count, seed],
+                        examples_per_page=5,
+                    )
+        with gr.Column():
+            result_gallery = gr.Gallery(label='Result(结果)', show_label=True, preview=True, columns=2, allow_preview=True, height=600)
+            result_info = gr.Markdown('', visible=False)
+    ips = [prompt, pos_radio, sort_radio, revise_pos, show_debug, draw_img, rect_img, ref_img, ori_img, img_count, ddim_steps, image_width, image_height, strength, cfg_scale, seed, eta, a_prompt, n_prompt, *(rect_cb_list+rect_xywh_list)]
+    run_gen.click(fn=process, inputs=[gr.State('gen')] + ips, outputs=[result_gallery, result_info])
+    run_edit.click(fn=process, inputs=[gr.State('edit')] + ips, outputs=[result_gallery, result_info])
+
+block.launch(
+    server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1",
+    share=False,
+    root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else ""
+)
+# block.launch(server_name='0.0.0.0')
diff --git a/bert_tokenizer.py b/bert_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f7315134f6e8182c50c87d30dd976e144a2cd89
--- /dev/null
+++ b/bert_tokenizer.py
@@ -0,0 +1,421 @@
+# Copyright 2018 The Google AI Language Team Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes."""
+
+from __future__ import absolute_import, division, print_function
+import collections
+import re
+import unicodedata
+
+import six
+
+
+def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
+    """Checks whether the casing config is consistent with the checkpoint name."""
+
+    # The casing has to be passed in by the user and there is no explicit check
+    # as to whether it matches the checkpoint. The casing information probably
+    # should have been stored in the bert_config.json file, but it's not, so
+    # we have to heuristically detect it to validate.
+
+    if not init_checkpoint:
+        return
+
+    m = re.match('^.*?([A-Za-z0-9_-]+)/bert_model.ckpt', init_checkpoint)
+    if m is None:
+        return
+
+    model_name = m.group(1)
+
+    lower_models = [
+        'uncased_L-24_H-1024_A-16', 'uncased_L-12_H-768_A-12',
+        'multilingual_L-12_H-768_A-12', 'chinese_L-12_H-768_A-12'
+    ]
+
+    cased_models = [
+        'cased_L-12_H-768_A-12', 'cased_L-24_H-1024_A-16',
+        'multi_cased_L-12_H-768_A-12'
+    ]
+
+    is_bad_config = False
+    if model_name in lower_models and not do_lower_case:
+        is_bad_config = True
+        actual_flag = 'False'
+        case_name = 'lowercased'
+        opposite_flag = 'True'
+
+    if model_name in cased_models and do_lower_case:
+        is_bad_config = True
+        actual_flag = 'True'
+        case_name = 'cased'
+        opposite_flag = 'False'
+
+    if is_bad_config:
+        raise ValueError(
+            'You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. '
+            'However, `%s` seems to be a %s model, so you '
+            'should pass in `--do_lower_case=%s` so that the fine-tuning matches '
+            'how the model was pre-training. If this error is wrong, please '
+            'just comment out this check.' %
+            (actual_flag, init_checkpoint, model_name, case_name,
+             opposite_flag))
+
+
+def convert_to_unicode(text):
+    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
+    if six.PY3:
+        if isinstance(text, str):
+            return text
+        elif isinstance(text, bytes):
+            return text.decode('utf-8', 'ignore')
+        else:
+            raise ValueError('Unsupported string type: %s' % (type(text)))
+    elif six.PY2:
+        if isinstance(text, str):
+            return text.decode('utf-8', 'ignore')
+        elif isinstance(text, unicode):
+            return text
+        else:
+            raise ValueError('Unsupported string type: %s' % (type(text)))
+    else:
+        raise ValueError('Not running on Python2 or Python 3?')
+
+
+def printable_text(text):
+    """Returns text encoded in a way suitable for print or `tf.logging`."""
+
+    # These functions want `str` for both Python2 and Python3, but in one case
+    # it's a Unicode string and in the other it's a byte string.
+    if six.PY3:
+        if isinstance(text, str):
+            return text
+        elif isinstance(text, bytes):
+            return text.decode('utf-8', 'ignore')
+        else:
+            raise ValueError('Unsupported string type: %s' % (type(text)))
+    elif six.PY2:
+        if isinstance(text, str):
+            return text
+        elif isinstance(text, unicode):
+            return text.encode('utf-8')
+        else:
+            raise ValueError('Unsupported string type: %s' % (type(text)))
+    else:
+        raise ValueError('Not running on Python2 or Python 3?')
+
+
+def load_vocab(vocab_file):
+    """Loads a vocabulary file into a dictionary."""
+    vocab = collections.OrderedDict()
+    index = 0
+    with open(vocab_file, 'r', encoding='utf-8') as reader:
+        while True:
+            token = convert_to_unicode(reader.readline())
+            if not token:
+                break
+            token = token.strip()
+            vocab[token] = index
+            index += 1
+    return vocab
+
+
+def convert_by_vocab(vocab, items):
+    """Converts a sequence of [tokens|ids] using the vocab."""
+    output = []
+    for item in items:
+        output.append(vocab[item])
+    return output
+
+
+def convert_tokens_to_ids(vocab, tokens):
+    return convert_by_vocab(vocab, tokens)
+
+
+def convert_ids_to_tokens(inv_vocab, ids):
+    return convert_by_vocab(inv_vocab, ids)
+
+
+def whitespace_tokenize(text):
+    """Runs basic whitespace cleaning and splitting on a piece of text."""
+    text = text.strip()
+    if not text:
+        return []
+    tokens = text.split()
+    return tokens
+
+
+class FullTokenizer(object):
+    """Runs end-to-end tokenziation."""
+
+    def __init__(self, vocab_file, do_lower_case=True):
+        self.vocab = load_vocab(vocab_file)
+        self.inv_vocab = {v: k for k, v in self.vocab.items()}
+        self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
+        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+
+    def tokenize(self, text):
+        split_tokens = []
+        for token in self.basic_tokenizer.tokenize(text):
+            for sub_token in self.wordpiece_tokenizer.tokenize(token):
+                split_tokens.append(sub_token)
+
+        return split_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        return convert_by_vocab(self.vocab, tokens)
+
+    def convert_ids_to_tokens(self, ids):
+        return convert_by_vocab(self.inv_vocab, ids)
+
+    @staticmethod
+    def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
+        """ Converts a sequence of tokens (string) in a single string. """
+
+        def clean_up_tokenization(out_string):
+            """ Clean up a list of simple English tokenization artifacts
+            like spaces before punctuations and abreviated forms.
+            """
+            out_string = (
+                out_string.replace(' .', '.').replace(' ?', '?').replace(
+                    ' !', '!').replace(' ,', ',').replace(" ' ", "'").replace(
+                        " n't", "n't").replace(" 'm", "'m").replace(
+                            " 's", "'s").replace(" 've",
+                                                 "'ve").replace(" 're", "'re"))
+            return out_string
+
+        text = ' '.join(tokens).replace(' ##', '').strip()
+        if clean_up_tokenization_spaces:
+            clean_text = clean_up_tokenization(text)
+            return clean_text
+        else:
+            return text
+
+    def vocab_size(self):
+        return len(self.vocab)
+
+
+class BasicTokenizer(object):
+    """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+    def __init__(self, do_lower_case=True):
+        """Constructs a BasicTokenizer.
+
+        Args:
+          do_lower_case: Whether to lower case the input.
+        """
+        self.do_lower_case = do_lower_case
+
+    def tokenize(self, text):
+        """Tokenizes a piece of text."""
+        text = convert_to_unicode(text)
+        text = self._clean_text(text)
+
+        # This was added on November 1st, 2018 for the multilingual and Chinese
+        # models. This is also applied to the English models now, but it doesn't
+        # matter since the English models were not trained on any Chinese data
+        # and generally don't have any Chinese data in them (there are Chinese
+        # characters in the vocabulary because Wikipedia does have some Chinese
+        # words in the English Wikipedia.).
+        text = self._tokenize_chinese_chars(text)
+
+        orig_tokens = whitespace_tokenize(text)
+        split_tokens = []
+        for token in orig_tokens:
+            if self.do_lower_case:
+                token = token.lower()
+                token = self._run_strip_accents(token)
+            split_tokens.extend(self._run_split_on_punc(token))
+
+        output_tokens = whitespace_tokenize(' '.join(split_tokens))
+        return output_tokens
+
+    def _run_strip_accents(self, text):
+        """Strips accents from a piece of text."""
+        text = unicodedata.normalize('NFD', text)
+        output = []
+        for char in text:
+            cat = unicodedata.category(char)
+            if cat == 'Mn':
+                continue
+            output.append(char)
+        return ''.join(output)
+
+    def _run_split_on_punc(self, text):
+        """Splits punctuation on a piece of text."""
+        chars = list(text)
+        i = 0
+        start_new_word = True
+        output = []
+        while i < len(chars):
+            char = chars[i]
+            if _is_punctuation(char):
+                output.append([char])
+                start_new_word = True
+            else:
+                if start_new_word:
+                    output.append([])
+                start_new_word = False
+                output[-1].append(char)
+            i += 1
+
+        return [''.join(x) for x in output]
+
+    def _tokenize_chinese_chars(self, text):
+        """Adds whitespace around any CJK character."""
+        output = []
+        for char in text:
+            cp = ord(char)
+            if self._is_chinese_char(cp):
+                output.append(' ')
+                output.append(char)
+                output.append(' ')
+            else:
+                output.append(char)
+        return ''.join(output)
+
+    def _is_chinese_char(self, cp):
+        """Checks whether CP is the codepoint of a CJK character."""
+        # This defines a "chinese character" as anything in the CJK Unicode block:
+        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+        #
+        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+        # despite its name. The modern Korean Hangul alphabet is a different block,
+        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+        # space-separated words, so they are not treated specially and handled
+        # like the all of the other languages.
+        if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
+                or (cp >= 0x20000 and cp <= 0x2A6DF)
+                or (cp >= 0x2A700 and cp <= 0x2B73F)
+                or (cp >= 0x2B740 and cp <= 0x2B81F)
+                or (cp >= 0x2B820 and cp <= 0x2CEAF)
+                or (cp >= 0xF900 and cp <= 0xFAFF)
+                or (cp >= 0x2F800 and cp <= 0x2FA1F)):
+            return True
+
+        return False
+
+    def _clean_text(self, text):
+        """Performs invalid character removal and whitespace cleanup on text."""
+        output = []
+        for char in text:
+            cp = ord(char)
+            if cp == 0 or cp == 0xfffd or _is_control(char):
+                continue
+            if _is_whitespace(char):
+                output.append(' ')
+            else:
+                output.append(char)
+        return ''.join(output)
+
+
+class WordpieceTokenizer(object):
+    """Runs WordPiece tokenziation."""
+
+    def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=200):
+        self.vocab = vocab
+        self.unk_token = unk_token
+        self.max_input_chars_per_word = max_input_chars_per_word
+
+    def tokenize(self, text):
+        """Tokenizes a piece of text into its word pieces.
+
+        This uses a greedy longest-match-first algorithm to perform tokenization
+        using the given vocabulary.
+
+        For example:
+          input = "unaffable"
+          output = ["un", "##aff", "##able"]
+
+        Args:
+          text: A single token or whitespace separated tokens. This should have
+            already been passed through `BasicTokenizer.
+
+        Returns:
+          A list of wordpiece tokens.
+        """
+
+        text = convert_to_unicode(text)
+
+        output_tokens = []
+        for token in whitespace_tokenize(text):
+            chars = list(token)
+            if len(chars) > self.max_input_chars_per_word:
+                output_tokens.append(self.unk_token)
+                continue
+
+            is_bad = False
+            start = 0
+            sub_tokens = []
+            while start < len(chars):
+                end = len(chars)
+                cur_substr = None
+                while start < end:
+                    substr = ''.join(chars[start:end])
+                    if start > 0:
+                        substr = '##' + substr
+                    if substr in self.vocab:
+                        cur_substr = substr
+                        break
+                    end -= 1
+                if cur_substr is None:
+                    is_bad = True
+                    break
+                sub_tokens.append(cur_substr)
+                start = end
+
+            if is_bad:
+                output_tokens.append(self.unk_token)
+            else:
+                output_tokens.extend(sub_tokens)
+        return output_tokens
+
+
+def _is_whitespace(char):
+    """Checks whether `chars` is a whitespace character."""
+    # \t, \n, and \r are technically contorl characters but we treat them
+    # as whitespace since they are generally considered as such.
+    if char == ' ' or char == '\t' or char == '\n' or char == '\r':
+        return True
+    cat = unicodedata.category(char)
+    if cat == 'Zs':
+        return True
+    return False
+
+
+def _is_control(char):
+    """Checks whether `chars` is a control character."""
+    # These are technically control characters but we count them as whitespace
+    # characters.
+    if char == '\t' or char == '\n' or char == '\r':
+        return False
+    cat = unicodedata.category(char)
+    if cat in ('Cc', 'Cf'):
+        return True
+    return False
+
+
+def _is_punctuation(char):
+    """Checks whether `chars` is a punctuation character."""
+    cp = ord(char)
+    # We treat all non-letter/number ASCII as punctuation.
+    # Characters such as "^", "$", and "`" are not in the Unicode
+    # Punctuation class but we treat them as punctuation anyways, for
+    # consistency.
+    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
+            or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
+        return True
+    cat = unicodedata.category(char)
+    if cat.startswith('P'):
+        return True
+    return False
diff --git a/cldm/cldm.py b/cldm/cldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..978f5b1d3fc0b7f4e68ef0f48d74d4e7c026bce8
--- /dev/null
+++ b/cldm/cldm.py
@@ -0,0 +1,617 @@
+import einops
+import torch
+import torch as th
+import torch.nn as nn
+import copy
+from easydict import EasyDict as edict
+
+from ldm.modules.diffusionmodules.util import (
+    conv_nd,
+    linear,
+    zero_module,
+    timestep_embedding,
+)
+
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from ldm.modules.attention import SpatialTransformer
+from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.util import log_txt_as_img, exists, instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+from .recognizer import TextRecognizer, create_predictor
+
+
+def count_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+class ControlledUnetModel(UNetModel):
+    def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
+        hs = []
+        with torch.no_grad():
+            t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+            emb = self.time_embed(t_emb)
+            h = x.type(self.dtype)
+            for module in self.input_blocks:
+                h = module(h, emb, context)
+                hs.append(h)
+            h = self.middle_block(h, emb, context)
+
+        if control is not None:
+            h += control.pop()
+
+        for i, module in enumerate(self.output_blocks):
+            if only_mid_control or control is None:
+                h = torch.cat([h, hs.pop()], dim=1)
+            else:
+                h = torch.cat([h, hs.pop() + control.pop()], dim=1)
+            h = module(h, emb, context)
+
+        h = h.type(x.dtype)
+        return self.out(h)
+
+
+class ControlNet(nn.Module):
+    def __init__(
+            self,
+            image_size,
+            in_channels,
+            model_channels,
+            glyph_channels,
+            position_channels,
+            num_res_blocks,
+            attention_resolutions,
+            dropout=0,
+            channel_mult=(1, 2, 4, 8),
+            conv_resample=True,
+            dims=2,
+            use_checkpoint=False,
+            use_fp16=False,
+            num_heads=-1,
+            num_head_channels=-1,
+            num_heads_upsample=-1,
+            use_scale_shift_norm=False,
+            resblock_updown=False,
+            use_new_attention_order=False,
+            use_spatial_transformer=False,  # custom transformer support
+            transformer_depth=1,  # custom transformer support
+            context_dim=None,  # custom transformer support
+            n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
+            legacy=True,
+            disable_self_attentions=None,
+            num_attention_blocks=None,
+            disable_middle_self_attn=False,
+            use_linear_in_transformer=False,
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+        if context_dim is not None:
+            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+            from omegaconf.listconfig import ListConfig
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+        if num_head_channels == -1:
+            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+        self.dims = dims
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+                                 "as a list/tuple (per-level) with the same length as channel_mult")
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                  f"attention will still not be set.")
+
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+        self.glyph_block = TimestepEmbedSequential(
+            conv_nd(dims, glyph_channels, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 16, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 16, 16, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 32, 32, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 96, 96, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+            nn.SiLU(),
+        )
+
+        self.position_block = TimestepEmbedSequential(
+            conv_nd(dims, position_channels, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 16, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 16, 16, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 32, 32, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 32, 64, 3, padding=1, stride=2),
+            nn.SiLU(),
+        )
+
+        self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
+
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        # num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self.zero_convs.append(self.make_zero_conv(ch))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                self.zero_convs.append(self.make_zero_conv(ch))
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            # num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
+                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+                use_checkpoint=use_checkpoint
+            ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self.middle_block_out = self.make_zero_conv(ch)
+        self._feature_size += ch
+
+    def make_zero_conv(self, channels):
+        return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+
+    def forward(self, x, hint, text_info, timesteps, context, **kwargs):
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        emb = self.time_embed(t_emb)
+
+        # guided_hint from text_info
+        B, C, H, W = x.shape
+        glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
+        positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
+        enc_glyph = self.glyph_block(glyphs, emb, context)
+        enc_pos = self.position_block(positions, emb, context)
+        guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
+
+        outs = []
+
+        h = x.type(self.dtype)
+        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+            if guided_hint is not None:
+                h = module(h, emb, context)
+                h += guided_hint
+                guided_hint = None
+            else:
+                h = module(h, emb, context)
+            outs.append(zero_conv(h, emb, context))
+
+        h = self.middle_block(h, emb, context)
+        outs.append(self.middle_block_out(h, emb, context))
+
+        return outs
+
+
+class ControlLDM(LatentDiffusion):
+
+    def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.control_model = instantiate_from_config(control_stage_config)
+        self.control_key = control_key
+        self.glyph_key = glyph_key
+        self.position_key = position_key
+        self.only_mid_control = only_mid_control
+        self.control_scales = [1.0] * 13
+        self.loss_alpha = loss_alpha
+        self.loss_beta = loss_beta
+        self.with_step_weight = with_step_weight
+        self.use_vae_upsample = use_vae_upsample
+        self.latin_weight = latin_weight
+        if embedding_manager_config is not None and embedding_manager_config.params.valid:
+            self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
+            for param in self.embedding_manager.embedding_parameters():
+                param.requires_grad = True
+        else:
+            self.embedding_manager = None
+        if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
+            if embedding_manager_config.params.emb_type == 'ocr':
+                self.text_predictor = create_predictor().eval()
+                args = edict()
+                args.rec_image_shape = "3, 48, 320"
+                args.rec_batch_num = 6
+                args.rec_char_dict_path = './ocr_recog/ppocr_keys_v1.txt'
+                self.cn_recognizer = TextRecognizer(args, self.text_predictor)
+                for param in self.text_predictor.parameters():
+                    param.requires_grad = False
+                if self.embedding_manager:
+                    self.embedding_manager.recog = self.cn_recognizer
+
+    @torch.no_grad()
+    def get_input(self, batch, k, bs=None, *args, **kwargs):
+        if self.embedding_manager is None:  # fill in full caption
+            self.fill_caption(batch)
+        x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
+        control = batch[self.control_key]  # for log_images and loss_alpha, not real control
+        if bs is not None:
+            control = control[:bs]
+        control = control.to(self.device)
+        control = einops.rearrange(control, 'b h w c -> b c h w')
+        control = control.to(memory_format=torch.contiguous_format).float()
+
+        inv_mask = batch['inv_mask']
+        if bs is not None:
+            inv_mask = inv_mask[:bs]
+        inv_mask = inv_mask.to(self.device)
+        inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
+        inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
+
+        glyphs = batch[self.glyph_key]
+        gly_line = batch['gly_line']
+        positions = batch[self.position_key]
+        n_lines = batch['n_lines']
+        language = batch['language']
+        texts = batch['texts']
+        assert len(glyphs) == len(positions)
+        for i in range(len(glyphs)):
+            if bs is not None:
+                glyphs[i] = glyphs[i][:bs]
+                gly_line[i] = gly_line[i][:bs]
+                positions[i] = positions[i][:bs]
+                n_lines = n_lines[:bs]
+            glyphs[i] = glyphs[i].to(self.device)
+            gly_line[i] = gly_line[i].to(self.device)
+            positions[i] = positions[i].to(self.device)
+            glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
+            gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
+            positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
+            glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
+            gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
+            positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
+        info = {}
+        info['glyphs'] = glyphs
+        info['positions'] = positions
+        info['n_lines'] = n_lines
+        info['language'] = language
+        info['texts'] = texts
+        info['img'] = batch['img']  # nhwc, (-1,1)
+        info['masked_x'] = mx
+        info['gly_line'] = gly_line
+        info['inv_mask'] = inv_mask
+        return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
+
+    def apply_model(self, x_noisy, t, cond, *args, **kwargs):
+        assert isinstance(cond, dict)
+        diffusion_model = self.model.diffusion_model
+        _cond = torch.cat(cond['c_crossattn'], 1)
+        _hint = torch.cat(cond['c_concat'], 1)
+        control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
+        control = [c * scale for c, scale in zip(control, self.control_scales)]
+        eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
+
+        return eps
+
+    def instantiate_embedding_manager(self, config, embedder):
+        model = instantiate_from_config(config, embedder=embedder)
+        return model
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, N):
+        return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
+
+    def get_learned_conditioning(self, c):
+        if self.cond_stage_forward is None:
+            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+                if self.embedding_manager is not None and c['text_info'] is not None:
+                    self.embedding_manager.encode_text(c['text_info'])
+                if isinstance(c, dict):
+                    cond_txt = c['c_crossattn'][0]
+                else:
+                    cond_txt = c
+                if self.embedding_manager is not None:
+                    cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
+                else:
+                    cond_txt = self.cond_stage_model.encode(cond_txt)
+                if isinstance(c, dict):
+                    c['c_crossattn'][0] = cond_txt
+                else:
+                    c = cond_txt
+                if isinstance(c, DiagonalGaussianDistribution):
+                    c = c.mode()
+            else:
+                c = self.cond_stage_model(c)
+        else:
+            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+        return c
+
+    def fill_caption(self, batch, place_holder='*'):
+        bs = len(batch['n_lines'])
+        cond_list = copy.deepcopy(batch[self.cond_stage_key])
+        for i in range(bs):
+            n_lines = batch['n_lines'][i]
+            if n_lines == 0:
+                continue
+            cur_cap = cond_list[i]
+            for j in range(n_lines):
+                r_txt = batch['texts'][j][i]
+                cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
+            cond_list[i] = cur_cap
+        batch[self.cond_stage_key] = cond_list
+
+    @torch.no_grad()
+    def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c = self.get_input(batch, self.first_stage_key, bs=N)
+        if self.cond_stage_trainable:
+            with torch.no_grad():
+                c = self.get_learned_conditioning(c)
+        c_crossattn = c["c_crossattn"][0][:N]
+        c_cat = c["c_concat"][0][:N]
+        text_info = c["text_info"]
+        text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
+        text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
+        text_info['positions'] = [i[:N] for i in text_info['positions']]
+        text_info['n_lines'] = text_info['n_lines'][:N]
+        text_info['masked_x'] = text_info['masked_x'][:N]
+        text_info['img'] = text_info['img'][:N]
+
+        N = min(z.shape[0], N)
+        n_row = min(z.shape[0], n_row)
+        log["reconstruction"] = self.decode_first_stage(z)
+        log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
+        log["control"] = c_cat * 2.0 - 1.0
+        log["img"] = text_info['img'].permute(0, 3, 1, 2)  # log source image if needed
+        # get glyph
+        glyph_bs = torch.stack(text_info['glyphs'])
+        glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
+        log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
+        # fill caption
+        if not self.embedding_manager:
+            self.fill_caption(batch)
+        captions = batch[self.cond_stage_key]
+        log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
+                                                     batch_size=N, ddim=use_ddim,
+                                                     ddim_steps=ddim_steps, eta=ddim_eta)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_cross = self.get_unconditional_conditioning(N)
+            uc_cat = c_cat  # torch.zeros_like(c_cat)
+            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
+            samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
+                                                batch_size=N, ddim=use_ddim,
+                                                ddim_steps=ddim_steps, eta=ddim_eta,
+                                                unconditional_guidance_scale=unconditional_guidance_scale,
+                                                unconditional_conditioning=uc_full,
+                                                )
+            x_samples_cfg = self.decode_first_stage(samples_cfg)
+            log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+            pred_x0 = False  # wether log pred_x0
+            if pred_x0:
+                for idx in range(len(tmps['pred_x0'])):
+                    pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
+                    log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
+
+        return log
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        ddim_sampler = DDIMSampler(self)
+        b, c, h, w = cond["c_concat"][0].shape
+        shape = (self.channels, h // 8, w // 8)
+        samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
+        return samples, intermediates
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.control_model.parameters())
+        if self.embedding_manager:
+            params += list(self.embedding_manager.embedding_parameters())
+        if not self.sd_locked:
+            # params += list(self.model.diffusion_model.input_blocks.parameters())
+            # params += list(self.model.diffusion_model.middle_block.parameters())
+            params += list(self.model.diffusion_model.output_blocks.parameters())
+            params += list(self.model.diffusion_model.out.parameters())
+        if self.unlockKV:
+            nCount = 0
+            for name, param in self.model.diffusion_model.named_parameters():
+                if 'attn2.to_k' in name or 'attn2.to_v' in name:
+                    params += [param]
+                    nCount += 1
+            print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
+
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+    def low_vram_shift(self, is_diffusing):
+        if is_diffusing:
+            self.model = self.model.cuda()
+            self.control_model = self.control_model.cuda()
+            self.first_stage_model = self.first_stage_model.cpu()
+            self.cond_stage_model = self.cond_stage_model.cpu()
+        else:
+            self.model = self.model.cpu()
+            self.control_model = self.control_model.cpu()
+            self.first_stage_model = self.first_stage_model.cuda()
+            self.cond_stage_model = self.cond_stage_model.cuda()
diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py
new file mode 100644
index 0000000000000000000000000000000000000000..25b1bc947272ad14d7f7e5e4d1809005253b63d0
--- /dev/null
+++ b/cldm/ddim_hacked.py
@@ -0,0 +1,317 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               ucg_schedule=None,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list): ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+        samples, intermediates = self.ddim_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    ucg_schedule=ucg_schedule
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+                      ucg_schedule=None):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      dynamic_threshold=dynamic_threshold)
+            img, pred_x0 = outs
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            model_t = self.model.apply_model(x, t, c)
+            model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
+            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", 'not implemented'
+            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+        num_reference_steps = timesteps.shape[0]
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc='Encoding Image'):
+            t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
+            if unconditional_guidance_scale == 1.:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+                                           torch.cat((unconditional_conditioning, c))), 2)
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = alphas_next[i].sqrt() * (
+                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+            x_next = xt_weighted + weighted_noise_pred
+            if return_intermediates and i % (
+                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback: callback(i)
+
+        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+        if return_intermediates:
+            out.update({'intermediates': intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+    @torch.no_grad()
+    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+               use_original_steps=False, callback=None):
+
+        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+                                          unconditional_guidance_scale=unconditional_guidance_scale,
+                                          unconditional_conditioning=unconditional_conditioning)
+            if callback: callback(i)
+        return x_dec
diff --git a/cldm/embedding_manager.py b/cldm/embedding_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..be397d23d1db1d97d536d1bfdf0a2301807ad77e
--- /dev/null
+++ b/cldm/embedding_manager.py
@@ -0,0 +1,165 @@
+'''
+Copyright (c) Alibaba, Inc. and its affiliates.
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from ldm.modules.diffusionmodules.util import conv_nd, linear
+
+
+def get_clip_token_for_string(tokenizer, string):
+    batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
+                               return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+    tokens = batch_encoding["input_ids"]
+    assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
+    return tokens[0, 1]
+
+
+def get_bert_token_for_string(tokenizer, string):
+    token = tokenizer(string)
+    assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
+    token = token[0, 1]
+    return token
+
+
+def get_clip_vision_emb(encoder, processor, img):
+    _img = img.repeat(1, 3, 1, 1)*255
+    inputs = processor(images=_img, return_tensors="pt")
+    inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
+    outputs = encoder(**inputs)
+    emb = outputs.image_embeds
+    return emb
+
+
+def get_recog_emb(encoder, img_list):
+    _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
+    encoder.predictor.eval()
+    _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
+    return preds_neck
+
+
+def pad_H(x):
+    _, _, H, W = x.shape
+    p_top = (W - H) // 2
+    p_bot = W - H - p_top
+    return F.pad(x, (0, 0, p_top, p_bot))
+
+
+class EncodeNet(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(EncodeNet, self).__init__()
+        chan = 16
+        n_layer = 4  # downsample
+
+        self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
+        self.conv_list = nn.ModuleList([])
+        _c = chan
+        for i in range(n_layer):
+            self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
+            _c *= 2
+        self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+        self.act = nn.SiLU()
+
+    def forward(self, x):
+        x = self.act(self.conv1(x))
+        for layer in self.conv_list:
+            x = self.act(layer(x))
+        x = self.act(self.conv2(x))
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        return x
+
+
+class EmbeddingManager(nn.Module):
+    def __init__(
+            self,
+            embedder,
+            valid=True,
+            glyph_channels=20,
+            position_channels=1,
+            placeholder_string='*',
+            add_pos=False,
+            emb_type='ocr',
+            **kwargs
+    ):
+        super().__init__()
+        if hasattr(embedder, 'tokenizer'):  # using Stable Diffusion's CLIP encoder
+            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
+            token_dim = 768
+            if hasattr(embedder, 'vit'):
+                assert emb_type == 'vit'
+                self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
+            self.get_recog_emb = None
+        else:  # using LDM's BERT encoder
+            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
+            token_dim = 1280
+        self.token_dim = token_dim
+        self.emb_type = emb_type
+
+        self.add_pos = add_pos
+        if add_pos:
+            self.position_encoder = EncodeNet(position_channels, token_dim)
+        if emb_type == 'ocr':
+            self.proj = linear(40*64, token_dim)
+        if emb_type == 'conv':
+            self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
+
+        self.placeholder_token = get_token_for_string(placeholder_string)
+
+    def encode_text(self, text_info):
+        if self.get_recog_emb is None and self.emb_type == 'ocr':
+            self.get_recog_emb = partial(get_recog_emb, self.recog)
+
+        gline_list = []
+        pos_list = []
+        for i in range(len(text_info['n_lines'])):  # sample index in a batch
+            n_lines = text_info['n_lines'][i]
+            for j in range(n_lines):  # line
+                gline_list += [text_info['gly_line'][j][i:i+1]]
+                if self.add_pos:
+                    pos_list += [text_info['positions'][j][i:i+1]]
+
+        if len(gline_list) > 0:
+            if self.emb_type == 'ocr':
+                recog_emb = self.get_recog_emb(gline_list)
+                enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
+            elif self.emb_type == 'vit':
+                enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
+            elif self.emb_type == 'conv':
+                enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
+            if self.add_pos:
+                enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
+                enc_glyph = enc_glyph+enc_pos
+
+        self.text_embs_all = []
+        n_idx = 0
+        for i in range(len(text_info['n_lines'])):  # sample index in a batch
+            n_lines = text_info['n_lines'][i]
+            text_embs = []
+            for j in range(n_lines):  # line
+                text_embs += [enc_glyph[n_idx:n_idx+1]]
+                n_idx += 1
+            self.text_embs_all += [text_embs]
+
+    def forward(
+            self,
+            tokenized_text,
+            embedded_text,
+    ):
+        b, device = tokenized_text.shape[0], tokenized_text.device
+        for i in range(b):
+            idx = tokenized_text[i] == self.placeholder_token.to(device)
+            if sum(idx) > 0:
+                if i >= len(self.text_embs_all):
+                    print('truncation for log images...')
+                    break
+                text_emb = torch.cat(self.text_embs_all[i], dim=0)
+                if sum(idx) != len(text_emb):
+                    print('truncation for long caption...')
+                embedded_text[i][idx] = text_emb[:sum(idx)]
+        return embedded_text
+
+    def embedding_parameters(self):
+        return self.parameters()
diff --git a/cldm/hack.py b/cldm/hack.py
new file mode 100644
index 0000000000000000000000000000000000000000..454361e9d036cd1a6a79122c2fd16b489e4767b1
--- /dev/null
+++ b/cldm/hack.py
@@ -0,0 +1,111 @@
+import torch
+import einops
+
+import ldm.modules.encoders.modules
+import ldm.modules.attention
+
+from transformers import logging
+from ldm.modules.attention import default
+
+
+def disable_verbosity():
+    logging.set_verbosity_error()
+    print('logging improved.')
+    return
+
+
+def enable_sliced_attention():
+    ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
+    print('Enabled sliced_attention.')
+    return
+
+
+def hack_everything(clip_skip=0):
+    disable_verbosity()
+    ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
+    ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
+    print('Enabled clip hacks.')
+    return
+
+
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+    PAD = self.tokenizer.pad_token_id
+    EOS = self.tokenizer.eos_token_id
+    BOS = self.tokenizer.bos_token_id
+
+    def tokenize(t):
+        return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
+
+    def transformer_encode(t):
+        if self.clip_skip > 1:
+            rt = self.transformer(input_ids=t, output_hidden_states=True)
+            return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
+        else:
+            return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
+
+    def split(x):
+        return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
+
+    def pad(x, p, i):
+        return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+    raw_tokens_list = tokenize(text)
+    tokens_list = []
+
+    for raw_tokens in raw_tokens_list:
+        raw_tokens_123 = split(raw_tokens)
+        raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
+        raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+        tokens_list.append(raw_tokens_123)
+
+    tokens_list = torch.IntTensor(tokens_list).to(self.device)
+
+    feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
+    y = transformer_encode(feed)
+    z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
+
+    return z
+
+
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+    h = self.heads
+
+    q = self.to_q(x)
+    context = default(context, x)
+    k = self.to_k(context)
+    v = self.to_v(context)
+    del context, x
+
+    q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+    limit = k.shape[0]
+    att_step = 1
+    q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+    k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+    v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+
+    q_chunks.reverse()
+    k_chunks.reverse()
+    v_chunks.reverse()
+    sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+    del k, q, v
+    for i in range(0, limit, att_step):
+        q_buffer = q_chunks.pop()
+        k_buffer = k_chunks.pop()
+        v_buffer = v_chunks.pop()
+        sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+        del k_buffer, q_buffer
+        # attention, what we cannot get enough of, by chunks
+
+        sim_buffer = sim_buffer.softmax(dim=-1)
+
+        sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+        del v_buffer
+        sim[i:i + att_step, :, :] = sim_buffer
+
+        del sim_buffer
+    sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+    return self.to_out(sim)
diff --git a/cldm/logger.py b/cldm/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8803846f2a8979f87f3cf9ea5b12869439e62f
--- /dev/null
+++ b/cldm/logger.py
@@ -0,0 +1,76 @@
+import os
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+
+class ImageLogger(Callback):
+    def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
+                 rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
+                 log_images_kwargs=None):
+        super().__init__()
+        self.rescale = rescale
+        self.batch_freq = batch_frequency
+        self.max_images = max_images
+        if not increase_log_steps:
+            self.log_steps = [self.batch_freq]
+        self.clamp = clamp
+        self.disabled = disabled
+        self.log_on_batch_idx = log_on_batch_idx
+        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+        self.log_first_step = log_first_step
+
+    @rank_zero_only
+    def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
+        root = os.path.join(save_dir, "image_log", split)
+        for k in images:
+            grid = torchvision.utils.make_grid(images[k], nrow=4)
+            if self.rescale:
+                grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
+            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+            grid = grid.numpy()
+            grid = (grid * 255).astype(np.uint8)
+            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+            path = os.path.join(root, filename)
+            os.makedirs(os.path.split(path)[0], exist_ok=True)
+            Image.fromarray(grid).save(path)
+
+    def log_img(self, pl_module, batch, batch_idx, split="train"):
+        check_idx = batch_idx  # if self.log_on_batch_idx else pl_module.global_step
+        if (self.check_frequency(check_idx) and  # batch_idx % self.batch_freq == 0
+                hasattr(pl_module, "log_images") and
+                callable(pl_module.log_images) and
+                self.max_images > 0):
+            logger = type(pl_module.logger)
+
+            is_train = pl_module.training
+            if is_train:
+                pl_module.eval()
+
+            with torch.no_grad():
+                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+
+            for k in images:
+                N = min(images[k].shape[0], self.max_images)
+                images[k] = images[k][:N]
+                if isinstance(images[k], torch.Tensor):
+                    images[k] = images[k].detach().cpu()
+                    if self.clamp:
+                        images[k] = torch.clamp(images[k], -1., 1.)
+
+            self.log_local(pl_module.logger.save_dir, split, images,
+                           pl_module.global_step, pl_module.current_epoch, batch_idx)
+
+            if is_train:
+                pl_module.train()
+
+    def check_frequency(self, check_idx):
+        return check_idx % self.batch_freq == 0
+
+    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+        if not self.disabled:
+            self.log_img(pl_module, batch, batch_idx, split="train")
diff --git a/cldm/model.py b/cldm/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4404622322e05fd8fea4ba0a15912b88c3b2e2a
--- /dev/null
+++ b/cldm/model.py
@@ -0,0 +1,30 @@
+import os
+import torch
+
+from omegaconf import OmegaConf
+from ldm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+    return d.get('state_dict', d)
+
+
+def load_state_dict(ckpt_path, location='cpu'):
+    _, extension = os.path.splitext(ckpt_path)
+    if extension.lower() == ".safetensors":
+        import safetensors.torch
+        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+    else:
+        state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
+    state_dict = get_state_dict(state_dict)
+    print(f'Loaded state_dict from [{ckpt_path}]')
+    return state_dict
+
+
+def create_model(config_path, cond_stage_path=None):
+    config = OmegaConf.load(config_path)
+    if cond_stage_path:
+        config.model.params.cond_stage_config.params.version = cond_stage_path  # use pre-downloaded ckpts, in case blocked
+    model = instantiate_from_config(config.model).cpu()
+    print(f'Loaded model config from [{config_path}]')
+    return model
diff --git a/cldm/recognizer.py b/cldm/recognizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..2defa9117b73bd741ef2bdaecb320a33f58e6933
--- /dev/null
+++ b/cldm/recognizer.py
@@ -0,0 +1,303 @@
+'''
+Copyright (c) Alibaba, Inc. and its affiliates.
+'''
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+import cv2
+import numpy as np
+import math
+import traceback
+from easydict import EasyDict as edict
+import time
+from ocr_recog.RecModel import RecModel
+import torch
+import torch.nn.functional as F
+from skimage.transform._geometric import _umeyama as get_sym_mat
+
+
+def min_bounding_rect(img):
+    ret, thresh = cv2.threshold(img, 127, 255, 0)
+    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+    if len(contours) == 0:
+        print('Bad contours, using fake bbox...')
+        return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
+    max_contour = max(contours, key=cv2.contourArea)
+    rect = cv2.minAreaRect(max_contour)
+    box = cv2.boxPoints(rect)
+    box = np.int0(box)
+    # sort
+    x_sorted = sorted(box, key=lambda x: x[0])
+    left = x_sorted[:2]
+    right = x_sorted[2:]
+    left = sorted(left, key=lambda x: x[1])
+    (tl, bl) = left
+    right = sorted(right, key=lambda x: x[1])
+    (tr, br) = right
+    if tl[1] > bl[1]:
+        (tl, bl) = (bl, tl)
+    if tr[1] > br[1]:
+        (tr, br) = (br, tr)
+    return np.array([tl, tr, br, bl])
+
+
+def adjust_image(box, img):
+    pts1 = np.float32([box[0], box[1], box[2], box[3]])
+    width = max(np.linalg.norm(pts1[0]-pts1[1]), np.linalg.norm(pts1[2]-pts1[3]))
+    height = max(np.linalg.norm(pts1[0]-pts1[3]), np.linalg.norm(pts1[1]-pts1[2]))
+    pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
+    # get transform matrix
+    M = get_sym_mat(pts1, pts2, estimate_scale=True)
+    C, H, W = img.shape
+    T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]])
+    theta = np.linalg.inv(T @ M @ np.linalg.inv(T))
+    theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device)
+    grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True)
+    result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True)
+    result = torch.clamp(result.squeeze(0), 0, 255)
+    # crop
+    result = result[:, :int(height), :int(width)]
+    return result
+
+
+'''
+mask: numpy.ndarray, mask of textual, HWC
+src_img: torch.Tensor, source image, CHW
+'''
+def crop_image(src_img, mask):
+    box = min_bounding_rect(mask)
+    result = adjust_image(box, src_img)
+    if len(result.shape) == 2:
+        result = torch.stack([result]*3, axis=-1)
+    return result
+
+
+def create_predictor(model_dir=None, model_lang='ch', is_onnx=False):
+    model_file_path = model_dir
+    if model_file_path is not None and not os.path.exists(model_file_path):
+        raise ValueError("not find model file path {}".format(model_file_path))
+
+    if is_onnx:
+        import onnxruntime as ort
+        sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])  # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
+        return sess
+    else:
+        if model_lang == 'ch':
+            n_class = 6625
+        elif model_lang == 'en':
+            n_class = 97
+        else:
+            raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
+        rec_config = edict(
+            in_channels=3,
+            backbone=edict(type='MobileNetV1Enhance', scale=0.5, last_conv_stride=[1, 2], last_pool_type='avg'),
+            neck=edict(type='SequenceEncoder', encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True),
+            head=edict(type='CTCHead', fc_decay=0.00001, out_channels=n_class, return_feats=True)
+        )
+
+        rec_model = RecModel(rec_config)
+        if model_file_path is not None:
+            rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
+            rec_model.eval()
+        return rec_model.eval()
+
+
+def _check_image_file(path):
+    img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'}
+    return any([path.lower().endswith(e) for e in img_end])
+
+
+def get_image_file_list(img_file):
+    imgs_lists = []
+    if img_file is None or not os.path.exists(img_file):
+        raise Exception("not found any img file in {}".format(img_file))
+    if os.path.isfile(img_file) and _check_image_file(img_file):
+        imgs_lists.append(img_file)
+    elif os.path.isdir(img_file):
+        for single_file in os.listdir(img_file):
+            file_path = os.path.join(img_file, single_file)
+            if os.path.isfile(file_path) and _check_image_file(file_path):
+                imgs_lists.append(file_path)
+    if len(imgs_lists) == 0:
+        raise Exception("not found any img file in {}".format(img_file))
+    imgs_lists = sorted(imgs_lists)
+    return imgs_lists
+
+
+class TextRecognizer(object):
+    def __init__(self, args, predictor):
+        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
+        self.rec_batch_num = args.rec_batch_num
+        self.predictor = predictor
+        self.chars = self.get_char_dict(args.rec_char_dict_path)
+        self.char2id = {x: i for i, x in enumerate(self.chars)}
+        self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
+
+    # img: CHW
+    def resize_norm_img(self, img, max_wh_ratio):
+        imgC, imgH, imgW = self.rec_image_shape
+        assert imgC == img.shape[0]
+        imgW = int((imgH * max_wh_ratio))
+
+        h, w = img.shape[1:]
+        ratio = w / float(h)
+        if math.ceil(imgH * ratio) > imgW:
+            resized_w = imgW
+        else:
+            resized_w = int(math.ceil(imgH * ratio))
+        resized_image = torch.nn.functional.interpolate(
+            img.unsqueeze(0),
+            size=(imgH, resized_w),
+            mode='bilinear',
+            align_corners=True,
+        )
+        resized_image /= 255.0
+        resized_image -= 0.5
+        resized_image /= 0.5
+        padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
+        padding_im[:, :, 0:resized_w] = resized_image[0]
+        return padding_im
+
+    # img_list: list of tensors with shape chw 0-255
+    def pred_imglist(self, img_list, show_debug=False, is_ori=False):
+        img_num = len(img_list)
+        assert img_num > 0
+        # Calculate the aspect ratio of all text bars
+        width_list = []
+        for img in img_list:
+            width_list.append(img.shape[2] / float(img.shape[1]))
+        # Sorting can speed up the recognition process
+        indices = torch.from_numpy(np.argsort(np.array(width_list)))
+        batch_num = self.rec_batch_num
+        preds_all = [None] * img_num
+        preds_neck_all = [None] * img_num
+        for beg_img_no in range(0, img_num, batch_num):
+            end_img_no = min(img_num, beg_img_no + batch_num)
+            norm_img_batch = []
+
+            imgC, imgH, imgW = self.rec_image_shape[:3]
+            max_wh_ratio = imgW / imgH
+            for ino in range(beg_img_no, end_img_no):
+                h, w = img_list[indices[ino]].shape[1:]
+                if h > w * 1.2:
+                    img = img_list[indices[ino]]
+                    img = torch.transpose(img, 1, 2).flip(dims=[1])
+                    img_list[indices[ino]] = img
+                    h, w = img.shape[1:]
+                # wh_ratio = w * 1.0 / h
+                # max_wh_ratio = max(max_wh_ratio, wh_ratio)  # comment to not use different ratio
+            for ino in range(beg_img_no, end_img_no):
+                norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
+                norm_img = norm_img.unsqueeze(0)
+                norm_img_batch.append(norm_img)
+            norm_img_batch = torch.cat(norm_img_batch, dim=0)
+            if show_debug:
+                for i in range(len(norm_img_batch)):
+                    _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
+                    _img = (_img + 0.5)*255
+                    _img = _img[:, :, ::-1]
+                    file_name = f'{indices[beg_img_no + i]}'
+                    file_name = file_name + '_ori' if is_ori else file_name
+                    cv2.imwrite(file_name + '.jpg', _img)
+            if self.is_onnx:
+                input_dict = {}
+                input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy()
+                outputs = self.predictor.run(None, input_dict)
+                preds = {}
+                preds['ctc'] = torch.from_numpy(outputs[0])
+                preds['ctc_neck'] = [torch.zeros(1)] * img_num
+            else:
+                preds = self.predictor(norm_img_batch)
+            for rno in range(preds['ctc'].shape[0]):
+                preds_all[indices[beg_img_no + rno]] = preds['ctc'][rno]
+                preds_neck_all[indices[beg_img_no + rno]] = preds['ctc_neck'][rno]
+
+        return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
+
+    def get_char_dict(self, character_dict_path):
+        character_str = []
+        with open(character_dict_path, "rb") as fin:
+            lines = fin.readlines()
+            for line in lines:
+                line = line.decode('utf-8').strip("\n").strip("\r\n")
+                character_str.append(line)
+        dict_character = list(character_str)
+        dict_character = ['sos'] + dict_character + [' ']  # eos is space
+        return dict_character
+
+    def get_text(self, order):
+        char_list = [self.chars[text_id] for text_id in order]
+        return ''.join(char_list)
+
+    def decode(self, mat):
+        text_index = mat.detach().cpu().numpy().argmax(axis=1)
+        ignored_tokens = [0]
+        selection = np.ones(len(text_index), dtype=bool)
+        selection[1:] = text_index[1:] != text_index[:-1]
+        for ignored_token in ignored_tokens:
+            selection &= text_index != ignored_token
+        return text_index[selection], np.where(selection)[0]
+
+    def get_ctcloss(self, preds, gt_text, weight):
+        if not isinstance(weight, torch.Tensor):
+            weight = torch.tensor(weight).to(preds.device)
+        ctc_loss = torch.nn.CTCLoss(reduction='none')
+        log_probs = preds.log_softmax(dim=2).permute(1, 0, 2)  # NTC-->TNC
+        targets = []
+        target_lengths = []
+        for t in gt_text:
+            targets += [self.char2id.get(i, len(self.chars)-1) for i in t]
+            target_lengths += [len(t)]
+        targets = torch.tensor(targets).to(preds.device)
+        target_lengths = torch.tensor(target_lengths).to(preds.device)
+        input_lengths = torch.tensor([log_probs.shape[0]]*(log_probs.shape[1])).to(preds.device)
+        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
+        loss = loss / input_lengths * weight
+        return loss
+
+
+def main():
+    rec_model_dir = "./ocr_weights/ppv3_rec.pth"
+    predictor = create_predictor(rec_model_dir)
+    args = edict()
+    args.rec_image_shape = "3, 48, 320"
+    args.rec_char_dict_path = './ocr_weights/ppocr_keys_v1.txt'
+    args.rec_batch_num = 6
+    text_recognizer = TextRecognizer(args, predictor)
+    image_dir = './test_imgs_cn'
+    gt_text = ['韩国小馆']*14
+
+    image_file_list = get_image_file_list(image_dir)
+    valid_image_file_list = []
+    img_list = []
+
+    for image_file in image_file_list:
+        img = cv2.imread(image_file)
+        if img is None:
+            print("error in loading image:{}".format(image_file))
+            continue
+        valid_image_file_list.append(image_file)
+        img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
+    try:
+        tic = time.time()
+        times = []
+        for i in range(10):
+            preds, _ = text_recognizer.pred_imglist(img_list)  # get text
+            preds_all = preds.softmax(dim=2)
+            times += [(time.time()-tic)*1000.]
+            tic = time.time()
+        print(times)
+        print(np.mean(times[1:]) / len(preds_all))
+        weight = np.ones(len(gt_text))
+        loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
+        for i in range(len(valid_image_file_list)):
+            pred = preds_all[i]
+            order, idx = text_recognizer.decode(pred)
+            text = text_recognizer.get_text(order)
+            print(f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}')
+    except Exception as E:
+        print(traceback.format_exc(), E)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/dataset_util.py b/dataset_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e44640e88d6bf99ffbaf81051a0677bf9c11ba
--- /dev/null
+++ b/dataset_util.py
@@ -0,0 +1,77 @@
+
+import json
+import pathlib
+
+__all__ = ['load', 'save', 'show_bbox_on_image']
+
+
+def load(file_path: str):
+    file_path = pathlib.Path(file_path)
+    func_dict = {'.txt': load_txt, '.json': load_json, '.list': load_txt}
+    assert file_path.suffix in func_dict
+    return func_dict[file_path.suffix](file_path)
+
+
+def load_txt(file_path: str):
+    with open(file_path, 'r', encoding='utf8') as f:
+        content = [x.strip().strip('\ufeff').strip('\xef\xbb\xbf') for x in f.readlines()]
+    return content
+
+
+def load_json(file_path: str):
+    with open(file_path, 'r', encoding='utf8') as f:
+        content = json.load(f)
+    return content
+
+
+def save(data, file_path):
+    file_path = pathlib.Path(file_path)
+    func_dict = {'.txt': save_txt, '.json': save_json}
+    assert file_path.suffix in func_dict
+    return func_dict[file_path.suffix](data, file_path)
+
+
+def save_txt(data, file_path):
+    if not isinstance(data, list):
+        data = [data]
+    with open(file_path, mode='w', encoding='utf8') as f:
+        f.write('\n'.join(data))
+
+
+def save_json(data, file_path):
+    with open(file_path, 'w', encoding='utf-8') as json_file:
+        json.dump(data, json_file, ensure_ascii=False, indent=4)
+
+
+def show_bbox_on_image(image, polygons=None, txt=None, color=None, font_path='./font/Arial_Unicode.ttf'):
+    from PIL import ImageDraw, ImageFont
+    image = image.convert('RGB')
+    draw = ImageDraw.Draw(image)
+    if len(txt) == 0:
+        txt = None
+    if color is None:
+        color = (255, 0, 0)
+    if txt is not None:
+        font = ImageFont.truetype(font_path, 20)
+    for i, box in enumerate(polygons):
+        box = box[0]
+        if txt is not None:
+            draw.text((int(box[0][0]) + 20, int(box[0][1]) - 20), str(txt[i]), fill='red', font=font)
+        for j in range(len(box) - 1):
+            draw.line((box[j][0], box[j][1], box[j + 1][0], box[j + 1][1]), fill=color, width=2)
+        draw.line((box[-1][0], box[-1][1], box[0][0], box[0][1]), fill=color, width=2)
+    return image
+
+
+def show_glyphs(glyphs, name):
+    import numpy as np
+    import cv2
+    size = 64
+    gap = 5
+    n_char = 20
+    canvas = np.ones((size, size*n_char + gap*(n_char-1), 1))*0.5
+    x = 0
+    for i in range(glyphs.shape[-1]):
+        canvas[:, x:x + size, :] = glyphs[..., i:i+1]
+        x += size+gap
+    cv2.imwrite(name, canvas*255)
diff --git a/example_images/banner.png b/example_images/banner.png
new file mode 100644
index 0000000000000000000000000000000000000000..4dc7f1ac0828b1b97ac49598c694bfe94d0bd0c6
Binary files /dev/null and b/example_images/banner.png differ
diff --git a/example_images/edit1.png b/example_images/edit1.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9eb8bd3fe5c124b73850d8f559d1bfad559f78e
Binary files /dev/null and b/example_images/edit1.png differ
diff --git a/example_images/edit10.png b/example_images/edit10.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b953032224d6b466c4225c14b91e48cc8cb92dc
Binary files /dev/null and b/example_images/edit10.png differ
diff --git a/example_images/edit11.png b/example_images/edit11.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0dc5d3c961e80eb865c82456bc21070198aef0b
Binary files /dev/null and b/example_images/edit11.png differ
diff --git a/example_images/edit12.png b/example_images/edit12.png
new file mode 100644
index 0000000000000000000000000000000000000000..2e2b417c9cad122e7f12f39655e1e4c5443849f2
Binary files /dev/null and b/example_images/edit12.png differ
diff --git a/example_images/edit13.png b/example_images/edit13.png
new file mode 100644
index 0000000000000000000000000000000000000000..77840c4cf7801c7b231ad743bece6ff7f9eaeb8b
Binary files /dev/null and b/example_images/edit13.png differ
diff --git a/example_images/edit14.png b/example_images/edit14.png
new file mode 100644
index 0000000000000000000000000000000000000000..c823c41380a9c28330ffe555bdab48111749e387
Binary files /dev/null and b/example_images/edit14.png differ
diff --git a/example_images/edit2.png b/example_images/edit2.png
new file mode 100644
index 0000000000000000000000000000000000000000..19bc46f6db4f98c650db038009717702872c2cb3
Binary files /dev/null and b/example_images/edit2.png differ
diff --git a/example_images/edit3.png b/example_images/edit3.png
new file mode 100644
index 0000000000000000000000000000000000000000..343c0dbfff29da58ceab148f907bf30e5fe49e63
Binary files /dev/null and b/example_images/edit3.png differ
diff --git a/example_images/edit4.png b/example_images/edit4.png
new file mode 100644
index 0000000000000000000000000000000000000000..f6cff25d2383f08eb548927d74529a64bcfc221c
Binary files /dev/null and b/example_images/edit4.png differ
diff --git a/example_images/edit5.png b/example_images/edit5.png
new file mode 100644
index 0000000000000000000000000000000000000000..8d7b41d87a04b391d69d8dbcdc66f8554c46f551
Binary files /dev/null and b/example_images/edit5.png differ
diff --git a/example_images/edit6.png b/example_images/edit6.png
new file mode 100644
index 0000000000000000000000000000000000000000..6eedce647acf5fdfcf1bff80d4c91d95a003fded
Binary files /dev/null and b/example_images/edit6.png differ
diff --git a/example_images/edit7.png b/example_images/edit7.png
new file mode 100644
index 0000000000000000000000000000000000000000..5390d6fd193f47937ba58df0b9bb41344e4df321
Binary files /dev/null and b/example_images/edit7.png differ
diff --git a/example_images/edit8.png b/example_images/edit8.png
new file mode 100644
index 0000000000000000000000000000000000000000..036c1d27f326fac76beb64044591a4143b845e22
Binary files /dev/null and b/example_images/edit8.png differ
diff --git a/example_images/edit9.png b/example_images/edit9.png
new file mode 100644
index 0000000000000000000000000000000000000000..7010574e488c3dea76d1d4e82fbee1d5fbbafd06
Binary files /dev/null and b/example_images/edit9.png differ
diff --git a/example_images/gen1.png b/example_images/gen1.png
new file mode 100644
index 0000000000000000000000000000000000000000..e0ca934937e6785f5ac19667a6cee5b82ace3dea
Binary files /dev/null and b/example_images/gen1.png differ
diff --git a/example_images/gen10.png b/example_images/gen10.png
new file mode 100644
index 0000000000000000000000000000000000000000..13a440cc173de1289f195c98037e20bce1ded951
Binary files /dev/null and b/example_images/gen10.png differ
diff --git a/example_images/gen11.png b/example_images/gen11.png
new file mode 100644
index 0000000000000000000000000000000000000000..fbb56162d4b2862d50c8401883ddc7965f9d0b9c
Binary files /dev/null and b/example_images/gen11.png differ
diff --git a/example_images/gen12.png b/example_images/gen12.png
new file mode 100644
index 0000000000000000000000000000000000000000..b224a4638c16f6316f40bcbe7f2e59c3167ed5c9
Binary files /dev/null and b/example_images/gen12.png differ
diff --git a/example_images/gen13.png b/example_images/gen13.png
new file mode 100644
index 0000000000000000000000000000000000000000..b99f87a2fbc368097d66440a52cda01792778396
Binary files /dev/null and b/example_images/gen13.png differ
diff --git a/example_images/gen14.png b/example_images/gen14.png
new file mode 100644
index 0000000000000000000000000000000000000000..50695245ce9f9e55646b018f45ee29a19cc99b20
Binary files /dev/null and b/example_images/gen14.png differ
diff --git a/example_images/gen15.png b/example_images/gen15.png
new file mode 100644
index 0000000000000000000000000000000000000000..cf096e1c30ce81b783b0827f2706611b95be34d3
Binary files /dev/null and b/example_images/gen15.png differ
diff --git a/example_images/gen16.png b/example_images/gen16.png
new file mode 100644
index 0000000000000000000000000000000000000000..77f06a6eab3cef917160164b87d229fe49ec5073
Binary files /dev/null and b/example_images/gen16.png differ
diff --git a/example_images/gen2.png b/example_images/gen2.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc3285d63614e962c8d5702b8302ed25b8aed4ee
Binary files /dev/null and b/example_images/gen2.png differ
diff --git a/example_images/gen3.png b/example_images/gen3.png
new file mode 100644
index 0000000000000000000000000000000000000000..fbdf379f40d4687221b20b98e8ba612877fd00a9
Binary files /dev/null and b/example_images/gen3.png differ
diff --git a/example_images/gen4.png b/example_images/gen4.png
new file mode 100644
index 0000000000000000000000000000000000000000..ffb94dfdc2c3b307d559a2c933e1f89a70f80b02
Binary files /dev/null and b/example_images/gen4.png differ
diff --git a/example_images/gen5.png b/example_images/gen5.png
new file mode 100644
index 0000000000000000000000000000000000000000..335f9bd2fe0a493d3e343d50d5765f4033a968de
Binary files /dev/null and b/example_images/gen5.png differ
diff --git a/example_images/gen6.png b/example_images/gen6.png
new file mode 100644
index 0000000000000000000000000000000000000000..b6f03092edf35715294978973ea4d9b9b14cdf8e
Binary files /dev/null and b/example_images/gen6.png differ
diff --git a/example_images/gen7.png b/example_images/gen7.png
new file mode 100644
index 0000000000000000000000000000000000000000..385d1dbe4243e66be718359f2616b99037a62376
Binary files /dev/null and b/example_images/gen7.png differ
diff --git a/example_images/gen8.png b/example_images/gen8.png
new file mode 100644
index 0000000000000000000000000000000000000000..4b7dcf6e4f586f860f45b25bdc4d1fb7ad1be465
Binary files /dev/null and b/example_images/gen8.png differ
diff --git a/example_images/gen9.png b/example_images/gen9.png
new file mode 100644
index 0000000000000000000000000000000000000000..a6c68dadc3d686d359cdb69a35a81aeecd10ccc1
Binary files /dev/null and b/example_images/gen9.png differ
diff --git a/example_images/ref1.jpg b/example_images/ref1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e051af3cd9bcb77805d3e95ab9f9e88f65952773
Binary files /dev/null and b/example_images/ref1.jpg differ
diff --git a/example_images/ref10.jpg b/example_images/ref10.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5515ac2cc041a257cf7ae2944182eadecc162ec7
Binary files /dev/null and b/example_images/ref10.jpg differ
diff --git a/example_images/ref11.jpg b/example_images/ref11.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f44a8fe449fd532b84d519a09734f5f76c9709e9
Binary files /dev/null and b/example_images/ref11.jpg differ
diff --git a/example_images/ref12.png b/example_images/ref12.png
new file mode 100644
index 0000000000000000000000000000000000000000..2182c327c8582eba88dea839211769105c154f19
Binary files /dev/null and b/example_images/ref12.png differ
diff --git a/example_images/ref13.jpg b/example_images/ref13.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..227fc9b2b8a21d0671cb66468c7ef9a1113736c6
Binary files /dev/null and b/example_images/ref13.jpg differ
diff --git a/example_images/ref14.png b/example_images/ref14.png
new file mode 100644
index 0000000000000000000000000000000000000000..78a2b6af97727e743e3873b5076f65827d1bb0ff
Binary files /dev/null and b/example_images/ref14.png differ
diff --git a/example_images/ref2.jpg b/example_images/ref2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f4878c58a7b09c540a8deb1a7685b8c017cc1dbf
Binary files /dev/null and b/example_images/ref2.jpg differ
diff --git a/example_images/ref3.jpg b/example_images/ref3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..67aac42212aac484d6fed6b032dc3aca43189620
Binary files /dev/null and b/example_images/ref3.jpg differ
diff --git a/example_images/ref4.jpg b/example_images/ref4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..36897cc34399c3c3e849e70b161090a34920c309
Binary files /dev/null and b/example_images/ref4.jpg differ
diff --git a/example_images/ref5.jpg b/example_images/ref5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..33cdfe4c95a124b06f829d1f774e847fecefd6ad
Binary files /dev/null and b/example_images/ref5.jpg differ
diff --git a/example_images/ref6.jpg b/example_images/ref6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d375a3f1bb966114237ba2eb36ef6f8b506835df
Binary files /dev/null and b/example_images/ref6.jpg differ
diff --git a/example_images/ref7.jpg b/example_images/ref7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8eb7ef4901eac5b727d775fef83a6cace00eebab
Binary files /dev/null and b/example_images/ref7.jpg differ
diff --git a/example_images/ref8.jpg b/example_images/ref8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c36b3c5b2e340b665e6547eac85044a6af72392a
Binary files /dev/null and b/example_images/ref8.jpg differ
diff --git a/example_images/ref9.jpg b/example_images/ref9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8c23e56cf899f4874eb7d0e4b9adf056aa0ea5ef
Binary files /dev/null and b/example_images/ref9.jpg differ
diff --git a/font/Arial_Unicode.ttf b/font/Arial_Unicode.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..2bc3e08a7fbd6d2c40fad21b056a1ed865c1a2b9
--- /dev/null
+++ b/font/Arial_Unicode.ttf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:876af2cd4854644e7f3e7feb2f688997fdb3343c6df6693611209c9dfb47ccec
+size 23278008
diff --git a/javascript/bboxHint.js b/javascript/bboxHint.js
new file mode 100644
index 0000000000000000000000000000000000000000..a17a233c8472cbcd960f2b1f949cf9a5a8021d05
--- /dev/null
+++ b/javascript/bboxHint.js
@@ -0,0 +1,554 @@
+/*
+Part of the implementation is borrowed and modified from multidiffusion-upscaler-for-automatic1111,
+publicly available at https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111
+*/
+
+const BBOX_MAX_NUM = 16;
+const BBOX_WARNING_SIZE = 1280;
+const DEFAULT_X = 0.4;
+const DEFAULT_Y = 0.4;
+const DEFAULT_H = 0.2;
+const DEFAULT_W = 0.2;
+
+// ref: https://html-color.codes/
+const COLOR_MAP = [
+    ['#ff0000', 'rgba(255, 0, 0, 0.3)'],          // red
+    ['#ff9900', 'rgba(255, 153, 0, 0.3)'],        // orange
+    ['#ffff00', 'rgba(255, 255, 0, 0.3)'],        // yellow
+    ['#33cc33', 'rgba(51, 204, 51, 0.3)'],        // green
+    ['#33cccc', 'rgba(51, 204, 204, 0.3)'],       // indigo
+    ['#0066ff', 'rgba(0, 102, 255, 0.3)'],        // blue
+    ['#6600ff', 'rgba(102, 0, 255, 0.3)'],        // purple
+    ['#cc00cc', 'rgba(204, 0, 204, 0.3)'],        // dark pink
+    ['#ff6666', 'rgba(255, 102, 102, 0.3)'],      // light red
+    ['#ffcc66', 'rgba(255, 204, 102, 0.3)'],      // light orange
+    ['#99cc00', 'rgba(153, 204, 0, 0.3)'],        // lime green
+    ['#00cc99', 'rgba(0, 204, 153, 0.3)'],        // teal
+    ['#0099cc', 'rgba(0, 153, 204, 0.3)'],        // steel blue
+    ['#9933cc', 'rgba(153, 51, 204, 0.3)'],       // lavender
+    ['#ff3399', 'rgba(255, 51, 153, 0.3)'],       // hot pink
+    ['#996633', 'rgba(153, 102, 51, 0.3)'],       // brown
+];
+
+const RESIZE_BORDER = 5;
+const MOVE_BORDER = 5;
+
+const t2i_bboxes = new Array(BBOX_MAX_NUM).fill(null);
+const i2i_bboxes = new Array(BBOX_MAX_NUM).fill(null);
+
+function gradioApp() {
+    const elems = document.getElementsByTagName('gradio-app')
+    const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
+    return !!gradioShadowRoot ? gradioShadowRoot : document;
+}
+
+// ↓↓↓ called from gradio ↓↓↓
+
+function onCreateT2IRefClick(overwrite) {
+    let width, height;
+    if (overwrite) {
+        const overwriteInputs = gradioApp().querySelectorAll('#MD-overwrite-width-t2i input, #MD-overwrite-height-t2i input');
+        width  = parseInt(overwriteInputs[0].value);
+        height = parseInt(overwriteInputs[2].value);
+    } else {
+        const sizeInputs = gradioApp().querySelectorAll('#txt2img_width input, #txt2img_height input');
+        width  = parseInt(sizeInputs[0].value);
+        height = parseInt(sizeInputs[2].value);
+    }
+
+    if (isNaN(width))  width  = 512;
+    if (isNaN(height)) height = 512;
+
+    // Concat it to string to bypass the gradio bug
+    // 向黑恶势力低头
+    return width.toString() + 'x' + height.toString();
+}
+
+function onCreateI2IRefClick() {
+    const canvas = gradioApp().querySelector('#img2img_image img');
+    return canvas.src;
+}
+
+function onBoxEnableClick(is_t2i, idx, enable) {
+    let canvas = null;
+    let bboxes = null;
+    let locator = null;
+    if (is_t2i) {
+        // locator = () => gradioApp().querySelector('#MD-bbox-ref-t2i');
+        locator = () => gradioApp().querySelector('#MD-bbox-rect-t2i');
+        bboxes = t2i_bboxes;
+        
+    } else {
+        locator = () => gradioApp().querySelector('#MD-bbox-ref-i2i');
+        bboxes = i2i_bboxes;
+    }
+    ref_div = locator();
+    canvas = ref_div.querySelector('img');
+    if (!canvas) { return false; }
+
+    if (enable) {
+        // Check if the bounding box already exists
+        if (!bboxes[idx]) {
+            // Initialize bounding box
+            const bbox = [DEFAULT_X, DEFAULT_Y, DEFAULT_W, DEFAULT_H];
+            const colorMap = COLOR_MAP[idx % COLOR_MAP.length];
+            const div = document.createElement('div');
+            div.id = 'MD-bbox-' + (is_t2i ? 't2i-' : 'i2i-') + idx;
+            div.style.left       = '0px';
+            div.style.top        = '0px';
+            div.style.width      = '0px';
+            div.style.height     = '0px';
+            div.style.position   = 'absolute';
+            div.style.border     = '2px solid ' + colorMap[0];
+            div.style.background = colorMap[1];
+            div.style.zIndex     = '900';
+            div.style.display    = 'none';
+            // A text tip to warn the user if bbox is too large
+            const tip = document.createElement('span');
+            tip.id = 'MD-tip-' + (is_t2i ? 't2i-' : 'i2i-') + idx;
+            tip.style.left       = '50%';
+            tip.style.top        = '50%';
+            tip.style.position   = 'absolute';
+            tip.style.transform  = 'translate(-50%, -50%)';
+            tip.style.fontSize   = '12px';
+            tip.style.fontWeight = 'bold';
+            tip.style.textAlign  = 'center';
+            tip.style.color      = colorMap[0];
+            tip.style.zIndex     = '901';
+            tip.style.display    = 'none';
+            tip.innerHTML        = 'Warning: Region very large!<br>Take care of VRAM usage!';
+            div.appendChild(tip);
+            div.addEventListener('mousedown', function (e) {
+                if (e.button === 0) { onBoxMouseDown(e, is_t2i, idx); }
+            });
+            div.addEventListener('mousemove', function (e) {
+                updateCursorStyle(e, is_t2i, idx);
+            });
+
+            const shower = function() { // insert to DOM if necessary
+                if (!gradioApp().querySelector('#' + div.id)) {
+                    locator().appendChild(div);
+                }
+            }
+            bboxes[idx] = [div, bbox, shower];
+        }
+
+        // Show the bounding box
+        displayBox(canvas, is_t2i, bboxes[idx]);
+        return true;
+    } else {
+        if (!bboxes[idx]) { return false; }
+        const [div, bbox, shower] = bboxes[idx];
+        div.style.display = 'none';
+    }
+    return false;
+}
+
+function onBoxChange(is_t2i, idx, what, v) {
+    // This function handles all the changes of the bounding box
+    // Including the rendering and python slider update
+    let bboxes = null;
+    let canvas = null;
+    if (is_t2i) {
+        bboxes = t2i_bboxes;
+        canvas = gradioApp().querySelector('#MD-bbox-rect-t2i img');
+    } else {
+        bboxes = i2i_bboxes;
+        canvas = gradioApp().querySelector('#MD-bbox-ref-i2i img');
+    }
+    if (!bboxes[idx] || !canvas) {
+        switch (what) {
+            case 'x': return DEFAULT_X;
+            case 'y': return DEFAULT_Y;
+            case 'w': return DEFAULT_W;
+            case 'h': return DEFAULT_H;
+        }
+    }
+    const [div, bbox, shower] = bboxes[idx];
+    if (div.style.display === 'none') { return v; }
+
+    // parse trigger
+    switch (what) {
+        case 'x': bbox[0] = v; break;
+        case 'y': bbox[1] = v; break;
+        case 'w': bbox[2] = v; break;
+        case 'h': bbox[3] = v; break;
+    }
+    displayBox(canvas, is_t2i, bboxes[idx]);
+    return v;
+}
+
+// ↓↓↓ called from js ↓↓↓
+
+function getSeedInfo(is_t2i, id, current_seed) {
+    const info_id = is_t2i ? '#html_info_txt2img' : '#html_info_img2img';
+    const info_div = gradioApp().querySelector(info_id);
+    try{
+        current_seed = parseInt(current_seed);
+    } catch(e) {
+        current_seed = -1;
+    }
+    if (!info_div) return current_seed;
+    let info = info_div.innerHTML;
+    if (!info) return current_seed;
+    // remove all html tags
+    info = info.replace(/<[^>]*>/g, '');
+    // Find a json string 'region control:' in the info
+    // get its index
+    idx = info.indexOf('Region control');
+    if (idx == -1) return current_seed;
+    // get the json string (detect the bracket)
+    // find the first '{'
+    let start_idx = info.indexOf('{', idx);
+    let bracket = 1;
+    let end_idx = start_idx + 1;
+    while (bracket > 0 && end_idx < info.length) {
+        if (info[end_idx] == '{') bracket++;
+        if (info[end_idx] == '}') bracket--;
+        end_idx++;
+    }
+    if (bracket > 0) {
+        return current_seed;
+    }
+    // get the json string
+    let json_str = info.substring(start_idx, end_idx);
+    // replace the single quote to double quote
+    json_str = json_str.replace(/'/g, '"');
+    // replace python True to javascript true, False to false
+    json_str = json_str.replace(/True/g, 'true');
+    // parse the json string
+    let json = JSON.parse(json_str);
+    // get the seed if the region id is in the json
+    const region_id = 'Region ' + id.toString();
+    if (!(region_id in json)) return current_seed;
+    const region = json[region_id];
+    if (!('seed' in region)) return current_seed;
+    let seed = region['seed'];
+    try{
+        seed = parseInt(seed);
+    } catch(e) {
+        return current_seed;
+    }
+    return seed;
+}
+
+function displayBox(canvas, is_t2i, bbox_info) {
+    // check null input
+    const [div, bbox, shower] = bbox_info;
+    const [x, y, w, h] = bbox;
+    if (!canvas || !div || x == null || y == null || w == null || h == null) { return; }
+
+    // client: canvas widget display size
+    // natural: content image real size
+    let vpScale = Math.min(canvas.clientWidth / canvas.naturalWidth, canvas.clientHeight / canvas.naturalHeight);
+    let canvasCenterX = canvas.clientWidth  / 2;
+    let canvasCenterY = canvas.clientHeight / 2;
+    let scaledX = canvas.naturalWidth  * vpScale;
+    let scaledY = canvas.naturalHeight * vpScale;
+    let viewRectLeft  = canvasCenterX - scaledX / 2;
+    let viewRectRight = canvasCenterX + scaledX / 2;
+    let viewRectTop   = canvasCenterY - scaledY / 2;
+    let viewRectDown  = canvasCenterY + scaledY / 2;
+
+    let xDiv = viewRectLeft + scaledX * x;
+    let yDiv = viewRectTop  + scaledY * y;
+    let wDiv = Math.min(scaledX * w, viewRectRight - xDiv);
+    let hDiv = Math.min(scaledY * h, viewRectDown - yDiv);
+
+    // Calculate warning bbox size
+    let upscalerFactor = 1.0;
+    if (!is_t2i) {
+        const upscalerInput = parseFloat(gradioApp().querySelector('#MD-i2i-upscaler-factor input').value);
+        if (!isNaN(upscalerInput)) upscalerFactor = upscalerInput;
+    }
+    let maxSize = BBOX_WARNING_SIZE / upscalerFactor * vpScale;
+    let maxW = maxSize / scaledX;
+    let maxH = maxSize / scaledY;
+    if (w > maxW || h > maxH) {
+        div.querySelector('span').style.display = 'block';
+    } else {
+        div.querySelector('span').style.display = 'none';
+    }
+
+    // update <div> when not equal
+    div.style.left    = xDiv + 'px';
+    div.style.top     = yDiv + 'px';
+    div.style.width   = wDiv + 'px';
+    div.style.height  = hDiv + 'px';
+    div.style.display = 'block';
+
+    // insert it to DOM if not appear
+    shower();
+}
+
+function onBoxMouseDown(e, is_t2i, idx) {
+    let bboxes = null;
+    let canvas = null;
+    if (is_t2i) {
+        bboxes = t2i_bboxes;
+        canvas = gradioApp().querySelector('#MD-bbox-rect-t2i img');
+    } else {
+        bboxes = i2i_bboxes;
+        canvas = gradioApp().querySelector('#MD-bbox-ref-i2i img');
+    }
+    // Get the bounding box
+    if (!canvas || !bboxes[idx]) { return; }
+    const [div, bbox, shower] = bboxes[idx];
+
+    // Check if the click is inside the bounding box
+    const boxRect = div.getBoundingClientRect();
+    let mouseX = e.clientX;
+    let mouseY = e.clientY;
+
+    const resizeLeft   = mouseX >= boxRect.left && mouseX <= boxRect.left + RESIZE_BORDER;
+    const resizeRight  = mouseX >= boxRect.right - RESIZE_BORDER && mouseX <= boxRect.right;
+    const resizeTop    = mouseY >= boxRect.top && mouseY <= boxRect.top + RESIZE_BORDER;
+    const resizeBottom = mouseY >= boxRect.bottom - RESIZE_BORDER && mouseY <= boxRect.bottom;
+
+    const moveHorizontal = mouseX >= boxRect.left + MOVE_BORDER && mouseX <= boxRect.right  - MOVE_BORDER;
+    const moveVertical   = mouseY >= boxRect.top  + MOVE_BORDER && mouseY <= boxRect.bottom - MOVE_BORDER;
+
+    if (!resizeLeft && !resizeRight && !resizeTop && !resizeBottom && !moveHorizontal && !moveVertical) { return; }
+
+    const horizontalPivot = resizeLeft ? bbox[0] + bbox[2] : bbox[0];
+    const verticalPivot   = resizeTop  ? bbox[1] + bbox[3] : bbox[1];
+
+    // Canvas can be regarded as invariant during the drag operation
+    // Calculate in advance to reduce overhead
+
+    // Calculate viewport scale based on the current canvas size and the natural image size
+    let vpScale = Math.min(canvas.clientWidth / canvas.naturalWidth, canvas.clientHeight / canvas.naturalHeight);
+    let vpOffset = canvas.getBoundingClientRect();
+
+    // Calculate scaled dimensions of the canvas
+    let scaledX = canvas.naturalWidth * vpScale;
+    let scaledY = canvas.naturalHeight * vpScale;
+
+    // Calculate the canvas center and view rectangle coordinates
+    let canvasCenterX = (vpOffset.left + window.scrollX) + canvas.clientWidth  / 2;
+    let canvasCenterY = (vpOffset.top  + window.scrollY) + canvas.clientHeight / 2;
+    let viewRectLeft  = canvasCenterX - scaledX / 2 - window.scrollX;
+    let viewRectRight = canvasCenterX + scaledX / 2 - window.scrollX;
+    let viewRectTop   = canvasCenterY - scaledY / 2 - window.scrollY;
+    let viewRectDown  = canvasCenterY + scaledY / 2 - window.scrollY;
+
+    mouseX = Math.min(Math.max(mouseX, viewRectLeft), viewRectRight);
+    mouseY = Math.min(Math.max(mouseY, viewRectTop),  viewRectDown);
+    
+    //const accordion = gradioApp().querySelector(`#MD-accordion-${is_t2i ? 't2i' : 'i2i'}-${idx}`);
+    const accordion = gradioApp().querySelector('#MD-tab-t2i');
+
+    // Move or resize the bounding box on mousemove
+    function onMouseMove(e) {
+        // Prevent selecting anything irrelevant
+        e.preventDefault();
+
+        // Get the new mouse position
+        let newMouseX = e.clientX;
+        let newMouseY = e.clientY;
+
+        // clamp the mouse position to the view rectangle
+        newMouseX = Math.min(Math.max(newMouseX, viewRectLeft), viewRectRight);
+        newMouseY = Math.min(Math.max(newMouseY, viewRectTop),  viewRectDown);
+
+        // Calculate the mouse movement delta
+        const dx = (newMouseX - mouseX) / scaledX;
+        const dy = (newMouseY - mouseY) / scaledY;
+
+        // Update the mouse position
+        mouseX = newMouseX;
+        mouseY = newMouseY;
+
+        // if no move just return
+        if (dx === 0 && dy === 0) { return; }
+
+        // Update the mouse position
+        let [x, y, w, h] = bbox;
+        if (moveHorizontal && moveVertical) {
+            // If moving the bounding box
+            x = Math.min(Math.max(x + dx, 0), 1 - w);
+            y = Math.min(Math.max(y + dy, 0), 1 - h);
+        } else {
+            // If resizing the bounding box
+            if (resizeLeft || resizeRight) {
+                if (x < horizontalPivot) {
+                    if (dx <= w) {
+                        // If still within the left side of the pivot
+                        x = x + dx;
+                        w = w - dx;
+                    } else {
+                        // If crossing the pivot
+                        w = dx - w;
+                        x = horizontalPivot;
+                    }
+                } else {
+                    if (w + dx < 0) {
+                        // If still within the right side of the pivot
+                        x = horizontalPivot + w + dx;
+                        w = - dx - w;
+                    } else {
+                        // If crossing the pivot
+                        x = horizontalPivot;
+                        w = w + dx;
+                    }
+                }
+
+                // Clamp the bounding box to the image
+                if (x < 0) {
+                    w = w + x;
+                    x = 0;
+                } else if (x + w > 1) {
+                    w = 1 - x;
+                }
+            }
+            // Same as above, but for the vertical axis
+            if (resizeTop || resizeBottom) {
+                if (y < verticalPivot) {
+                    if (dy <= h) {
+                        y = y + dy;
+                        h = h - dy;
+                    } else {
+                        h = dy - h;
+                        y = verticalPivot;
+                    }
+                } else {
+                    if (h + dy < 0) {
+                        y = verticalPivot + h + dy;
+                        h = - dy - h;
+                    } else {
+                        y = verticalPivot;
+                        h = h + dy;
+                    }
+                }
+                if (y < 0) {
+                    h = h + y;
+                    y = 0;
+                } else if (y + h > 1) {
+                    h = 1 - y;
+                }
+            }
+        }
+        const [div, old_bbox, _] = bboxes[idx];
+
+        // If all the values are the same, just return
+        if (old_bbox[0] === x && old_bbox[1] === y && old_bbox[2] === w && old_bbox[3] === h) { return; }
+        // else update the bbox
+        const event = new Event('input');
+        const coords = [x, y, w, h];
+        // <del>The querySelector is not very efficient, so we query it once and reuse it</del>
+        // caching will result gradio bugs that stucks bbox and cannot move & drag
+        const sliderIds = ['x', 'y', 'w', 'h'];
+        // We try to select the input sliders
+        const sliderSelectors = sliderIds.map(id => `#MD-${is_t2i ? 't2i' : 'i2i'}-${idx}-${id} input`).join(', ');
+        let sliderInputs = accordion.querySelectorAll(sliderSelectors);
+        // alert(sliderInputs.length)
+        if (sliderInputs.length == 0) { 
+            // If we failed, the accordion is probably closed and sliders are removed in the dom, so we open it
+            accordion.querySelector('.label-wrap').click();
+            // and try again
+            sliderInputs = accordion.querySelectorAll(sliderSelectors);
+            // If we still failed, we just return
+            if (sliderInputs.length == 0) { return; }
+        }
+        for (let i = 0; i < 4; i++) {
+            if (old_bbox[i] !== coords[i]) {
+                sliderInputs[2*i].value = coords[i];
+                sliderInputs[2*i].dispatchEvent(event);
+            }
+        }
+    }
+
+    // Remove the mousemove and mouseup event listeners
+    function onMouseUp() {
+        document.removeEventListener('mousemove', onMouseMove);
+        document.removeEventListener('mouseup',   onMouseUp);
+    }
+
+    // Add the event listeners
+    document.addEventListener('mousemove', onMouseMove);
+    document.addEventListener('mouseup',   onMouseUp);
+}
+
+function updateCursorStyle(e, is_t2i, idx) {
+    // This function changes the cursor style when hovering over the bounding box
+    const bboxes = is_t2i ? t2i_bboxes : i2i_bboxes;
+    if (!bboxes[idx]) return;
+
+    const div = bboxes[idx][0];
+    const boxRect = div.getBoundingClientRect();
+    const mouseX = e.clientX;
+    const mouseY = e.clientY;
+
+    const resizeLeft   = mouseX >= boxRect.left && mouseX <= boxRect.left + RESIZE_BORDER;
+    const resizeRight  = mouseX >= boxRect.right - RESIZE_BORDER && mouseX <= boxRect.right;
+    const resizeTop    = mouseY >= boxRect.top && mouseY <= boxRect.top + RESIZE_BORDER;
+    const resizeBottom = mouseY >= boxRect.bottom - RESIZE_BORDER && mouseY <= boxRect.bottom;
+
+    if ((resizeLeft && resizeTop) || (resizeRight && resizeBottom)) {
+        div.style.cursor = 'nwse-resize';
+    } else if ((resizeLeft && resizeBottom) || (resizeRight && resizeTop)) {
+        div.style.cursor = 'nesw-resize';
+    } else if (resizeLeft || resizeRight) {
+        div.style.cursor = 'ew-resize';
+    } else if (resizeTop || resizeBottom) {
+        div.style.cursor = 'ns-resize';
+    } else {
+        div.style.cursor = 'move';
+    }
+}
+
+// ↓↓↓ auto called event listeners ↓↓↓
+
+function updateBoxes(is_t2i) {
+    // This function redraw all bounding boxes
+    let bboxes = null;
+    let canvas = null;
+    if (is_t2i) {
+        bboxes = t2i_bboxes;
+        canvas = gradioApp().querySelector('#MD-bbox-rect-t2i img');
+    } else {
+        bboxes = i2i_bboxes;
+        canvas = gradioApp().querySelector('#MD-bbox-ref-i2i img');
+    }
+    if (!canvas) return;
+
+    for (let idx = 0; idx < bboxes.length; idx++) {
+        if (!bboxes[idx]) continue;
+        const [div, bbox, shower] = bboxes[idx];
+        if (div.style.display === 'none') { return; }
+
+        displayBox(canvas, is_t2i, bboxes[idx]);
+    }
+}
+
+window.addEventListener('resize', _ => {
+    updateBoxes(true);
+    updateBoxes(false);
+});
+
+// ======== Gradio Bug Fix ========
+// For Gradio versions > 3.16.0 and < 3.29.0, the accordion DOM will be deleted when it is closed.
+// We need to judge the versions and listen to the accordion open event, rerender the bbox at that time.
+// This silly bug fix is only for compatibility, we recommend to update the gradio version to 3.29.0 or higher.
+try {
+    const GRADIO_VERSIONS = window.gradio_config["version"].split(".");
+    const gradio_major_version = parseInt(GRADIO_VERSIONS[0]);
+    const gradio_minor_version = parseInt(GRADIO_VERSIONS[1]);
+    if (gradio_major_version == 3 && gradio_minor_version > 16 && gradio_minor_version < 29) {
+        let listener = e => {
+            if (!e) { return; }
+            if (!e.target) { return; }
+            if (!e.target.classList) { return; }
+            if (!e.target.classList.contains('label-wrap')) { return; }
+            for (let tab of ['t2i', 'i2i']) {
+                const div = gradioApp().querySelector('#MD-bbox-control-' + tab +' div.label-wrap');
+                if (!div) { continue; }
+                updateBoxes(tab === 't2i');
+            }
+        };
+        window.addEventListener('DOMNodeInserted', listener);
+    }
+} catch (ignored) {
+    // If the above code failed, the gradio version shouldn't be in the range of 3.16.0 to 3.29.0, so we just return.
+}
+// ======== Gradio Bug Fix ========
diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/data/util.py b/ldm/data/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c
--- /dev/null
+++ b/ldm/data/util.py
@@ -0,0 +1,24 @@
+import torch
+
+from ldm.modules.midas.api import load_midas_transform
+
+
+class AddMiDaS(object):
+    def __init__(self, model_type):
+        super().__init__()
+        self.transform = load_midas_transform(model_type)
+
+    def pt2np(self, x):
+        x = ((x + 1.0) * .5).detach().cpu().numpy()
+        return x
+
+    def np2pt(self, x):
+        x = torch.from_numpy(x) * 2 - 1.
+        return x
+
+    def __call__(self, sample):
+        # sample['jpg'] is tensor hwc in [-1, 1] at this point
+        x = self.pt2np(sample['jpg'])
+        x = self.transform({"image": x})["image"]
+        sample['midas_in'] = x
+        return sample
\ No newline at end of file
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd
--- /dev/null
+++ b/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+    def __init__(self,
+                 ddconfig,
+                 lossconfig,
+                 embed_dim,
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 image_key="image",
+                 colorize_nlabels=None,
+                 monitor=None,
+                 ema_decay=None,
+                 learn_logvar=False
+                 ):
+        super().__init__()
+        self.learn_logvar = learn_logvar
+        self.image_key = image_key
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        assert ddconfig["double_z"]
+        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+        self.embed_dim = embed_dim
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels)==int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+
+        self.use_ema = ema_decay is not None
+        if self.use_ema:
+            self.ema_decay = ema_decay
+            assert 0. < ema_decay < 1.
+            self.model_ema = LitEma(self, decay=ema_decay)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.parameters())
+            self.model_ema.copy_to(self)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self)
+
+    def encode(self, x):
+        h = self.encoder(x)
+        moments = self.quant_conv(h)
+        posterior = DiagonalGaussianDistribution(moments)
+        return posterior
+
+    def decode(self, z):
+        z = self.post_quant_conv(z)
+        dec = self.decoder(z)
+        return dec
+
+    def forward(self, input, sample_posterior=True):
+        posterior = self.encode(input)
+        if sample_posterior:
+            z = posterior.sample()
+        else:
+            z = posterior.mode()
+        dec = self.decode(z)
+        return dec, posterior
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+
+        if optimizer_idx == 0:
+            # train encoder+decoder+logvar
+            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+                                            last_layer=self.get_last_layer(), split="train")
+            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+            return aeloss
+
+        if optimizer_idx == 1:
+            # train the discriminator
+            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+                                                last_layer=self.get_last_layer(), split="train")
+
+            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+            return discloss
+
+    def validation_step(self, batch, batch_idx):
+        log_dict = self._validation_step(batch, batch_idx)
+        with self.ema_scope():
+            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+        return log_dict
+
+    def _validation_step(self, batch, batch_idx, postfix=""):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+                                        last_layer=self.get_last_layer(), split="val"+postfix)
+
+        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+                                            last_layer=self.get_last_layer(), split="val"+postfix)
+
+        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+        self.log_dict(log_dict_ae)
+        self.log_dict(log_dict_disc)
+        return self.log_dict
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+            self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+        if self.learn_logvar:
+            print(f"{self.__class__.__name__}: Learning logvar")
+            ae_params_list.append(self.loss.logvar)
+        opt_ae = torch.optim.Adam(ae_params_list,
+                                  lr=lr, betas=(0.5, 0.9))
+        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+                                    lr=lr, betas=(0.5, 0.9))
+        return [opt_ae, opt_disc], []
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    @torch.no_grad()
+    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        if not only_inputs:
+            xrec, posterior = self(x)
+            if x.shape[1] > 3:
+                # colorize with random projection
+                assert xrec.shape[1] > 3
+                x = self.to_rgb(x)
+                xrec = self.to_rgb(xrec)
+            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+            log["reconstructions"] = xrec
+            if log_ema or self.use_ema:
+                with self.ema_scope():
+                    xrec_ema, posterior_ema = self(x)
+                    if x.shape[1] > 3:
+                        # colorize with random projection
+                        assert xrec_ema.shape[1] > 3
+                        xrec_ema = self.to_rgb(xrec_ema)
+                    log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+                    log["reconstructions_ema"] = xrec_ema
+        log["inputs"] = x
+        return log
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+        return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+    def __init__(self, *args, vq_interface=False, **kwargs):
+        self.vq_interface = vq_interface
+        super().__init__()
+
+    def encode(self, x, *args, **kwargs):
+        return x
+
+    def decode(self, x, *args, **kwargs):
+        return x
+
+    def quantize(self, x, *args, **kwargs):
+        if self.vq_interface:
+            return x, None, [None, None, None]
+        return x
+
+    def forward(self, x, *args, **kwargs):
+        return x
+
diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..676813b7ecaa3f82338d4e99d6549251872d31a1
--- /dev/null
+++ b/ldm/models/diffusion/ddim.py
@@ -0,0 +1,354 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               ucg_schedule=None,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list): ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                # cbs = len(ctmp[0])
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+        samples, intermediates = self.ddim_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    ucg_schedule=ucg_schedule
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+                      ucg_schedule=None):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img], "index": [10000]}
+        time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      dynamic_threshold=dynamic_threshold)
+            img, pred_x0 = outs
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+                intermediates['index'].append(index)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            x_in = torch.cat([x] * 2)
+            t_in = torch.cat([t] * 2)
+            if isinstance(c, dict):
+                assert isinstance(unconditional_conditioning, dict)
+                c_in = dict()
+                for k in c:
+                    if isinstance(c[k], list):
+                        c_in[k] = [torch.cat([
+                            unconditional_conditioning[k][i],
+                            c[k][i]]) for i in range(len(c[k]))]
+                    elif isinstance(c[k], dict):
+                        c_in[k] = dict()
+                        for key in c[k]:
+                            if isinstance(c[k][key], list):
+                                if not isinstance(c[k][key][0], torch.Tensor):
+                                    continue
+                                c_in[k][key] = [torch.cat([
+                                    unconditional_conditioning[k][key][i],
+                                    c[k][key][i]]) for i in range(len(c[k][key]))]
+                            else:
+                                c_in[k][key] = torch.cat([
+                                    unconditional_conditioning[k][key],
+                                    c[k][key]])
+
+                    else:
+                        c_in[k] = torch.cat([
+                                unconditional_conditioning[k],
+                                c[k]])
+            elif isinstance(c, list):
+                c_in = list()
+                assert isinstance(unconditional_conditioning, list)
+                for i in range(len(c)):
+                    c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+            else:
+                c_in = torch.cat([unconditional_conditioning, c])
+            model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", 'not implemented'
+            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc='Encoding Image'):
+            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+            if unconditional_guidance_scale == 1.:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+                                           torch.cat((unconditional_conditioning, c))), 2)
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = alphas_next[i].sqrt() * (
+                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+            x_next = xt_weighted + weighted_noise_pred
+            if return_intermediates and i % (
+                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback: callback(i)
+
+        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+        if return_intermediates:
+            out.update({'intermediates': intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+    @torch.no_grad()
+    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+               use_original_steps=False, callback=None):
+
+        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+                                          unconditional_guidance_scale=unconditional_guidance_scale,
+                                          unconditional_conditioning=unconditional_conditioning)
+            if callback: callback(i)
+        return x_dec
\ No newline at end of file
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..43a3d590a0e5621afde6ed9e3899094fe2322555
--- /dev/null
+++ b/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1799 @@
+"""
+Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+from cldm.recognizer import crop_image
+import cv2
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+                         'crossattn': 'c_crossattn',
+                         'adm': 'y'}
+
+PRINT_DEBUG = False
+
+
+def print_grad(grad):
+    # print('Gradient:', grad)
+    # print(grad.shape)
+    a = grad.max()
+    b = grad.min()
+    # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}')
+    s = 255./(a-b)
+    c = 255*(-b/(a-b))
+    grad = grad * s + c
+    # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}')
+    img = grad[0].permute(1, 2, 0).detach().cpu().numpy()
+    if img.shape[0] == 512:
+        cv2.imwrite('grad-img.jpg', img)
+    elif img.shape[0] == 64:
+        cv2.imwrite('grad-latent.jpg', img)
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+    return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+    # classic DDPM with Gaussian diffusion, in image space
+    def __init__(self,
+                 unet_config,
+                 timesteps=1000,
+                 beta_schedule="linear",
+                 loss_type="l2",
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 load_only_unet=False,
+                 monitor="val/loss",
+                 use_ema=True,
+                 first_stage_key="image",
+                 image_size=256,
+                 channels=3,
+                 log_every_t=100,
+                 clip_denoised=True,
+                 linear_start=1e-4,
+                 linear_end=2e-2,
+                 cosine_s=8e-3,
+                 given_betas=None,
+                 original_elbo_weight=0.,
+                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+                 l_simple_weight=1.,
+                 conditioning_key=None,
+                 parameterization="eps",  # all assuming fixed variance schedules
+                 scheduler_config=None,
+                 use_positional_encodings=False,
+                 learn_logvar=False,
+                 logvar_init=0.,
+                 make_it_fit=False,
+                 ucg_training=None,
+                 reset_ema=False,
+                 reset_num_ema_updates=False,
+                 ):
+        super().__init__()
+        assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+        self.parameterization = parameterization
+        print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+        self.cond_stage_model = None
+        self.clip_denoised = clip_denoised
+        self.log_every_t = log_every_t
+        self.first_stage_key = first_stage_key
+        self.image_size = image_size  # try conv?
+        self.channels = channels
+        self.use_positional_encodings = use_positional_encodings
+        self.model = DiffusionWrapper(unet_config, conditioning_key)
+        count_params(self.model, verbose=True)
+        self.use_ema = use_ema
+        if self.use_ema:
+            self.model_ema = LitEma(self.model)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        self.use_scheduler = scheduler_config is not None
+        if self.use_scheduler:
+            self.scheduler_config = scheduler_config
+
+        self.v_posterior = v_posterior
+        self.original_elbo_weight = original_elbo_weight
+        self.l_simple_weight = l_simple_weight
+
+        if monitor is not None:
+            self.monitor = monitor
+        self.make_it_fit = make_it_fit
+        if reset_ema: assert exists(ckpt_path)
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+            if reset_ema:
+                assert self.use_ema
+                print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+        self.loss_type = loss_type
+
+        self.learn_logvar = learn_logvar
+        logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+        if self.learn_logvar:
+            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+        else:
+            self.register_buffer('logvar', logvar)
+
+        self.ucg_training = ucg_training or dict()
+        if self.ucg_training:
+            self.ucg_prng = np.random.RandomState()
+
+    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        if exists(given_betas):
+            betas = given_betas
+        else:
+            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+                                       cosine_s=cosine_s)
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        # np.save('1.npy', alphas_cumprod)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+                1. - alphas_cumprod) + self.v_posterior * betas
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        if self.parameterization == "eps":
+            lvlb_weights = self.betas ** 2 / (
+                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+        elif self.parameterization == "x0":
+            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+        elif self.parameterization == "v":
+            lvlb_weights = torch.ones_like(self.betas ** 2 / (
+                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+        else:
+            raise NotImplementedError("mu not supported")
+        lvlb_weights[0] = lvlb_weights[1]
+        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+        assert not torch.isnan(self.lvlb_weights).all()
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.model.parameters())
+            self.model_ema.copy_to(self.model)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.model.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    @torch.no_grad()
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        if self.make_it_fit:
+            n_params = len([name for name, _ in
+                            itertools.chain(self.named_parameters(),
+                                            self.named_buffers())])
+            for name, param in tqdm(
+                    itertools.chain(self.named_parameters(),
+                                    self.named_buffers()),
+                    desc="Fitting old weights to new weights",
+                    total=n_params
+            ):
+                if not name in sd:
+                    continue
+                old_shape = sd[name].shape
+                new_shape = param.shape
+                assert len(old_shape) == len(new_shape)
+                if len(new_shape) > 2:
+                    # we only modify first two axes
+                    assert new_shape[2:] == old_shape[2:]
+                # assumes first axis corresponds to output dim
+                if not new_shape == old_shape:
+                    new_param = param.clone()
+                    old_param = sd[name]
+                    if len(new_shape) == 1:
+                        for i in range(new_param.shape[0]):
+                            new_param[i] = old_param[i % old_shape[0]]
+                    elif len(new_shape) >= 2:
+                        for i in range(new_param.shape[0]):
+                            for j in range(new_param.shape[1]):
+                                new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+                        n_used_old = torch.ones(old_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_old[j % old_shape[1]] += 1
+                        n_used_new = torch.zeros(new_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_new[j] = n_used_old[j % old_shape[1]]
+
+                        n_used_new = n_used_new[None, :]
+                        while len(n_used_new.shape) < len(new_shape):
+                            n_used_new = n_used_new.unsqueeze(-1)
+                        new_param /= n_used_new
+
+                    sd[name] = new_param
+
+        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+            sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys:\n {missing}")
+        if len(unexpected) > 0:
+            print(f"\nUnexpected Keys:\n {unexpected}")
+
+    def q_mean_variance(self, x_start, t):
+        """
+        Get the distribution q(x_t | x_0).
+        :param x_start: the [N x C x ...] tensor of noiseless inputs.
+        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+        """
+        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def predict_start_from_z_and_v(self, x_t, t, v):
+        # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+        )
+
+    def predict_eps_from_z_and_v(self, x_t, t, v):
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, clip_denoised: bool):
+        model_out = self.model(x, t)
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample_loop(self, shape, return_intermediates=False):
+        device = self.betas.device
+        b = shape[0]
+        img = torch.randn(shape, device=device)
+        intermediates = [img]
+        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+                                clip_denoised=self.clip_denoised)
+            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+                intermediates.append(img)
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, batch_size=16, return_intermediates=False):
+        image_size = self.image_size
+        channels = self.channels
+        return self.p_sample_loop((batch_size, channels, image_size, image_size),
+                                  return_intermediates=return_intermediates)
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+    def get_v(self, x, noise, t):
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+        )
+
+    def get_loss(self, pred, target, mean=True):
+        if self.loss_type == 'l1':
+            loss = (target - pred).abs()
+            if mean:
+                loss = loss.mean()
+        elif self.loss_type == 'l2':
+            if mean:
+                loss = torch.nn.functional.mse_loss(target, pred)
+            else:
+                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+        else:
+            raise NotImplementedError("unknown loss type '{loss_type}'")
+
+        return loss
+
+    def p_losses(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_out = self.model(x_noisy, t)
+
+        loss_dict = {}
+        if self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "v":
+            target = self.get_v(x_start, noise, t)
+        else:
+            raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+
+        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+        log_prefix = 'train' if self.training else 'val'
+
+        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+        loss_simple = loss.mean() * self.l_simple_weight
+
+        loss_vlb = (self.lvlb_weights[t] * loss).mean()
+        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+        loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+        loss_dict.update({f'{log_prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    def forward(self, x, *args, **kwargs):
+        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        return self.p_losses(x, t, *args, **kwargs)
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = rearrange(x, 'b h w c -> b c h w')
+        x = x.to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def shared_step(self, batch):
+        x = self.get_input(batch, self.first_stage_key)
+        loss, loss_dict = self(x)
+        return loss, loss_dict
+
+    def training_step(self, batch, batch_idx):
+        for k in self.ucg_training:
+            p = self.ucg_training[k]["p"]
+            val = self.ucg_training[k]["val"]
+            if val is None:
+                val = ""
+            for i in range(len(batch[k])):
+                if self.ucg_prng.choice(2, p=[1 - p, p]):
+                    batch[k][i] = val
+
+        loss, loss_dict = self.shared_step(batch)
+
+        self.log_dict(loss_dict, prog_bar=True,
+                      logger=True, on_step=True, on_epoch=True)
+
+        self.log("global_step", self.global_step,
+                 prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        if self.use_scheduler:
+            lr = self.optimizers().param_groups[0]['lr']
+            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        return loss
+
+    @torch.no_grad()
+    def validation_step(self, batch, batch_idx):
+        _, loss_dict_no_ema = self.shared_step(batch)
+        with self.ema_scope():
+            _, loss_dict_ema = self.shared_step(batch)
+            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self.model)
+
+    def _get_rows_from_list(self, samples):
+        n_imgs_per_row = len(samples)
+        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.first_stage_key)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        x = x.to(self.device)[:N]
+        log["inputs"] = x
+
+        # get diffusion row
+        diffusion_row = list()
+        x_start = x[:n_row]
+
+        for t in range(self.num_timesteps):
+            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                t = t.to(self.device).long()
+                noise = torch.randn_like(x_start)
+                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+                diffusion_row.append(x_noisy)
+
+        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+        if sample:
+            # get denoise row
+            with self.ema_scope("Plotting"):
+                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+            log["samples"] = samples
+            log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.learn_logvar:
+            params = params + [self.logvar]
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+
+class LatentDiffusion(DDPM):
+    """main class"""
+
+    def __init__(self,
+                 first_stage_config,
+                 cond_stage_config,
+                 num_timesteps_cond=None,
+                 cond_stage_key="image",
+                 cond_stage_trainable=False,
+                 concat_mode=True,
+                 cond_stage_forward=None,
+                 conditioning_key=None,
+                 scale_factor=1.0,
+                 scale_by_std=False,
+                 force_null_conditioning=False,
+                 *args, **kwargs):
+        self.force_null_conditioning = force_null_conditioning
+        self.num_timesteps_cond = default(num_timesteps_cond, 1)
+        self.scale_by_std = scale_by_std
+        assert self.num_timesteps_cond <= kwargs['timesteps']
+        # for backwards compatibility after implementation of DiffusionWrapper
+        if conditioning_key is None:
+            conditioning_key = 'concat' if concat_mode else 'crossattn'
+        if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+            conditioning_key = None
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        reset_ema = kwargs.pop("reset_ema", False)
+        reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+        ignore_keys = kwargs.pop("ignore_keys", [])
+        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+        self.concat_mode = concat_mode
+        self.cond_stage_trainable = cond_stage_trainable
+        self.cond_stage_key = cond_stage_key
+        try:
+            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+        except:
+            self.num_downs = 0
+        if not scale_by_std:
+            self.scale_factor = scale_factor
+        else:
+            self.register_buffer('scale_factor', torch.tensor(scale_factor))
+        self.instantiate_first_stage(first_stage_config)
+        self.instantiate_cond_stage(cond_stage_config)
+        self.cond_stage_forward = cond_stage_forward
+        self.clip_denoised = False
+        self.bbox_tokenizer = None
+
+        self.restarted_from_ckpt = False
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+            self.restarted_from_ckpt = True
+            if reset_ema:
+                assert self.use_ema
+                print(
+                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+    def make_cond_schedule(self, ):
+        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+        self.cond_ids[:self.num_timesteps_cond] = ids
+
+    @rank_zero_only
+    @torch.no_grad()
+    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+        # only for very first batch
+        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+            # set rescale weight to 1./std of encodings
+            print("### USING STD-RESCALING ###")
+            x = super().get_input(batch, self.first_stage_key)
+            x = x.to(self.device)
+            encoder_posterior = self.encode_first_stage(x)
+            z = self.get_first_stage_encoding(encoder_posterior).detach()
+            del self.scale_factor
+            self.register_buffer('scale_factor', 1. / z.flatten().std())
+            print(f"setting self.scale_factor to {self.scale_factor}")
+            print("### USING STD-RESCALING ###")
+
+    def register_schedule(self,
+                          given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+        self.shorten_cond_schedule = self.num_timesteps_cond > 1
+        if self.shorten_cond_schedule:
+            self.make_cond_schedule()
+
+    def instantiate_first_stage(self, config):
+        model = instantiate_from_config(config)
+        self.first_stage_model = model.eval()
+        self.first_stage_model.train = disabled_train
+        for param in self.first_stage_model.parameters():
+            param.requires_grad = False
+
+    def instantiate_cond_stage(self, config):
+        if not self.cond_stage_trainable:
+            if config == "__is_first_stage__":
+                print("Using first stage also as cond stage.")
+                self.cond_stage_model = self.first_stage_model
+            elif config == "__is_unconditional__":
+                print(f"Training {self.__class__.__name__} as an unconditional model.")
+                self.cond_stage_model = None
+                # self.be_unconditional = True
+            else:
+                model = instantiate_from_config(config)
+                self.cond_stage_model = model.eval()
+                self.cond_stage_model.train = disabled_train
+                for param in self.cond_stage_model.parameters():
+                    param.requires_grad = False
+        else:
+            assert config != '__is_first_stage__'
+            assert config != '__is_unconditional__'
+            model = instantiate_from_config(config)
+            self.cond_stage_model = model
+
+    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+        denoise_row = []
+        for zd in tqdm(samples, desc=desc):
+            denoise_row.append(self.decode_first_stage(zd.to(self.device),
+                                                       force_not_quantize=force_no_decoder_quantization))
+        n_imgs_per_row = len(denoise_row)
+        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
+        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    def get_first_stage_encoding(self, encoder_posterior):
+        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+            z = encoder_posterior.sample()
+        elif isinstance(encoder_posterior, torch.Tensor):
+            z = encoder_posterior
+        else:
+            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+        return self.scale_factor * z
+
+    def get_learned_conditioning(self, c):
+        if self.cond_stage_forward is None:
+            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+                c = self.cond_stage_model.encode(c)
+                if isinstance(c, DiagonalGaussianDistribution):
+                    c = c.mode()
+            else:
+                c = self.cond_stage_model(c)
+        else:
+            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+        return c
+
+    def meshgrid(self, h, w):
+        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+        arr = torch.cat([y, x], dim=-1)
+        return arr
+
+    def delta_border(self, h, w):
+        """
+        :param h: height
+        :param w: width
+        :return: normalized distance to image border,
+         wtith min distance = 0 at border and max dist = 0.5 at image center
+        """
+        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+        arr = self.meshgrid(h, w) / lower_right_corner
+        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+        return edge_dist
+
+    def get_weighting(self, h, w, Ly, Lx, device):
+        weighting = self.delta_border(h, w)
+        weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+                               self.split_input_params["clip_max_weight"], )
+        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+        if self.split_input_params["tie_braker"]:
+            L_weighting = self.delta_border(Ly, Lx)
+            L_weighting = torch.clip(L_weighting,
+                                     self.split_input_params["clip_min_tie_weight"],
+                                     self.split_input_params["clip_max_tie_weight"])
+
+            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+            weighting = weighting * L_weighting
+        return weighting
+
+    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
+        """
+        :param x: img of size (bs, c, h, w)
+        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+        """
+        bs, nc, h, w = x.shape
+
+        # number of crops in image
+        Ly = (h - kernel_size[0]) // stride[0] + 1
+        Lx = (w - kernel_size[1]) // stride[1] + 1
+
+        if uf == 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+        elif uf > 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+                                dilation=1, padding=0,
+                                stride=(stride[0] * uf, stride[1] * uf))
+            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+        elif df > 1 and uf == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+                                dilation=1, padding=0,
+                                stride=(stride[0] // df, stride[1] // df))
+            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+        else:
+            raise NotImplementedError
+
+        return fold, unfold, normalization, weighting
+
+    @torch.no_grad()
+    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+                  cond_key=None, return_original_cond=False, bs=None, return_x=False, mask_k=None):
+        x = super().get_input(batch, k)
+        if bs is not None:
+            x = x[:bs]
+        x = x.to(self.device)
+        encoder_posterior = self.encode_first_stage(x)
+        z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+        if mask_k is not None:
+            mx = super().get_input(batch, mask_k)
+            if bs is not None:
+                mx = mx[:bs]
+            mx = mx.to(self.device)
+            encoder_posterior = self.encode_first_stage(mx)
+            mx = self.get_first_stage_encoding(encoder_posterior).detach()
+
+        if self.model.conditioning_key is not None and not self.force_null_conditioning:
+            if cond_key is None:
+                cond_key = self.cond_stage_key
+            if cond_key != self.first_stage_key:
+                if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+                    xc = batch[cond_key]
+                elif cond_key in ['class_label', 'cls']:
+                    xc = batch
+                else:
+                    xc = super().get_input(batch, cond_key).to(self.device)
+            else:
+                xc = x
+            if not self.cond_stage_trainable or force_c_encode:
+                if isinstance(xc, dict) or isinstance(xc, list):
+                    c = self.get_learned_conditioning(xc)
+                else:
+                    c = self.get_learned_conditioning(xc.to(self.device))
+            else:
+                c = xc
+            if bs is not None:
+                c = c[:bs]
+
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                ckey = __conditioning_keys__[self.model.conditioning_key]
+                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+        else:
+            c = None
+            xc = None
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                c = {'pos_x': pos_x, 'pos_y': pos_y}
+        out = [z, c]
+        if return_first_stage_outputs:
+            xrec = self.decode_first_stage(z)
+            out.extend([x, xrec])
+        if return_x:
+            out.extend([x])
+        if return_original_cond:
+            out.append(xc)
+        if mask_k:
+            out.append(mx)
+        return out
+
+    @torch.no_grad()
+    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+        z = 1. / self.scale_factor * z
+        return self.first_stage_model.decode(z)
+
+    def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+        z = 1. / self.scale_factor * z
+        return self.first_stage_model.decode(z)
+
+    @torch.no_grad()
+    def encode_first_stage(self, x):
+        return self.first_stage_model.encode(x)
+
+    def shared_step(self, batch, **kwargs):
+        x, c = self.get_input(batch, self.first_stage_key)
+        loss = self(x, c)
+        return loss
+
+    def forward(self, x, c, *args, **kwargs):
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long()
+        if self.model.conditioning_key is not None:
+            assert c is not None
+            if self.cond_stage_trainable:
+                c = self.get_learned_conditioning(c)
+            if self.shorten_cond_schedule:  # TODO: drop this option
+                tc = self.cond_ids[t].to(self.device)
+                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+        return self.p_losses(x, c, t, *args, **kwargs)
+
+    def apply_model(self, x_noisy, t, cond, return_ids=False):
+        if isinstance(cond, dict):
+            # hybrid case, cond is expected to be a dict
+            pass
+        else:
+            if not isinstance(cond, list):
+                cond = [cond]
+            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+            cond = {key: cond}
+
+        x_recon = self.model(x_noisy, t, **cond)
+
+        if isinstance(x_recon, tuple) and not return_ids:
+            return x_recon[0]
+        else:
+            return x_recon
+
+    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+    def _prior_bpd(self, x_start):
+        """
+        Get the prior KL term for the variational lower-bound, measured in
+        bits-per-dim.
+        This term can't be optimized, as it only depends on the encoder.
+        :param x_start: the [N x C x ...] tensor of inputs.
+        :return: a batch of [N] KL values (in bits), one per batch element.
+        """
+        batch_size = x_start.shape[0]
+        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+        return mean_flat(kl_prior) / np.log(2.0)
+
+    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+                        return_x0=False, score_corrector=None, corrector_kwargs=None):
+        t_in = t
+        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+        if score_corrector is not None:
+            assert self.parameterization == "eps"
+            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+        if return_codebook_ids:
+            model_out, logits = model_out
+
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        else:
+            raise NotImplementedError()
+
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+        if quantize_denoised:
+            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        if return_codebook_ids:
+            return model_mean, posterior_variance, posterior_log_variance, logits
+        elif return_x0:
+            return model_mean, posterior_variance, posterior_log_variance, x_recon
+        else:
+            return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+        b, *_, device = *x.shape, x.device
+        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+                                       return_codebook_ids=return_codebook_ids,
+                                       quantize_denoised=quantize_denoised,
+                                       return_x0=return_x0,
+                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+        if return_codebook_ids:
+            raise DeprecationWarning("Support dropped.")
+            model_mean, _, model_log_variance, logits = outputs
+        elif return_x0:
+            model_mean, _, model_log_variance, x0 = outputs
+        else:
+            model_mean, _, model_log_variance = outputs
+
+        noise = noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+        if return_codebook_ids:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+        if return_x0:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+        else:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+                              log_every_t=None):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        timesteps = self.num_timesteps
+        if batch_size is not None:
+            b = batch_size if batch_size is not None else shape[0]
+            shape = [batch_size] + list(shape)
+        else:
+            b = batch_size = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=self.device)
+        else:
+            img = x_T
+        intermediates = []
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+                        total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+        if type(temperature) == float:
+            temperature = [temperature] * timesteps
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img, x0_partial = self.p_sample(img, cond, ts,
+                                            clip_denoised=self.clip_denoised,
+                                            quantize_denoised=quantize_denoised, return_x0=True,
+                                            temperature=temperature[i], noise_dropout=noise_dropout,
+                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(x0_partial)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_loop(self, cond, shape, return_intermediates=False,
+                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, start_T=None,
+                      log_every_t=None):
+
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        device = self.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        intermediates = [img]
+        if timesteps is None:
+            timesteps = self.num_timesteps
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+
+        if mask is not None:
+            assert x0 is not None
+            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img = self.p_sample(img, cond, ts,
+                                clip_denoised=self.clip_denoised,
+                                quantize_denoised=quantize_denoised)
+            if mask is not None:
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(img)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+               verbose=True, timesteps=None, quantize_denoised=False,
+               mask=None, x0=None, shape=None, **kwargs):
+        if shape is None:
+            shape = (batch_size, self.channels, self.image_size, self.image_size)
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+        return self.p_sample_loop(cond,
+                                  shape,
+                                  return_intermediates=return_intermediates, x_T=x_T,
+                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+                                  mask=mask, x0=x0)
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        if ddim:
+            ddim_sampler = DDIMSampler(self)
+            shape = (self.channels, self.image_size, self.image_size)
+            samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+                                                         shape, cond, verbose=False, **kwargs)
+
+        else:
+            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+                                                 return_intermediates=True, **kwargs)
+
+        return samples, intermediates
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, batch_size, null_label=None):
+        if null_label is not None:
+            xc = null_label
+            if isinstance(xc, ListConfig):
+                xc = list(xc)
+            if isinstance(xc, dict) or isinstance(xc, list):
+                c = self.get_learned_conditioning(xc)
+            else:
+                if hasattr(xc, "to"):
+                    xc = xc.to(self.device)
+                c = self.get_learned_conditioning(xc)
+        else:
+            if self.cond_stage_key in ["class_label", "cls"]:
+                xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+                return self.get_learned_conditioning(xc)
+            else:
+                raise NotImplementedError("todo")
+        if isinstance(c, list):  # in case the encoder gives us a list
+            for i in range(len(c)):
+                c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+        else:
+            c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+        return c
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+                                           return_first_stage_outputs=True,
+                                           force_c_encode=True,
+                                           return_original_cond=True,
+                                           bs=N)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', "cls"]:
+                try:
+                    xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                    log['conditioning'] = xc
+                except KeyError:
+                    # probably no "human_label" in batch
+                    pass
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+                    self.first_stage_model, IdentityFirstStage):
+                # also display when quantizing x0 while sampling
+                with ema_scope("Plotting Quantized Denoised"):
+                    samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                             ddim_steps=ddim_steps, eta=ddim_eta,
+                                                             quantize_denoised=True)
+                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+                    #                                      quantize_denoised=True)
+                x_samples = self.decode_first_stage(samples.to(self.device))
+                log["samples_x0_quantized"] = x_samples
+
+        if unconditional_guidance_scale > 1.0:
+            uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            if self.model.conditioning_key == "crossattn-adm":
+                uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        if inpaint:
+            # make a simple center square
+            b, h, w = z.shape[0], z.shape[2], z.shape[3]
+            mask = torch.ones(N, h, w).to(self.device)
+            # zeros will be filled in
+            mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+            mask = mask[:, None, ...]
+            with ema_scope("Plotting Inpaint"):
+                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_inpainting"] = x_samples
+            log["mask"] = mask
+
+            # outpaint
+            mask = 1. - mask
+            with ema_scope("Plotting Outpaint"):
+                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_outpainting"] = x_samples
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(c,
+                                                               shape=(self.channels, self.image_size, self.image_size),
+                                                               batch_size=N)
+            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+            log["progressive_row"] = prog_row
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.cond_stage_trainable:
+            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+            params = params + list(self.cond_stage_model.parameters())
+        if self.learn_logvar:
+            print('Diffusion model optimizing logvar')
+            params.append(self.logvar)
+        opt = torch.optim.AdamW(params, lr=lr)
+        if self.use_scheduler:
+            assert 'target' in self.scheduler_config
+            scheduler = instantiate_from_config(self.scheduler_config)
+
+            print("Setting up LambdaLR scheduler...")
+            scheduler = [
+                {
+                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+                    'interval': 'step',
+                    'frequency': 1
+                }]
+            return [opt], scheduler
+        return opt
+
+    @torch.no_grad()
+    def to_rgb(self, x):
+        x = x.float()
+        if not hasattr(self, "colorize"):
+            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+        x = nn.functional.conv2d(x, weight=self.colorize)
+        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+        return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+    def __init__(self, diff_model_config, conditioning_key):
+        super().__init__()
+        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+        self.diffusion_model = instantiate_from_config(diff_model_config)
+        self.conditioning_key = conditioning_key
+        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+        if self.conditioning_key is None:
+            out = self.diffusion_model(x, t)
+        elif self.conditioning_key == 'concat':
+            xc = torch.cat([x] + c_concat, dim=1)
+            out = self.diffusion_model(xc, t)
+        elif self.conditioning_key == 'crossattn':
+            if not self.sequential_cross_attn:
+                cc = torch.cat(c_crossattn, 1)
+            else:
+                cc = c_crossattn
+            out = self.diffusion_model(x, t, context=cc)
+        elif self.conditioning_key == 'hybrid':
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc)
+        elif self.conditioning_key == 'hybrid-adm':
+            assert c_adm is not None
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+        elif self.conditioning_key == 'crossattn-adm':
+            assert c_adm is not None
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(x, t, context=cc, y=c_adm)
+        elif self.conditioning_key == 'adm':
+            cc = c_crossattn[0]
+            out = self.diffusion_model(x, t, y=cc)
+        else:
+            raise NotImplementedError()
+
+        return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+    def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+        assert not self.cond_stage_trainable
+        self.instantiate_low_stage(low_scale_config)
+        self.low_scale_key = low_scale_key
+        self.noise_level_key = noise_level_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+        if not log_mode:
+            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+        else:
+            z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                                  force_c_encode=True, return_original_cond=True, bs=bs)
+        x_low = batch[self.low_scale_key][:bs]
+        x_low = rearrange(x_low, 'b h w c -> b c h w')
+        x_low = x_low.to(memory_format=torch.contiguous_format).float()
+        zx, noise_level = self.low_scale_model(x_low)
+        if self.noise_level_key is not None:
+            # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+            raise NotImplementedError('TODO')
+
+        all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+        if log_mode:
+            # TODO: maybe disable if too expensive
+            x_low_rec = self.low_scale_model.decode(zx)
+            return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+                   plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+                   unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+                                                                          log_mode=True)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        log["x_lr"] = x_low
+        log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', 'cls']:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                log['conditioning'] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            # TODO explore better "unconditional" choices for the other keys
+            # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+            uc = dict()
+            for k in c:
+                if k == "c_crossattn":
+                    assert isinstance(c[k], list) and len(c[k]) == 1
+                    uc[k] = [uc_tmp]
+                elif k == "c_adm":  # todo: only run with text-based guidance?
+                    assert isinstance(c[k], torch.Tensor)
+                    #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+                    uc[k] = c[k]
+                elif isinstance(c[k], list):
+                    uc[k] = [c[k][i] for i in range(len(c[k]))]
+                else:
+                    uc[k] = c[k]
+
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(c,
+                                                               shape=(self.channels, self.image_size, self.image_size),
+                                                               batch_size=N)
+            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+            log["progressive_row"] = prog_row
+
+        return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+    """
+         Basis for different finetunas, such as inpainting or depth2image
+         To disable finetuning mode, set finetune_keys to None
+    """
+
+    def __init__(self,
+                 concat_keys: tuple,
+                 finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+                                "model_ema.diffusion_modelinput_blocks00weight"
+                                ),
+                 keep_finetune_dims=4,
+                 # if model was trained without concat mode before and we would like to keep these channels
+                 c_concat_log_start=None,  # to log reconstruction of c_concat codes
+                 c_concat_log_end=None,
+                 *args, **kwargs
+                 ):
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        ignore_keys = kwargs.pop("ignore_keys", list())
+        super().__init__(*args, **kwargs)
+        self.finetune_keys = finetune_keys
+        self.concat_keys = concat_keys
+        self.keep_dims = keep_finetune_dims
+        self.c_concat_log_start = c_concat_log_start
+        self.c_concat_log_end = c_concat_log_end
+        if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+        if exists(ckpt_path):
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+
+            # make it explicit, finetune by including extra input channels
+            if exists(self.finetune_keys) and k in self.finetune_keys:
+                new_entry = None
+                for name, param in self.named_parameters():
+                    if name in self.finetune_keys:
+                        print(
+                            f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+                        new_entry = torch.zeros_like(param)  # zero init
+                assert exists(new_entry), 'did not find matching parameter to modify'
+                new_entry[:, :self.keep_dims, ...] = sd[k]
+                sd[k] = new_entry
+
+        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+            sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys: {missing}")
+        if len(unexpected) > 0:
+            print(f"Unexpected Keys: {unexpected}")
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+        c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', 'cls']:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                log['conditioning'] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+            log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                                         batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            uc_cat = c_cat
+            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                                 batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc_full,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+    """
+    can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+    e.g. mask as concat and text via cross-attn.
+    To disable finetuning mode, set finetune_keys to None
+     """
+
+    def __init__(self,
+                 concat_keys=("mask", "masked_image"),
+                 masked_image_key="masked_image",
+                 *args, **kwargs
+                 ):
+        super().__init__(concat_keys, *args, **kwargs)
+        self.masked_image_key = masked_image_key
+        assert self.masked_image_key in concat_keys
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            bchw = z.shape
+            if ck != self.masked_image_key:
+                cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+            else:
+                cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+        log["masked_image"] = rearrange(args[0]["masked_image"],
+                                        'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+        return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+    """
+    condition on monocular depth estimation
+    """
+
+    def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.depth_model = instantiate_from_config(depth_stage_config)
+        self.depth_stage_key = concat_keys[0]
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            cc = self.depth_model(cc)
+            cc = torch.nn.functional.interpolate(
+                cc,
+                size=z.shape[2:],
+                mode="bicubic",
+                align_corners=False,
+            )
+
+            depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+                                                                                           keepdim=True)
+            cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        depth = self.depth_model(args[0][self.depth_stage_key])
+        depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+                               torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+        log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+        return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+    """
+        condition on low-res image (and optionally on some spatial noise augmentation)
+    """
+    def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+                 low_scale_config=None, low_scale_key=None, *args, **kwargs):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.reshuffle_patch_size = reshuffle_patch_size
+        self.low_scale_model = None
+        if low_scale_config is not None:
+            print("Initializing a low-scale model")
+            assert exists(low_scale_key)
+            self.instantiate_low_stage(low_scale_config)
+            self.low_scale_key = low_scale_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        # optionally make spatial noise_level here
+        c_cat = list()
+        noise_level = None
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            cc = rearrange(cc, 'b h w c -> b c h w')
+            if exists(self.reshuffle_patch_size):
+                assert isinstance(self.reshuffle_patch_size, int)
+                cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+                               p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            if exists(self.low_scale_model) and ck == self.low_scale_key:
+                cc, noise_level = self.low_scale_model(cc)
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        if exists(noise_level):
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+        else:
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+        return log
diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+    def __init__(
+            self,
+            schedule='discrete',
+            betas=None,
+            alphas_cumprod=None,
+            continuous_beta_0=0.1,
+            continuous_beta_1=20.,
+    ):
+        """Create a wrapper class for the forward SDE (VP type).
+        ***
+        Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+                We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+        ***
+        The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+        We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+        Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+            log_alpha_t = self.marginal_log_mean_coeff(t)
+            sigma_t = self.marginal_std(t)
+            lambda_t = self.marginal_lambda(t)
+        Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+            t = self.inverse_lambda(lambda_t)
+        ===============================================================
+        We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+        1. For discrete-time DPMs:
+            For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+                t_i = (i + 1) / N
+            e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+            We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+            Args:
+                betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+                alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+            Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+            **Important**:  Please pay special attention for the args for `alphas_cumprod`:
+                The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+                    q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+                Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+                    alpha_{t_n} = \sqrt{\hat{alpha_n}},
+                and
+                    log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+        2. For continuous-time DPMs:
+            We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+            schedule are the default settings in DDPM and improved-DDPM:
+            Args:
+                beta_min: A `float` number. The smallest beta for the linear schedule.
+                beta_max: A `float` number. The largest beta for the linear schedule.
+                cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+                cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+                T: A `float` number. The ending time of the forward process.
+        ===============================================================
+        Args:
+            schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+                    'linear' or 'cosine' for continuous-time DPMs.
+        Returns:
+            A wrapper object of the forward SDE (VP type).
+
+        ===============================================================
+        Example:
+        # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', betas=betas)
+        # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+        # For continuous-time DPMs (VPSDE), linear schedule:
+        >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+        """
+
+        if schedule not in ['discrete', 'linear', 'cosine']:
+            raise ValueError(
+                "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+                    schedule))
+
+        self.schedule = schedule
+        if schedule == 'discrete':
+            if betas is not None:
+                log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+            else:
+                assert alphas_cumprod is not None
+                log_alphas = 0.5 * torch.log(alphas_cumprod)
+            self.total_N = len(log_alphas)
+            self.T = 1.
+            self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+            self.log_alpha_array = log_alphas.reshape((1, -1,))
+        else:
+            self.total_N = 1000
+            self.beta_0 = continuous_beta_0
+            self.beta_1 = continuous_beta_1
+            self.cosine_s = 0.008
+            self.cosine_beta_max = 999.
+            self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+                        1. + self.cosine_s) / math.pi - self.cosine_s
+            self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+            self.schedule = schedule
+            if schedule == 'cosine':
+                # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+                # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+                self.T = 0.9946
+            else:
+                self.T = 1.
+
+    def marginal_log_mean_coeff(self, t):
+        """
+        Compute log(alpha_t) of a given continuous-time label t in [0, T].
+        """
+        if self.schedule == 'discrete':
+            return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+                                  self.log_alpha_array.to(t.device)).reshape((-1))
+        elif self.schedule == 'linear':
+            return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+        elif self.schedule == 'cosine':
+            log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+            log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+            return log_alpha_t
+
+    def marginal_alpha(self, t):
+        """
+        Compute alpha_t of a given continuous-time label t in [0, T].
+        """
+        return torch.exp(self.marginal_log_mean_coeff(t))
+
+    def marginal_std(self, t):
+        """
+        Compute sigma_t of a given continuous-time label t in [0, T].
+        """
+        return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+    def marginal_lambda(self, t):
+        """
+        Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+        """
+        log_mean_coeff = self.marginal_log_mean_coeff(t)
+        log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+        return log_mean_coeff - log_std
+
+    def inverse_lambda(self, lamb):
+        """
+        Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+        """
+        if self.schedule == 'linear':
+            tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+            Delta = self.beta_0 ** 2 + tmp
+            return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+        elif self.schedule == 'discrete':
+            log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+            t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+                               torch.flip(self.t_array.to(lamb.device), [1]))
+            return t.reshape((-1,))
+        else:
+            log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+            t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+                        1. + self.cosine_s) / math.pi - self.cosine_s
+            t = t_fn(log_alpha)
+            return t
+
+
+def model_wrapper(
+        model,
+        noise_schedule,
+        model_type="noise",
+        model_kwargs={},
+        guidance_type="uncond",
+        condition=None,
+        unconditional_condition=None,
+        guidance_scale=1.,
+        classifier_fn=None,
+        classifier_kwargs={},
+):
+    """Create a wrapper function for the noise prediction model.
+    DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+    firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+    We support four types of the diffusion model by setting `model_type`:
+        1. "noise": noise prediction model. (Trained by predicting noise).
+        2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+        3. "v": velocity prediction model. (Trained by predicting the velocity).
+            The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+            [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+                arXiv preprint arXiv:2202.00512 (2022).
+            [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+                arXiv preprint arXiv:2210.02303 (2022).
+
+        4. "score": marginal score function. (Trained by denoising score matching).
+            Note that the score function and the noise prediction model follows a simple relationship:
+            ```
+                noise(x_t, t) = -sigma_t * score(x_t, t)
+            ```
+    We support three types of guided sampling by DPMs by setting `guidance_type`:
+        1. "uncond": unconditional sampling by DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+        2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+            The input `classifier_fn` has the following format:
+            ``
+                classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+            ``
+            [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+                in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+        3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+            ``
+            And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+            [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+                arXiv preprint arXiv:2207.12598 (2022).
+
+    The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+    or continuous-time labels (i.e. epsilon to T).
+    We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+    ``
+        def model_fn(x, t_continuous) -> noise:
+            t_input = get_model_input_time(t_continuous)
+            return noise_pred(model, x, t_input, **model_kwargs)
+    ``
+    where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+    ===============================================================
+    Args:
+        model: A diffusion model with the corresponding format described above.
+        noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+        model_type: A `str`. The parameterization type of the diffusion model.
+                    "noise" or "x_start" or "v" or "score".
+        model_kwargs: A `dict`. A dict for the other inputs of the model function.
+        guidance_type: A `str`. The type of the guidance for sampling.
+                    "uncond" or "classifier" or "classifier-free".
+        condition: A pytorch tensor. The condition for the guided sampling.
+                    Only used for "classifier" or "classifier-free" guidance type.
+        unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+                    Only used for "classifier-free" guidance type.
+        guidance_scale: A `float`. The scale for the guided sampling.
+        classifier_fn: A classifier function. Only used for the classifier guidance.
+        classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+    Returns:
+        A noise prediction model that accepts the noised data and the continuous time as the inputs.
+    """
+
+    def get_model_input_time(t_continuous):
+        """
+        Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+        For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+        For continuous-time DPMs, we just use `t_continuous`.
+        """
+        if noise_schedule.schedule == 'discrete':
+            return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+        else:
+            return t_continuous
+
+    def noise_pred_fn(x, t_continuous, cond=None):
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        t_input = get_model_input_time(t_continuous)
+        if cond is None:
+            output = model(x, t_input, **model_kwargs)
+        else:
+            output = model(x, t_input, cond, **model_kwargs)
+        if model_type == "noise":
+            return output
+        elif model_type == "x_start":
+            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+        elif model_type == "v":
+            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+        elif model_type == "score":
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return -expand_dims(sigma_t, dims) * output
+
+    def cond_grad_fn(x, t_input):
+        """
+        Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+        """
+        with torch.enable_grad():
+            x_in = x.detach().requires_grad_(True)
+            log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+            return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+    def model_fn(x, t_continuous):
+        """
+        The noise predicition model function that is used for DPM-Solver.
+        """
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        if guidance_type == "uncond":
+            return noise_pred_fn(x, t_continuous)
+        elif guidance_type == "classifier":
+            assert classifier_fn is not None
+            t_input = get_model_input_time(t_continuous)
+            cond_grad = cond_grad_fn(x, t_input)
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            noise = noise_pred_fn(x, t_continuous)
+            return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+        elif guidance_type == "classifier-free":
+            if guidance_scale == 1. or unconditional_condition is None:
+                return noise_pred_fn(x, t_continuous, cond=condition)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t_continuous] * 2)
+                c_in = torch.cat([unconditional_condition, condition])
+                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+                return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+    assert model_type in ["noise", "x_start", "v"]
+    assert guidance_type in ["uncond", "classifier", "classifier-free"]
+    return model_fn
+
+
+class DPM_Solver:
+    def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+        """Construct a DPM-Solver.
+        We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+        If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+        If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+            In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+            The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+        Args:
+            model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+                ``
+                def model_fn(x, t_continuous):
+                    return noise
+                ``
+            noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+            predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+            thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+            max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+        [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+        """
+        self.model = model_fn
+        self.noise_schedule = noise_schedule
+        self.predict_x0 = predict_x0
+        self.thresholding = thresholding
+        self.max_val = max_val
+
+    def noise_prediction_fn(self, x, t):
+        """
+        Return the noise prediction model.
+        """
+        return self.model(x, t)
+
+    def data_prediction_fn(self, x, t):
+        """
+        Return the data prediction model (with thresholding).
+        """
+        noise = self.noise_prediction_fn(x, t)
+        dims = x.dim()
+        alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+        x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+        if self.thresholding:
+            p = 0.995  # A hyperparameter in the paper of "Imagen" [1].
+            s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+            s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+            x0 = torch.clamp(x0, -s, s) / s
+        return x0
+
+    def model_fn(self, x, t):
+        """
+        Convert the model to the noise prediction model or the data prediction model.
+        """
+        if self.predict_x0:
+            return self.data_prediction_fn(x, t)
+        else:
+            return self.noise_prediction_fn(x, t)
+
+    def get_time_steps(self, skip_type, t_T, t_0, N, device):
+        """Compute the intermediate time steps for sampling.
+        Args:
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            N: A `int`. The total number of the spacing of the time steps.
+            device: A torch device.
+        Returns:
+            A pytorch tensor of the time steps, with the shape (N + 1,).
+        """
+        if skip_type == 'logSNR':
+            lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+            lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+            logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+            return self.noise_schedule.inverse_lambda(logSNR_steps)
+        elif skip_type == 'time_uniform':
+            return torch.linspace(t_T, t_0, N + 1).to(device)
+        elif skip_type == 'time_quadratic':
+            t_order = 2
+            t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+            return t
+        else:
+            raise ValueError(
+                "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+    def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+        """
+        Get the order of each step for sampling by the singlestep DPM-Solver.
+        We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+        Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+            - If order == 1:
+                We take `steps` of DPM-Solver-1 (i.e. DDIM).
+            - If order == 2:
+                - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+                - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+                - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+            - If order == 3:
+                - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+        ============================================
+        Args:
+            order: A `int`. The max order for the solver (2 or 3).
+            steps: A `int`. The total number of function evaluations (NFE).
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            device: A torch device.
+        Returns:
+            orders: A list of the solver order of each step.
+        """
+        if order == 3:
+            K = steps // 3 + 1
+            if steps % 3 == 0:
+                orders = [3, ] * (K - 2) + [2, 1]
+            elif steps % 3 == 1:
+                orders = [3, ] * (K - 1) + [1]
+            else:
+                orders = [3, ] * (K - 1) + [2]
+        elif order == 2:
+            if steps % 2 == 0:
+                K = steps // 2
+                orders = [2, ] * K
+            else:
+                K = steps // 2 + 1
+                orders = [2, ] * (K - 1) + [1]
+        elif order == 1:
+            K = 1
+            orders = [1, ] * steps
+        else:
+            raise ValueError("'order' must be '1' or '2' or '3'.")
+        if skip_type == 'logSNR':
+            # To reproduce the results in DPM-Solver paper
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+        else:
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+                torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+        return timesteps_outer, orders
+
+    def denoise_to_zero_fn(self, x, s):
+        """
+        Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+        """
+        return self.data_prediction_fn(x, s)
+
+    def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+        """
+        DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_1 = torch.expm1(-h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                    expand_dims(sigma_t / sigma_s, dims) * x
+                    - expand_dims(alpha_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {'model_s': model_s}
+            else:
+                return x_t
+        else:
+            phi_1 = torch.expm1(h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                    - expand_dims(sigma_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {'model_s': model_s}
+            else:
+                return x_t
+
+    def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+                                            solver_type='dpm_solver'):
+        """
+        Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the second-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        if r1 is None:
+            r1 = 0.5
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+            s1), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+        alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_1 = torch.expm1(-h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                    expand_dims(sigma_s1 / sigma_s, dims) * x
+                    - expand_dims(alpha_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+                                    model_s1 - model_s)
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_1 = torch.expm1(h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                    expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                    - expand_dims(sigma_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+                )
+        if return_intermediate:
+            return x_t, {'model_s': model_s, 'model_s1': model_s1}
+        else:
+            return x_t
+
+    def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+                                           return_intermediate=False, solver_type='dpm_solver'):
+        """
+        Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+                If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        if r1 is None:
+            r1 = 1. / 3.
+        if r2 is None:
+            r2 = 2. / 3.
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        lambda_s2 = lambda_s + r2 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        s2 = ns.inverse_lambda(lambda_s2)
+        log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+            s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+            s2), ns.marginal_std(t)
+        alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_12 = torch.expm1(-r2 * h)
+            phi_1 = torch.expm1(-h)
+            phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+            phi_2 = phi_1 / h + 1.
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                        expand_dims(sigma_s1 / sigma_s, dims) * x
+                        - expand_dims(alpha_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                    expand_dims(sigma_s2 / sigma_s, dims) * x
+                    - expand_dims(alpha_s2 * phi_12, dims) * model_s
+                    + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+                )
+            elif solver_type == 'taylor':
+                D1_0 = (1. / r1) * (model_s1 - model_s)
+                D1_1 = (1. / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + expand_dims(alpha_t * phi_2, dims) * D1
+                        - expand_dims(alpha_t * phi_3, dims) * D2
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_12 = torch.expm1(r2 * h)
+            phi_1 = torch.expm1(h)
+            phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+            phi_2 = phi_1 / h - 1.
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                        expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                        - expand_dims(sigma_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                    expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+                    - expand_dims(sigma_s2 * phi_12, dims) * model_s
+                    - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+                )
+            elif solver_type == 'taylor':
+                D1_0 = (1. / r1) * (model_s1 - model_s)
+                D1_1 = (1. / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - expand_dims(sigma_t * phi_2, dims) * D1
+                        - expand_dims(sigma_t * phi_3, dims) * D2
+                )
+
+        if return_intermediate:
+            return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+        else:
+            return x_t
+
+    def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+        """
+        Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_1, model_prev_0 = model_prev_list
+        t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+            t_prev_0), ns.marginal_lambda(t)
+        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0 = h_0 / h
+        D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+        if self.predict_x0:
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_prev_0, dims) * x
+                        - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                        - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(sigma_t / sigma_prev_0, dims) * x
+                        - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                        + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+                )
+        else:
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                        - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                        - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                        - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                        - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+                )
+        return x_t
+
+    def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+        """
+        Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+        t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+            t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_1 = lambda_prev_1 - lambda_prev_2
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0, r1 = h_0 / h, h_1 / h
+        D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+        D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+        D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+        D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+        if self.predict_x0:
+            x_t = (
+                    expand_dims(sigma_t / sigma_prev_0, dims) * x
+                    - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                    + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+                    - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+            )
+        else:
+            x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                    - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                    - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+                    - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+            )
+        return x_t
+
+    def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+                                     r2=None):
+        """
+        Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+            r1: A `float`. The hyperparameter of the second-order or third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+        elif order == 2:
+            return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+                                                            solver_type=solver_type, r1=r1)
+        elif order == 3:
+            return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+                                                           solver_type=solver_type, r1=r1, r2=r2)
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+        """
+        Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+        elif order == 2:
+            return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+        elif order == 3:
+            return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+                            solver_type='dpm_solver'):
+        """
+        The adaptive step size solver based on singlestep DPM-Solver.
+        Args:
+            x: A pytorch tensor. The initial value at time `t_T`.
+            order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            h_init: A `float`. The initial step size (for logSNR).
+            atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+            rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+            theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+            t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+                current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_0: A pytorch tensor. The approximated solution at time `t_0`.
+        [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+        """
+        ns = self.noise_schedule
+        s = t_T * torch.ones((x.shape[0],)).to(x)
+        lambda_s = ns.marginal_lambda(s)
+        lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+        h = h_init * torch.ones_like(s).to(x)
+        x_prev = x
+        nfe = 0
+        if order == 2:
+            r1 = 0.5
+            lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+            higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+                                                                                               solver_type=solver_type,
+                                                                                               **kwargs)
+        elif order == 3:
+            r1, r2 = 1. / 3., 2. / 3.
+            lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+                                                                                    return_intermediate=True,
+                                                                                    solver_type=solver_type)
+            higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+                                                                                              solver_type=solver_type,
+                                                                                              **kwargs)
+        else:
+            raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+        while torch.abs((s - t_0)).mean() > t_err:
+            t = ns.inverse_lambda(lambda_s + h)
+            x_lower, lower_noise_kwargs = lower_update(x, s, t)
+            x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+            delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+            norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+            E = norm_fn((x_higher - x_lower) / delta).max()
+            if torch.all(E <= 1.):
+                x = x_higher
+                s = t
+                x_prev = x_lower
+                lambda_s = ns.marginal_lambda(s)
+            h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+            nfe += order
+        print('adaptive solver nfe', nfe)
+        return x
+
+    def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+               method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+               atol=0.0078, rtol=0.05,
+               ):
+        """
+        Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+        =====================================================
+        We support the following algorithms for both noise prediction model and data prediction model:
+            - 'singlestep':
+                Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+                We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+                The total number of function evaluations (NFE) == `steps`.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    - If `order` == 1:
+                        - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+                        - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+                        - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                    - If `order` == 3:
+                        - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                        - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+            - 'multistep':
+                Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+                We initialize the first `order` values by lower order multistep solvers.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    Denote K = steps.
+                    - If `order` == 1:
+                        - We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+                    - If `order` == 3:
+                        - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+            - 'singlestep_fixed':
+                Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+                We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+            - 'adaptive':
+                Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+                We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+                You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+                (NFE) and the sample quality.
+                    - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+                    - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+        =====================================================
+        Some advices for choosing the algorithm:
+            - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+                Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+                            skip_type='time_uniform', method='singlestep')
+            - For **guided sampling with large guidance scale** by DPMs:
+                Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+                            skip_type='time_uniform', method='multistep')
+        We support three types of `skip_type`:
+            - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+            - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+            - 'time_quadratic': quadratic time for the time steps.
+        =====================================================
+        Args:
+            x: A pytorch tensor. The initial value at time `t_start`
+                e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+            steps: A `int`. The total number of function evaluations (NFE).
+            t_start: A `float`. The starting time of the sampling.
+                If `T` is None, we use self.noise_schedule.T (default is 1.0).
+            t_end: A `float`. The ending time of the sampling.
+                If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+                e.g. if total_N == 1000, we have `t_end` == 1e-3.
+                For discrete-time DPMs:
+                    - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+                For continuous-time DPMs:
+                    - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+            order: A `int`. The order of DPM-Solver.
+            skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+            method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+            denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+                Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+                This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+                score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+                for diffusion models sampling by diffusion SDEs for low-resolutional images
+                (such as CIFAR-10). However, we observed that such trick does not matter for
+                high-resolutional images. As it needs an additional NFE, we do not recommend
+                it for high-resolutional images.
+            lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+                Only valid for `method=multistep` and `steps < 15`. We empirically find that
+                this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+                (especially for steps <= 10). So we recommend to set it to be `True`.
+            solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+            atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+            rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+        Returns:
+            x_end: A pytorch tensor. The approximated solution at time `t_end`.
+        """
+        t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+        t_T = self.noise_schedule.T if t_start is None else t_start
+        device = x.device
+        if method == 'adaptive':
+            with torch.no_grad():
+                x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+                                             solver_type=solver_type)
+        elif method == 'multistep':
+            assert steps >= order
+            timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+            assert timesteps.shape[0] - 1 == steps
+            with torch.no_grad():
+                vec_t = timesteps[0].expand((x.shape[0]))
+                model_prev_list = [self.model_fn(x, vec_t)]
+                t_prev_list = [vec_t]
+                # Init the first `order` values by lower order multistep DPM-Solver.
+                for init_order in tqdm(range(1, order), desc="DPM init order"):
+                    vec_t = timesteps[init_order].expand(x.shape[0])
+                    x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+                                                         solver_type=solver_type)
+                    model_prev_list.append(self.model_fn(x, vec_t))
+                    t_prev_list.append(vec_t)
+                # Compute the remaining values by `order`-th order multistep DPM-Solver.
+                for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+                    vec_t = timesteps[step].expand(x.shape[0])
+                    if lower_order_final and steps < 15:
+                        step_order = min(order, steps + 1 - step)
+                    else:
+                        step_order = order
+                    x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+                                                         solver_type=solver_type)
+                    for i in range(order - 1):
+                        t_prev_list[i] = t_prev_list[i + 1]
+                        model_prev_list[i] = model_prev_list[i + 1]
+                    t_prev_list[-1] = vec_t
+                    # We do not need to evaluate the final model value.
+                    if step < steps:
+                        model_prev_list[-1] = self.model_fn(x, vec_t)
+        elif method in ['singlestep', 'singlestep_fixed']:
+            if method == 'singlestep':
+                timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+                                                                                              skip_type=skip_type,
+                                                                                              t_T=t_T, t_0=t_0,
+                                                                                              device=device)
+            elif method == 'singlestep_fixed':
+                K = steps // order
+                orders = [order, ] * K
+                timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+            for i, order in enumerate(orders):
+                t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+                timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+                                                      N=order, device=device)
+                lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+                vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+                h = lambda_inner[-1] - lambda_inner[0]
+                r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+                r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+                x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+        if denoise_to_zero:
+            x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+        return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+    """
+    A piecewise linear function y = f(x), using xp and yp as keypoints.
+    We implement f(x) in a differentiable way (i.e. applicable for autograd).
+    The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+    Args:
+        x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+        xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+        yp: PyTorch tensor with shape [C, K].
+    Returns:
+        The function values f(x), with shape [N, C].
+    """
+    N, K = x.shape[0], xp.shape[1]
+    all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+    sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+    x_idx = torch.argmin(x_indices, dim=2)
+    cand_start_idx = x_idx - 1
+    start_idx = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(1, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+        ),
+    )
+    end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+    start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+    end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+    start_idx2 = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(0, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+        ),
+    )
+    y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+    start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+    end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+    cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+    return cand
+
+
+def expand_dims(v, dims):
+    """
+    Expand the tensor `v` to the dim `dims`.
+    Args:
+        `v`: a PyTorch tensor with shape [N].
+        `dim`: a `int`.
+    Returns:
+        a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+    """
+    return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+    "eps": "noise",
+    "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        self.model = model
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+        self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None,
+               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+
+        print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+        device = self.model.betas.device
+        if x_T is None:
+            img = torch.randn(size, device=device)
+        else:
+            img = x_T
+
+        ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+        model_fn = model_wrapper(
+            lambda x, t, c: self.model.apply_model(x, t, c),
+            ns,
+            model_type=MODEL_TYPES[self.model.parameterization],
+            guidance_type="classifier-free",
+            condition=conditioning,
+            unconditional_condition=unconditional_conditioning,
+            guidance_scale=unconditional_guidance_scale,
+        )
+
+        dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+        x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+        return x.to(device), None
\ No newline at end of file
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae
--- /dev/null
+++ b/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        if ddim_eta != 0:
+            raise ValueError('ddim_eta must be 0 for PLMS')
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None,
+               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for PLMS sampling is {size}')
+
+        samples, intermediates = self.plms_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def plms_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+        old_eps = []
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+            ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      old_eps=old_eps, t_next=ts_next,
+                                      dynamic_threshold=dynamic_threshold)
+            img, pred_x0, e_t = outs
+            old_eps.append(e_t)
+            if len(old_eps) >= 4:
+                old_eps.pop(0)
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+                      dynamic_threshold=None):
+        b, *_, device = *x.shape, x.device
+
+        def get_model_output(x, t):
+            if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+                e_t = self.model.apply_model(x, t, c)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t] * 2)
+                c_in = torch.cat([unconditional_conditioning, c])
+                e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+                e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+            if score_corrector is not None:
+                assert self.model.parameterization == "eps"
+                e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+            return e_t
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+        def get_x_prev_and_pred_x0(e_t, index):
+            # select parameters corresponding to the currently considered timestep
+            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+            # current prediction for x_0
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+            if quantize_denoised:
+                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+            if dynamic_threshold is not None:
+                pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+            # direction pointing to x_t
+            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+            if noise_dropout > 0.:
+                noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+            return x_prev, pred_x0
+
+        e_t = get_model_output(x, t)
+        if len(old_eps) == 0:
+            # Pseudo Improved Euler (2nd order)
+            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+            e_t_next = get_model_output(x_prev, t_next)
+            e_t_prime = (e_t + e_t_next) / 2
+        elif len(old_eps) == 1:
+            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (3 * e_t - old_eps[-1]) / 2
+        elif len(old_eps) == 2:
+            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+        elif len(old_eps) >= 3:
+            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+        return x_prev, pred_x0, e_t
diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33
--- /dev/null
+++ b/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+    """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+    From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+    dims_to_append = target_dims - x.ndim
+    if dims_to_append < 0:
+        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+    return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+    s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+    return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+    # b c h w
+    s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+    return x0 * (value / s)
\ No newline at end of file
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..509cd873768f0dd75a75ab3fcdd652822b12b59f
--- /dev/null
+++ b/ldm/modules/attention.py
@@ -0,0 +1,341 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+try:
+    import xformers
+    import xformers.ops
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+def exists(val):
+    return val is not None
+
+
+def uniq(arr):
+    return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+    return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+    dim = tensor.shape[-1]
+    std = 1 / math.sqrt(dim)
+    tensor.uniform_(-std, std)
+    return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def forward(self, x):
+        x, gate = self.proj(x).chunk(2, dim=-1)
+        return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+        super().__init__()
+        inner_dim = int(dim * mult)
+        dim_out = default(dim_out, dim)
+        project_in = nn.Sequential(
+            nn.Linear(dim, inner_dim),
+            nn.GELU()
+        ) if not glu else GEGLU(dim, inner_dim)
+
+        self.net = nn.Sequential(
+            project_in,
+            nn.Dropout(dropout),
+            nn.Linear(inner_dim, dim_out)
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def Normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialSelfAttention(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = rearrange(q, 'b c h w -> b (h w) c')
+        k = rearrange(k, 'b c h w -> b c (h w)')
+        w_ = torch.einsum('bij,bjk->bik', q, k)
+
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = rearrange(v, 'b c h w -> b c (h w)')
+        w_ = rearrange(w_, 'b i j -> b j i')
+        h_ = torch.einsum('bij,bjk->bik', v, w_)
+        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+
+class CrossAttention(nn.Module):
+    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+        super().__init__()
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.scale = dim_head ** -0.5
+        self.heads = heads
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(
+            nn.Linear(inner_dim, query_dim),
+            nn.Dropout(dropout)
+        )
+
+    def forward(self, x, context=None, mask=None):
+        h = self.heads
+
+        q = self.to_q(x)
+        context = default(context, x)
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+        # force cast to fp32 to avoid overflowing
+        if _ATTN_PRECISION =="fp32":
+            with torch.autocast(enabled=False, device_type = 'cuda'):
+                q, k = q.float(), k.float()
+                sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+        else:
+            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+        
+        del q, k
+    
+        if exists(mask):
+            mask = rearrange(mask, 'b ... -> b (...)')
+            max_neg_value = -torch.finfo(sim.dtype).max
+            mask = repeat(mask, 'b j -> (b h) () j', h=h)
+            sim.masked_fill_(~mask, max_neg_value)
+
+        # attention, what we cannot get enough of
+        sim = sim.softmax(dim=-1)
+
+        out = einsum('b i j, b j d -> b i d', sim, v)
+        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+        return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+        super().__init__()
+        print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+              f"{heads} heads.")
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.heads = heads
+        self.dim_head = dim_head
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+        self.attention_op: Optional[Any] = None
+
+    def forward(self, x, context=None, mask=None):
+        q = self.to_q(x)
+        context = default(context, x)
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        b, _, _ = q.shape
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(b, t.shape[1], self.heads, self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b * self.heads, t.shape[1], self.dim_head)
+            .contiguous(),
+            (q, k, v),
+        )
+
+        # actually compute the attention, what we cannot get enough of
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+        if exists(mask):
+            raise NotImplementedError
+        out = (
+            out.unsqueeze(0)
+            .reshape(b, self.heads, out.shape[1], self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b, out.shape[1], self.heads * self.dim_head)
+        )
+        return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+    ATTENTION_MODES = {
+        "softmax": CrossAttention,  # vanilla attention
+        "softmax-xformers": MemoryEfficientCrossAttention
+    }
+    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+                 disable_self_attn=False):
+        super().__init__()
+        attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+        assert attn_mode in self.ATTENTION_MODES
+        attn_cls = self.ATTENTION_MODES[attn_mode]
+        self.disable_self_attn = disable_self_attn
+        self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+                              context_dim=context_dim if self.disable_self_attn else None)  # is a self-attention if not self.disable_self_attn
+        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+        self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+                              heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+        self.norm3 = nn.LayerNorm(dim)
+        self.checkpoint = checkpoint
+
+    def forward(self, x, context=None):
+        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+    def _forward(self, x, context=None):
+        x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+        x = self.attn2(self.norm2(x), context=context) + x
+        x = self.ff(self.norm3(x)) + x
+        return x
+
+
+class SpatialTransformer(nn.Module):
+    """
+    Transformer block for image-like data.
+    First, project the input (aka embedding)
+    and reshape to b, t, d.
+    Then apply standard transformer action.
+    Finally, reshape to image
+    NEW: use_linear for more efficiency instead of the 1x1 convs
+    """
+    def __init__(self, in_channels, n_heads, d_head,
+                 depth=1, dropout=0., context_dim=None,
+                 disable_self_attn=False, use_linear=False,
+                 use_checkpoint=True):
+        super().__init__()
+        if exists(context_dim) and not isinstance(context_dim, list):
+            context_dim = [context_dim]
+        self.in_channels = in_channels
+        inner_dim = n_heads * d_head
+        self.norm = Normalize(in_channels)
+        if not use_linear:
+            self.proj_in = nn.Conv2d(in_channels,
+                                     inner_dim,
+                                     kernel_size=1,
+                                     stride=1,
+                                     padding=0)
+        else:
+            self.proj_in = nn.Linear(in_channels, inner_dim)
+
+        self.transformer_blocks = nn.ModuleList(
+            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+                                   disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+                for d in range(depth)]
+        )
+        if not use_linear:
+            self.proj_out = zero_module(nn.Conv2d(inner_dim,
+                                                  in_channels,
+                                                  kernel_size=1,
+                                                  stride=1,
+                                                  padding=0))
+        else:
+            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+        self.use_linear = use_linear
+
+    def forward(self, x, context=None):
+        # note: if no context is given, cross-attention defaults to self-attention
+        if not isinstance(context, list):
+            context = [context]
+        b, c, h, w = x.shape
+        x_in = x
+        x = self.norm(x)
+        if not self.use_linear:
+            x = self.proj_in(x)
+        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+        if self.use_linear:
+            x = self.proj_in(x)
+        for i, block in enumerate(self.transformer_blocks):
+            x = block(x, context=context[i])
+        if self.use_linear:
+            x = self.proj_out(x)
+        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+        if not self.use_linear:
+            x = self.proj_out(x)
+        return x + x_in
+
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ldm.modules.attention import MemoryEfficientCrossAttention
+
+try:
+    import xformers
+    import xformers.ops
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+    print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models:
+    From Fairseq.
+    Build sinusoidal embeddings.
+    This matches the implementation in tensor2tensor, but differs slightly
+    from the description in Section 3.5 of "Attention Is All You Need".
+    """
+    assert len(timesteps.shape) == 1
+
+    half_dim = embedding_dim // 2
+    emb = math.log(10000) / (half_dim - 1)
+    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+    emb = emb.to(device=timesteps.device)
+    emb = timesteps.float()[:, None] * emb[None, :]
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+    if embedding_dim % 2 == 1:  # zero pad
+        emb = torch.nn.functional.pad(emb, (0,1,0,0))
+    return emb
+
+
+def nonlinearity(x):
+    # swish
+    return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=2,
+                                        padding=0)
+
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0,1,0,1)
+            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+        return x
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+                 dropout, temb_channels=512):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = torch.nn.Conv2d(in_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if temb_channels > 0:
+            self.temb_proj = torch.nn.Linear(temb_channels,
+                                             out_channels)
+        self.norm2 = Normalize(out_channels)
+        self.dropout = torch.nn.Dropout(dropout)
+        self.conv2 = torch.nn.Conv2d(out_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(in_channels,
+                                                     out_channels,
+                                                     kernel_size=3,
+                                                     stride=1,
+                                                     padding=1)
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(in_channels,
+                                                    out_channels,
+                                                    kernel_size=1,
+                                                    stride=1,
+                                                    padding=0)
+
+    def forward(self, x, temb):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        if temb is not None:
+            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x+h
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = q.reshape(b,c,h*w)
+        q = q.permute(0,2,1)   # b,hw,c
+        k = k.reshape(b,c,h*w) # b,c,hw
+        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b,c,h*w)
+        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
+        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+        h_ = h_.reshape(b,c,h,w)
+
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+    """
+        Uses xformers efficient implementation,
+        see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+        Note: this is a single-head self-attention operation
+    """
+    #
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+        self.attention_op: Optional[Any] = None
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        B, C, H, W = q.shape
+        q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(B, t.shape[1], 1, C)
+            .permute(0, 2, 1, 3)
+            .reshape(B * 1, t.shape[1], C)
+            .contiguous(),
+            (q, k, v),
+        )
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+        out = (
+            out.unsqueeze(0)
+            .reshape(B, 1, out.shape[1], C)
+            .permute(0, 2, 1, 3)
+            .reshape(B, out.shape[1], C)
+        )
+        out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+        out = self.proj_out(out)
+        return x+out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+    def forward(self, x, context=None, mask=None):
+        b, c, h, w = x.shape
+        x = rearrange(x, 'b c h w -> b (h w) c')
+        out = super().forward(x, context=context, mask=mask)
+        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+        return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+    assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+    if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+        attn_type = "vanilla-xformers"
+    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+    if attn_type == "vanilla":
+        assert attn_kwargs is None
+        return AttnBlock(in_channels)
+    elif attn_type == "vanilla-xformers":
+        print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+        return MemoryEfficientAttnBlock(in_channels)
+    elif type == "memory-efficient-cross-attn":
+        attn_kwargs["query_dim"] = in_channels
+        return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+    elif attn_type == "none":
+        return nn.Identity(in_channels)
+    else:
+        raise NotImplementedError()
+
+
+class Model(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = self.ch*4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList([
+                torch.nn.Linear(self.ch,
+                                self.temb_ch),
+                torch.nn.Linear(self.temb_ch,
+                                self.temb_ch),
+            ])
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            skip_in = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch*in_ch_mult[i_level]
+                block.append(ResnetBlock(in_channels=block_in+skip_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x, t=None, context=None):
+        #assert x.shape[2] == x.shape[3] == self.resolution
+        if context is not None:
+            # assume aligned context, cat along channel axis
+            x = torch.cat((x, context), dim=1)
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+    def get_last_layer(self):
+        return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+                 **ignore_kwargs):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.in_ch_mult = in_ch_mult
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        2*z_channels if double_z else z_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # timestep embedding
+        temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class Decoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+                 attn_type="vanilla", **ignorekwargs):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+        self.tanh_out = tanh_out
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        in_ch_mult = (1,)+tuple(ch_mult)
+        block_in = ch*ch_mult[self.num_resolutions-1]
+        curr_res = resolution // 2**(self.num_resolutions-1)
+        self.z_shape = (1,z_channels,curr_res,curr_res)
+        print("Working with z of shape {} = {} dimensions.".format(
+            self.z_shape, np.prod(self.z_shape)))
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(z_channels,
+                                       block_in,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, z):
+        #assert z.shape[1:] == self.z_shape[1:]
+        self.last_z_shape = z.shape
+
+        # timestep embedding
+        temb = None
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](h, temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        if self.give_pre_end:
+            return h
+
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        if self.tanh_out:
+            h = torch.tanh(h)
+        return h
+
+
+class SimpleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, *args, **kwargs):
+        super().__init__()
+        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+                                     ResnetBlock(in_channels=in_channels,
+                                                 out_channels=2 * in_channels,
+                                                 temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=2 * in_channels,
+                                                out_channels=4 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=4 * in_channels,
+                                                out_channels=2 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     nn.Conv2d(2*in_channels, in_channels, 1),
+                                     Upsample(in_channels, with_conv=True)])
+        # end
+        self.norm_out = Normalize(in_channels)
+        self.conv_out = torch.nn.Conv2d(in_channels,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        for i, layer in enumerate(self.model):
+            if i in [1,2,3]:
+                x = layer(x, None)
+            else:
+                x = layer(x)
+
+        h = self.norm_out(x)
+        h = nonlinearity(h)
+        x = self.conv_out(h)
+        return x
+
+
+class UpsampleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+                 ch_mult=(2,2), dropout=0.0):
+        super().__init__()
+        # upsampling
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        block_in = in_channels
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.res_blocks = nn.ModuleList()
+        self.upsample_blocks = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            res_block = []
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                res_block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+            self.res_blocks.append(nn.ModuleList(res_block))
+            if i_level != self.num_resolutions - 1:
+                self.upsample_blocks.append(Upsample(block_in, True))
+                curr_res = curr_res * 2
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # upsampling
+        h = x
+        for k, i_level in enumerate(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.res_blocks[i_level][i_block](h, None)
+            if i_level != self.num_resolutions - 1:
+                h = self.upsample_blocks[k](h)
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class LatentRescaler(nn.Module):
+    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+        super().__init__()
+        # residual block, interpolate, residual block
+        self.factor = factor
+        self.conv_in = nn.Conv2d(in_channels,
+                                 mid_channels,
+                                 kernel_size=3,
+                                 stride=1,
+                                 padding=1)
+        self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+                                                     out_channels=mid_channels,
+                                                     temb_channels=0,
+                                                     dropout=0.0) for _ in range(depth)])
+        self.attn = AttnBlock(mid_channels)
+        self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+                                                     out_channels=mid_channels,
+                                                     temb_channels=0,
+                                                     dropout=0.0) for _ in range(depth)])
+
+        self.conv_out = nn.Conv2d(mid_channels,
+                                  out_channels,
+                                  kernel_size=1,
+                                  )
+
+    def forward(self, x):
+        x = self.conv_in(x)
+        for block in self.res_block1:
+            x = block(x, None)
+        x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+        x = self.attn(x)
+        for block in self.res_block2:
+            x = block(x, None)
+        x = self.conv_out(x)
+        return x
+
+
+class MergedRescaleEncoder(nn.Module):
+    def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True,
+                 ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+        super().__init__()
+        intermediate_chn = ch * ch_mult[-1]
+        self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+                               z_channels=intermediate_chn, double_z=False, resolution=resolution,
+                               attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+                               out_ch=None)
+        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+                                       mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.rescaler(x)
+        return x
+
+
+class MergedRescaleDecoder(nn.Module):
+    def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+                 dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+        super().__init__()
+        tmp_chn = z_channels*ch_mult[-1]
+        self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+                               resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+                               ch_mult=ch_mult, resolution=resolution, ch=ch)
+        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+                                       out_channels=tmp_chn, depth=rescale_module_depth)
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Upsampler(nn.Module):
+    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+        super().__init__()
+        assert out_size >= in_size
+        num_blocks = int(np.log2(out_size//in_size))+1
+        factor_up = 1.+ (out_size % in_size)
+        print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+        self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+                                       out_channels=in_channels)
+        self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+                               attn_resolutions=[], in_channels=None, ch=in_channels,
+                               ch_mult=[ch_mult for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Resize(nn.Module):
+    def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+        super().__init__()
+        self.with_conv = learned
+        self.mode = mode
+        if self.with_conv:
+            print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+            raise NotImplementedError()
+            assert in_channels is not None
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=4,
+                                        stride=2,
+                                        padding=1)
+
+    def forward(self, x, scale_factor=1.0):
+        if scale_factor==1.0:
+            return x
+        else:
+            x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+        return x
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df6b5abfe8eff07f0c8e8703ba8aee90d45984b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,786 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+    checkpoint,
+    conv_nd,
+    linear,
+    avg_pool_nd,
+    zero_module,
+    normalization,
+    timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+    pass
+
+def convert_module_to_f32(x):
+    pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+    """
+    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+    """
+
+    def __init__(
+        self,
+        spacial_dim: int,
+        embed_dim: int,
+        num_heads_channels: int,
+        output_dim: int = None,
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+        self.num_heads = embed_dim // num_heads_channels
+        self.attention = QKVAttention(self.num_heads)
+
+    def forward(self, x):
+        b, c, *_spatial = x.shape
+        x = x.reshape(b, c, -1)  # NC(HW)
+        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
+        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
+        x = self.qkv_proj(x)
+        x = self.attention(x)
+        x = self.c_proj(x)
+        return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+    """
+    Any module where forward() takes timestep embeddings as a second argument.
+    """
+
+    @abstractmethod
+    def forward(self, x, emb):
+        """
+        Apply the module to `x` given `emb` timestep embeddings.
+        """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+    """
+    A sequential module that passes timestep embeddings to the children that
+    support it as an extra input.
+    """
+
+    def forward(self, x, emb, context=None):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb)
+            elif isinstance(layer, SpatialTransformer):
+                x = layer(x, context)
+            else:
+                x = layer(x)
+        return x
+
+
+class Upsample(nn.Module):
+    """
+    An upsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 upsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        if use_conv:
+            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        if self.dims == 3:
+            x = F.interpolate(
+                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+            )
+        else:
+            x = F.interpolate(x, scale_factor=2, mode="nearest")
+        if self.use_conv:
+            x = self.conv(x)
+        return x
+
+class TransposedUpsample(nn.Module):
+    'Learned 2x upsampling without padding'
+    def __init__(self, channels, out_channels=None, ks=5):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+
+        self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+    def forward(self,x):
+        return self.up(x)
+
+
+class Downsample(nn.Module):
+    """
+    A downsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 downsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        stride = 2 if dims != 3 else (1, 2, 2)
+        if use_conv:
+            self.op = conv_nd(
+                dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+            )
+        else:
+            assert self.channels == self.out_channels
+            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+    """
+    A residual block that can optionally change the number of channels.
+    :param channels: the number of input channels.
+    :param emb_channels: the number of timestep embedding channels.
+    :param dropout: the rate of dropout.
+    :param out_channels: if specified, the number of out channels.
+    :param use_conv: if True and out_channels is specified, use a spatial
+        convolution instead of a smaller 1x1 convolution to change the
+        channels in the skip connection.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param use_checkpoint: if True, use gradient checkpointing on this module.
+    :param up: if True, use this block for upsampling.
+    :param down: if True, use this block for downsampling.
+    """
+
+    def __init__(
+        self,
+        channels,
+        emb_channels,
+        dropout,
+        out_channels=None,
+        use_conv=False,
+        use_scale_shift_norm=False,
+        dims=2,
+        use_checkpoint=False,
+        up=False,
+        down=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.dropout = dropout
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.use_checkpoint = use_checkpoint
+        self.use_scale_shift_norm = use_scale_shift_norm
+
+        self.in_layers = nn.Sequential(
+            normalization(channels),
+            nn.SiLU(),
+            conv_nd(dims, channels, self.out_channels, 3, padding=1),
+        )
+
+        self.updown = up or down
+
+        if up:
+            self.h_upd = Upsample(channels, False, dims)
+            self.x_upd = Upsample(channels, False, dims)
+        elif down:
+            self.h_upd = Downsample(channels, False, dims)
+            self.x_upd = Downsample(channels, False, dims)
+        else:
+            self.h_upd = self.x_upd = nn.Identity()
+
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            linear(
+                emb_channels,
+                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+            ),
+        )
+        self.out_layers = nn.Sequential(
+            normalization(self.out_channels),
+            nn.SiLU(),
+            nn.Dropout(p=dropout),
+            zero_module(
+                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+            ),
+        )
+
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        elif use_conv:
+            self.skip_connection = conv_nd(
+                dims, channels, self.out_channels, 3, padding=1
+            )
+        else:
+            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+    def forward(self, x, emb):
+        """
+        Apply the block to a Tensor, conditioned on a timestep embedding.
+        :param x: an [N x C x ...] Tensor of features.
+        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        return checkpoint(
+            self._forward, (x, emb), self.parameters(), self.use_checkpoint
+        )
+
+
+    def _forward(self, x, emb):
+        if self.updown:
+            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+            h = in_rest(x)
+            h = self.h_upd(h)
+            x = self.x_upd(x)
+            h = in_conv(h)
+        else:
+            h = self.in_layers(x)
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        while len(emb_out.shape) < len(h.shape):
+            emb_out = emb_out[..., None]
+        if self.use_scale_shift_norm:
+            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+            scale, shift = th.chunk(emb_out, 2, dim=1)
+            h = out_norm(h) * (1 + scale) + shift
+            h = out_rest(h)
+        else:
+            h = h + emb_out
+            h = self.out_layers(h)
+        return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+    """
+    An attention block that allows spatial positions to attend to each other.
+    Originally ported from here, but adapted to the N-d case.
+    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+    """
+
+    def __init__(
+        self,
+        channels,
+        num_heads=1,
+        num_head_channels=-1,
+        use_checkpoint=False,
+        use_new_attention_order=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        if num_head_channels == -1:
+            self.num_heads = num_heads
+        else:
+            assert (
+                channels % num_head_channels == 0
+            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+            self.num_heads = channels // num_head_channels
+        self.use_checkpoint = use_checkpoint
+        self.norm = normalization(channels)
+        self.qkv = conv_nd(1, channels, channels * 3, 1)
+        if use_new_attention_order:
+            # split qkv before split heads
+            self.attention = QKVAttention(self.num_heads)
+        else:
+            # split heads before split qkv
+            self.attention = QKVAttentionLegacy(self.num_heads)
+
+        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+    def forward(self, x):
+        return checkpoint(self._forward, (x,), self.parameters(), True)   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+        #return pt_checkpoint(self._forward, x)  # pytorch
+
+    def _forward(self, x):
+        b, c, *spatial = x.shape
+        x = x.reshape(b, c, -1)
+        qkv = self.qkv(self.norm(x))
+        h = self.attention(qkv)
+        h = self.proj_out(h)
+        return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+    """
+    A counter for the `thop` package to count the operations in an
+    attention operation.
+    Meant to be used like:
+        macs, params = thop.profile(
+            model,
+            inputs=(inputs, timestamps),
+            custom_ops={QKVAttention: QKVAttention.count_flops},
+        )
+    """
+    b, c, *spatial = y[0].shape
+    num_spatial = int(np.prod(spatial))
+    # We perform two matmuls with the same number of ops.
+    # The first computes the weight matrix, the second computes
+    # the combination of the value vectors.
+    matmul_ops = 2 * b * (num_spatial ** 2) * c
+    model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+    """
+    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts", q * scale, k * scale
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v)
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+    """
+    A module which performs QKV attention and splits in a different order.
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts",
+            (q * scale).view(bs * self.n_heads, ch, length),
+            (k * scale).view(bs * self.n_heads, ch, length),
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+    """
+    The full UNet model with attention and timestep embedding.
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param num_res_blocks: number of residual blocks per downsample.
+    :param attention_resolutions: a collection of downsample rates at which
+        attention will take place. May be a set, list, or tuple.
+        For example, if this contains 4, then at 4x downsampling, attention
+        will be used.
+    :param dropout: the dropout probability.
+    :param channel_mult: channel multiplier for each level of the UNet.
+    :param conv_resample: if True, use learned convolutions for upsampling and
+        downsampling.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param num_classes: if specified (as an int), then this model will be
+        class-conditional with `num_classes` classes.
+    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+    :param num_heads: the number of attention heads in each attention layer.
+    :param num_heads_channels: if specified, ignore num_heads and instead use
+                               a fixed channel width per attention head.
+    :param num_heads_upsample: works with num_heads to set a different number
+                               of heads for upsampling. Deprecated.
+    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+    :param resblock_updown: use residual blocks for up/downsampling.
+    :param use_new_attention_order: use a different attention pattern for potentially
+                                    increased efficiency.
+    """
+
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        num_classes=None,
+        use_checkpoint=False,
+        use_fp16=False,
+        num_heads=-1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        use_spatial_transformer=False,    # custom transformer support
+        transformer_depth=1,              # custom transformer support
+        context_dim=None,                 # custom transformer support
+        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
+        legacy=True,
+        disable_self_attentions=None,
+        num_attention_blocks=None,
+        disable_middle_self_attn=False,
+        use_linear_in_transformer=False,
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+        if context_dim is not None:
+            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+            from omegaconf.listconfig import ListConfig
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+        if num_head_channels == -1:
+            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+                                 "as a list/tuple (per-level) with the same length as channel_mult")
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                  f"attention will still not be set.")
+
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.num_classes = num_classes
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        if self.num_classes is not None:
+            if isinstance(self.num_classes, int):
+                self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+            elif self.num_classes == "continuous":
+                print("setting up linear c_adm embedding layer")
+                self.label_emb = nn.Linear(1, time_embed_dim)
+            else:
+                raise ValueError()
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        #num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            #num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
+                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                            disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+                            use_checkpoint=use_checkpoint
+                        ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self._feature_size += ch
+
+        self.output_blocks = nn.ModuleList([])
+        for level, mult in list(enumerate(channel_mult))[::-1]:
+            for i in range(self.num_res_blocks[level] + 1):
+                ich = input_block_chans.pop()
+                layers = [
+                    ResBlock(
+                        ch + ich,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=model_channels * mult,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = model_channels * mult
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        #num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads_upsample,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                if level and i == self.num_res_blocks[level]:
+                    out_ch = ch
+                    layers.append(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            up=True,
+                        )
+                        if resblock_updown
+                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+                    )
+                    ds //= 2
+                self.output_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+
+        self.out = nn.Sequential(
+            normalization(ch),
+            nn.SiLU(),
+            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+        )
+        if self.predict_codebook_ids:
+            self.id_predictor = nn.Sequential(
+            normalization(ch),
+            conv_nd(dims, model_channels, n_embed, 1),
+            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
+        )
+
+    def convert_to_fp16(self):
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+        self.output_blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self):
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+        self.output_blocks.apply(convert_module_to_f32)
+
+    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+        """
+        Apply the model to an input batch.
+        :param x: an [N x C x ...] Tensor of inputs.
+        :param timesteps: a 1-D batch of timesteps.
+        :param context: conditioning plugged in via crossattn
+        :param y: an [N] Tensor of labels, if class-conditional.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        assert (y is not None) == (
+            self.num_classes is not None
+        ), "must specify y if and only if the model is class-conditional"
+        hs = []
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        emb = self.time_embed(t_emb)
+
+        if self.num_classes is not None:
+            assert y.shape[0] == x.shape[0]
+            emb = emb + self.label_emb(y)
+
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb, context)
+            hs.append(h)
+        h = self.middle_block(h, emb, context)
+        for module in self.output_blocks:
+            h = th.cat([h, hs.pop()], dim=1)
+            h = module(h, emb, context)
+        h = h.type(x.dtype)
+        if self.predict_codebook_ids:
+            return self.id_predictor(h)
+        else:
+            return self.out(h)
diff --git a/ldm/modules/diffusionmodules/upscaling.py b/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988
--- /dev/null
+++ b/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+    # for concatenating a downsampled image to the latent representation
+    def __init__(self, noise_schedule_config=None):
+        super(AbstractLowScaleModel, self).__init__()
+        if noise_schedule_config is not None:
+            self.register_schedule(**noise_schedule_config)
+
+    def register_schedule(self, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+                                   cosine_s=cosine_s)
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+    def forward(self, x):
+        return x, None
+
+    def decode(self, x):
+        return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+    # no noise level conditioning
+    def __init__(self):
+        super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+        self.max_noise_level = 0
+
+    def forward(self, x):
+        # fix to constant noise level
+        return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+    def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+        super().__init__(noise_schedule_config=noise_schedule_config)
+        self.max_noise_level = max_noise_level
+
+    def forward(self, x, noise_level=None):
+        if noise_level is None:
+            noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+        else:
+            assert isinstance(noise_level, torch.Tensor)
+        z = self.q_sample(x, noise_level)
+        return z, noise_level
+
+
+
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f
--- /dev/null
+++ b/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,270 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+    if schedule == "linear":
+        betas = (
+                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+        )
+
+    elif schedule == "cosine":
+        timesteps = (
+                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+        )
+        alphas = timesteps / (1 + cosine_s) * np.pi / 2
+        alphas = torch.cos(alphas).pow(2)
+        alphas = alphas / alphas[0]
+        betas = 1 - alphas[1:] / alphas[:-1]
+        betas = np.clip(betas, a_min=0, a_max=0.999)
+
+    elif schedule == "sqrt_linear":
+        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+    elif schedule == "sqrt":
+        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+    else:
+        raise ValueError(f"schedule '{schedule}' unknown.")
+    return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+    if ddim_discr_method == 'uniform':
+        c = num_ddpm_timesteps // num_ddim_timesteps
+        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+    elif ddim_discr_method == 'quad':
+        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+    else:
+        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+    # add one to get the final alpha values right (the ones from first scale to data during sampling)
+    steps_out = ddim_timesteps + 1
+    if verbose:
+        print(f'Selected timesteps for ddim sampler: {steps_out}')
+    return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+    # select alphas for computing the variance schedule
+    alphas = alphacums[ddim_timesteps]
+    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+    # according the the formula provided in https://arxiv.org/abs/2010.02502
+    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+    if verbose:
+        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+        print(f'For the chosen value of eta, which is {eta}, '
+              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+    return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+    """
+    Create a beta schedule that discretizes the given alpha_t_bar function,
+    which defines the cumulative product of (1-beta) over time from t = [0,1].
+    :param num_diffusion_timesteps: the number of betas to produce.
+    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+                      produces the cumulative product of (1-beta) up to that
+                      part of the diffusion process.
+    :param max_beta: the maximum beta to use; use values lower than 1 to
+                     prevent singularities.
+    """
+    betas = []
+    for i in range(num_diffusion_timesteps):
+        t1 = i / num_diffusion_timesteps
+        t2 = (i + 1) / num_diffusion_timesteps
+        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+    return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+    b, *_ = t.shape
+    out = a.gather(-1, t)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+    """
+    Evaluate a function without caching intermediate activations, allowing for
+    reduced memory at the expense of extra compute in the backward pass.
+    :param func: the function to evaluate.
+    :param inputs: the argument sequence to pass to `func`.
+    :param params: a sequence of parameters `func` depends on but does not
+                   explicitly take as arguments.
+    :param flag: if False, disable gradient checkpointing.
+    """
+    if flag:
+        args = tuple(inputs) + tuple(params)
+        return CheckpointFunction.apply(func, len(inputs), *args)
+    else:
+        return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, run_function, length, *args):
+        ctx.run_function = run_function
+        ctx.input_tensors = list(args[:length])
+        ctx.input_params = list(args[length:])
+        ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+                                   "dtype": torch.get_autocast_gpu_dtype(),
+                                   "cache_enabled": torch.is_autocast_cache_enabled()}
+        with torch.no_grad():
+            output_tensors = ctx.run_function(*ctx.input_tensors)
+        return output_tensors
+
+    @staticmethod
+    def backward(ctx, *output_grads):
+        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+        with torch.enable_grad(), \
+                torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+            # Fixes a bug where the first op in run_function modifies the
+            # Tensor storage in place, which is not allowed for detach()'d
+            # Tensors.
+            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+            output_tensors = ctx.run_function(*shallow_copies)
+        input_grads = torch.autograd.grad(
+            output_tensors,
+            ctx.input_tensors + ctx.input_params,
+            output_grads,
+            allow_unused=True,
+        )
+        del ctx.input_tensors
+        del ctx.input_params
+        del output_tensors
+        return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+    """
+    Create sinusoidal timestep embeddings.
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    if not repeat_only:
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+        ).to(device=timesteps.device)
+        args = timesteps[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    else:
+        embedding = repeat(timesteps, 'b -> b d', d=dim)
+    return embedding
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def scale_module(module, scale):
+    """
+    Scale the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().mul_(scale)
+    return module
+
+
+def mean_flat(tensor):
+    """
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+    """
+    Make a standard normalization layer.
+    :param channels: number of input channels.
+    :return: an nn.Module for normalization.
+    """
+    return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+    def forward(self, x):
+        return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D convolution module.
+    """
+    if dims == 1:
+        return nn.Conv1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.Conv2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.Conv3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+    """
+    Create a linear module.
+    """
+    return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D average pooling module.
+    """
+    if dims == 1:
+        return nn.AvgPool1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.AvgPool2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.AvgPool3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+    def __init__(self, c_concat_config, c_crossattn_config):
+        super().__init__()
+        self.concat_conditioner = instantiate_from_config(c_concat_config)
+        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+    def forward(self, c_concat, c_crossattn):
+        c_concat = self.concat_conditioner(c_concat)
+        c_crossattn = self.crossattn_conditioner(c_crossattn)
+        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+    def sample(self):
+        raise NotImplementedError()
+
+    def mode(self):
+        raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+    def __init__(self, value):
+        self.value = value
+
+    def sample(self):
+        return self.value
+
+    def mode(self):
+        return self.value
+
+
+class DiagonalGaussianDistribution(object):
+    def __init__(self, parameters, deterministic=False):
+        self.parameters = parameters
+        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+        self.deterministic = deterministic
+        self.std = torch.exp(0.5 * self.logvar)
+        self.var = torch.exp(self.logvar)
+        if self.deterministic:
+            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+    def sample(self):
+        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+        return x
+
+    def kl(self, other=None):
+        if self.deterministic:
+            return torch.Tensor([0.])
+        else:
+            if other is None:
+                return 0.5 * torch.sum(torch.pow(self.mean, 2)
+                                       + self.var - 1.0 - self.logvar,
+                                       dim=[1, 2, 3])
+            else:
+                return 0.5 * torch.sum(
+                    torch.pow(self.mean - other.mean, 2) / other.var
+                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
+                    dim=[1, 2, 3])
+
+    def nll(self, sample, dims=[1,2,3]):
+        if self.deterministic:
+            return torch.Tensor([0.])
+        logtwopi = np.log(2.0 * np.pi)
+        return 0.5 * torch.sum(
+            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+            dim=dims)
+
+    def mode(self):
+        return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+    """
+    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+    Compute the KL divergence between two gaussians.
+    Shapes are automatically broadcasted, so batches can be compared to
+    scalars, among other use cases.
+    """
+    tensor = None
+    for obj in (mean1, logvar1, mean2, logvar2):
+        if isinstance(obj, torch.Tensor):
+            tensor = obj
+            break
+    assert tensor is not None, "at least one argument must be a Tensor"
+
+    # Force variances to be Tensors. Broadcasting helps convert scalars to
+    # Tensors, but it does not work for torch.exp().
+    logvar1, logvar2 = [
+        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+        for x in (logvar1, logvar2)
+    ]
+
+    return 0.5 * (
+        -1.0
+        + logvar2
+        - logvar1
+        + torch.exp(logvar1 - logvar2)
+        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+    )
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+    def __init__(self, model, decay=0.9999, use_num_upates=True):
+        super().__init__()
+        if decay < 0.0 or decay > 1.0:
+            raise ValueError('Decay must be between 0 and 1')
+
+        self.m_name2s_name = {}
+        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+        else torch.tensor(-1, dtype=torch.int))
+
+        for name, p in model.named_parameters():
+            if p.requires_grad:
+                # remove as '.'-character is not allowed in buffers
+                s_name = name.replace('.', '')
+                self.m_name2s_name.update({name: s_name})
+                self.register_buffer(s_name, p.clone().detach().data)
+
+        self.collected_params = []
+
+    def reset_num_updates(self):
+        del self.num_updates
+        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+    def forward(self, model):
+        decay = self.decay
+
+        if self.num_updates >= 0:
+            self.num_updates += 1
+            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+        one_minus_decay = 1.0 - decay
+
+        with torch.no_grad():
+            m_param = dict(model.named_parameters())
+            shadow_params = dict(self.named_buffers())
+
+            for key in m_param:
+                if m_param[key].requires_grad:
+                    sname = self.m_name2s_name[key]
+                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+                else:
+                    assert not key in self.m_name2s_name
+
+    def copy_to(self, model):
+        m_param = dict(model.named_parameters())
+        shadow_params = dict(self.named_buffers())
+        for key in m_param:
+            if m_param[key].requires_grad:
+                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+            else:
+                assert not key in self.m_name2s_name
+
+    def store(self, parameters):
+        """
+        Save the current parameters for restoring later.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            temporarily stored.
+        """
+        self.collected_params = [param.clone() for param in parameters]
+
+    def restore(self, parameters):
+        """
+        Restore the parameters stored with the `store` method.
+        Useful to validate the model with EMA parameters without affecting the
+        original optimization process. Store the parameters before the
+        `copy_to` method. After validation (or model saving), use this to
+        restore the former parameters.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            updated with the stored parameters.
+        """
+        for c_param, param in zip(self.collected_params, parameters):
+            param.data.copy_(c_param.data)
diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f3638a8ec4062d6aa4cd7046f34be502f785a88
--- /dev/null
+++ b/ldm/modules/encoders/modules.py
@@ -0,0 +1,385 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPVisionModelWithProjection
+
+import open_clip
+from ldm.util import count_params
+
+
+def _expand_mask(mask, dtype, tgt_len=None):
+    """
+    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+    """
+    bsz, src_len = mask.size()
+    tgt_len = tgt_len if tgt_len is not None else src_len
+
+    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+    inverted_mask = 1.0 - expanded_mask
+
+    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+def _build_causal_attention_mask(bsz, seq_len, dtype):
+    # lazily create causal attention mask, with full attention between the vision tokens
+    # pytorch uses additive attention mask; fill with -inf
+    mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
+    mask.fill_(torch.tensor(torch.finfo(dtype).min))
+    mask.triu_(1)  # zero out the lower diagonal
+    mask = mask.unsqueeze(1)  # expand mask
+    return mask
+
+class AbstractEncoder(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def encode(self, *args, **kwargs):
+        raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+    def encode(self, x):
+        return x
+
+
+class ClassEmbedder(nn.Module):
+    def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+        super().__init__()
+        self.key = key
+        self.embedding = nn.Embedding(n_classes, embed_dim)
+        self.n_classes = n_classes
+        self.ucg_rate = ucg_rate
+
+    def forward(self, batch, key=None, disable_dropout=False):
+        if key is None:
+            key = self.key
+        # this is for use in crossattn
+        c = batch[key][:, None]
+        if self.ucg_rate > 0. and not disable_dropout:
+            mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+            c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
+            c = c.long()
+        c = self.embedding(c)
+        return c
+
+    def get_unconditional_conditioning(self, bs, device="cuda"):
+        uc_class = self.n_classes - 1  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+        uc = torch.ones((bs,), device=device) * uc_class
+        uc = {self.key: uc}
+        return uc
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+    """Uses the T5 transformer encoder for text"""
+    def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+        super().__init__()
+        self.tokenizer = T5Tokenizer.from_pretrained(version)
+        self.transformer = T5EncoderModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length   # TODO: typical value?
+        if freeze:
+            self.freeze()
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        #self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(input_ids=tokens)
+
+        z = outputs.last_hidden_state
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+    """Uses the CLIP transformer encoder for text (from huggingface)"""
+    LAYERS = [
+        "last",
+        "pooled",
+        "hidden"
+    ]
+    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
+        super().__init__()
+        assert layer in self.LAYERS
+        self.tokenizer = CLIPTokenizer.from_pretrained(version)
+        self.transformer = CLIPTextModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        self.layer_idx = layer_idx
+        if layer == "hidden":
+            assert layer_idx is not None
+            assert 0 <= abs(layer_idx) <= 12
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        # self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
+        if self.layer == "last":
+            z = outputs.last_hidden_state
+        elif self.layer == "pooled":
+            z = outputs.pooler_output[:, None, :]
+        else:
+            z = outputs.hidden_states[self.layer_idx]
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+    """
+    Uses the OpenCLIP transformer encoder for text
+    """
+    LAYERS = [
+        # "pooled",
+        "last",
+        "penultimate"
+    ]
+
+    def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+                 freeze=True, layer="last"):
+        super().__init__()
+        assert layer in self.LAYERS
+        model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+        del model.visual
+        self.model = model
+
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        if self.layer == "last":
+            self.layer_idx = 0
+        elif self.layer == "penultimate":
+            self.layer_idx = 1
+        else:
+            raise NotImplementedError()
+
+    def freeze(self):
+        self.model = self.model.eval()
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        tokens = open_clip.tokenize(text)
+        z = self.encode_with_transformer(tokens.to(self.device))
+        return z
+
+    def encode_with_transformer(self, text):
+        x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]
+        x = x + self.model.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.model.ln_final(x)
+        return x
+
+    def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+        for i, r in enumerate(self.model.transformer.resblocks):
+            if i == len(self.model.transformer.resblocks) - self.layer_idx:
+                break
+            if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+                x = checkpoint(r, x, attn_mask)
+            else:
+                x = r(x, attn_mask=attn_mask)
+        return x
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+    def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+                 clip_max_length=77, t5_max_length=77):
+        super().__init__()
+        self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+        print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+              f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+
+    def encode(self, text):
+        return self(text)
+
+    def forward(self, text):
+        clip_z = self.clip_encoder.encode(text)
+        t5_z = self.t5_encoder.encode(text)
+        return [clip_z, t5_z]
+
+
+class FrozenCLIPEmbedderT3(AbstractEncoder):
+    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False):
+        super().__init__()
+        self.tokenizer = CLIPTokenizer.from_pretrained(version)
+        self.transformer = CLIPTextModel.from_pretrained(version)
+        if use_vision:
+            self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
+            self.processor = AutoProcessor.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+
+        def embedding_forward(
+            self,
+            input_ids=None,
+            position_ids=None,
+            inputs_embeds=None,
+            embedding_manager=None,
+        ):
+            seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+            if position_ids is None:
+                position_ids = self.position_ids[:, :seq_length]
+            if inputs_embeds is None:
+                inputs_embeds = self.token_embedding(input_ids)
+            if embedding_manager is not None:
+                inputs_embeds = embedding_manager(input_ids, inputs_embeds)
+            position_embeddings = self.position_embedding(position_ids)
+            embeddings = inputs_embeds + position_embeddings
+            return embeddings
+
+        self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
+
+        def encoder_forward(
+            self,
+            inputs_embeds,
+            attention_mask=None,
+            causal_attention_mask=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+        ):
+            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+            output_hidden_states = (
+                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+            )
+            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+            encoder_states = () if output_hidden_states else None
+            all_attentions = () if output_attentions else None
+            hidden_states = inputs_embeds
+            for idx, encoder_layer in enumerate(self.layers):
+                if output_hidden_states:
+                    encoder_states = encoder_states + (hidden_states,)
+                layer_outputs = encoder_layer(
+                    hidden_states,
+                    attention_mask,
+                    causal_attention_mask,
+                    output_attentions=output_attentions,
+                )
+                hidden_states = layer_outputs[0]
+                if output_attentions:
+                    all_attentions = all_attentions + (layer_outputs[1],)
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            return hidden_states
+
+        self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
+
+        def text_encoder_forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            position_ids=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+            embedding_manager=None,
+        ):
+            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+            output_hidden_states = (
+                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+            )
+            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+            if input_ids is None:
+                raise ValueError("You have to specify either input_ids")
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+            hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
+            bsz, seq_len = input_shape
+            # CLIP's text model uses causal mask, prepare it here.
+            # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+            causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
+                hidden_states.device
+            )
+            # expand attention_mask
+            if attention_mask is not None:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
+            last_hidden_state = self.encoder(
+                inputs_embeds=hidden_states,
+                attention_mask=attention_mask,
+                causal_attention_mask=causal_attention_mask,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+            last_hidden_state = self.final_layer_norm(last_hidden_state)
+            return last_hidden_state
+
+        self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
+
+        def transformer_forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            position_ids=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+            embedding_manager=None,
+        ):
+            return self.text_model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                position_ids=position_ids,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+                embedding_manager=embedding_manager
+            )
+
+        self.transformer.forward = transformer_forward.__get__(self.transformer)
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text, **kwargs):
+        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+        tokens = batch_encoding["input_ids"].to(self.device)
+        z = self.transformer(input_ids=tokens, **kwargs)
+        return z
+
+    def encode(self, text, **kwargs):
+        return self(text, **kwargs)
diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+    '''
+    Args:
+        img: numpy image, WxH or WxHxC
+        sf: scale factor
+    Return:
+        cropped image
+    '''
+    w, h = img.shape[:2]
+    im = np.copy(img)
+    return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+    """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+    k_size = k.shape[0]
+    # Calculate the big kernels size
+    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+    # Loop over the small kernel to fill the big one
+    for r in range(k_size):
+        for c in range(k_size):
+            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+    # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+    crop = k_size // 2
+    cropped_big_k = big_k[crop:-crop, crop:-crop]
+    # Normalize to 1
+    return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+    """ generate an anisotropic Gaussian kernel
+    Args:
+        ksize : e.g., 15, kernel size
+        theta : [0,  pi], rotation angle range
+        l1    : [0.1,50], scaling of eigenvalues
+        l2    : [0.1,l1], scaling of eigenvalues
+        If l1 = l2, will get an isotropic Gaussian kernel.
+    Returns:
+        k     : kernel
+    """
+
+    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+    V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+    D = np.array([[l1, 0], [0, l2]])
+    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+    return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+    center = size / 2.0 + 0.5
+    k = np.zeros([size, size])
+    for y in range(size):
+        for x in range(size):
+            cy = y - center + 1
+            cx = x - center + 1
+            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+    k = k / np.sum(k)
+    return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+    """shift pixel for super-resolution with different scale factors
+    Args:
+        x: WxHxC or WxH
+        sf: scale factor
+        upper_left: shift direction
+    """
+    h, w = x.shape[:2]
+    shift = (sf - 1) * 0.5
+    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+    if upper_left:
+        x1 = xv + shift
+        y1 = yv + shift
+    else:
+        x1 = xv - shift
+        y1 = yv - shift
+
+    x1 = np.clip(x1, 0, w - 1)
+    y1 = np.clip(y1, 0, h - 1)
+
+    if x.ndim == 2:
+        x = interp2d(xv, yv, x)(x1, y1)
+    if x.ndim == 3:
+        for i in range(x.shape[-1]):
+            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+    return x
+
+
+def blur(x, k):
+    '''
+    x: image, NxcxHxW
+    k: kernel, Nx1xhxw
+    '''
+    n, c = x.shape[:2]
+    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+    k = k.repeat(1, c, 1, 1)
+    k = k.view(-1, 1, k.shape[2], k.shape[3])
+    x = x.view(1, -1, x.shape[2], x.shape[3])
+    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+    x = x.view(n, c, x.shape[2], x.shape[3])
+
+    return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+    """"
+    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+    # Kai Zhang
+    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
+    # max_var = 2.5 * sf
+    """
+    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+    lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+    lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+    theta = np.random.rand() * np.pi  # random theta
+    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+    # Set COV matrix using Lambdas and Theta
+    LAMBDA = np.diag([lambda_1, lambda_2])
+    Q = np.array([[np.cos(theta), -np.sin(theta)],
+                  [np.sin(theta), np.cos(theta)]])
+    SIGMA = Q @ LAMBDA @ Q.T
+    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+    # Set expectation position (shifting kernel for aligned image)
+    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)
+    MU = MU[None, None, :, None]
+
+    # Create meshgrid for Gaussian
+    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+    Z = np.stack([X, Y], 2)[:, :, :, None]
+
+    # Calcualte Gaussian for every pixel of the kernel
+    ZZ = Z - MU
+    ZZ_t = ZZ.transpose(0, 1, 3, 2)
+    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+    # shift the kernel so it will be centered
+    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+    # Normalize the kernel and return
+    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+    kernel = raw_kernel / np.sum(raw_kernel)
+    return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+    hsize = [hsize, hsize]
+    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+    std = sigma
+    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+    arg = -(x * x + y * y) / (2 * std * std)
+    h = np.exp(arg)
+    h[h < scipy.finfo(float).eps * h.max()] = 0
+    sumh = h.sum()
+    if sumh != 0:
+        h = h / sumh
+    return h
+
+
+def fspecial_laplacian(alpha):
+    alpha = max([0, min([alpha, 1])])
+    h1 = alpha / (alpha + 1)
+    h2 = (1 - alpha) / (alpha + 1)
+    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+    h = np.array(h)
+    return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+    '''
+    python code from:
+    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+    '''
+    if filter_type == 'gaussian':
+        return fspecial_gaussian(*args, **kwargs)
+    if filter_type == 'laplacian':
+        return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+    '''
+    Args:
+        x: HxWxC image, [0, 1]
+        sf: down-scale factor
+    Return:
+        bicubicly downsampled LR image
+    '''
+    x = util.imresize_np(x, scale=1 / sf)
+    return x
+
+
+def srmd_degradation(x, k, sf=3):
+    ''' blur + bicubic downsampling
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2018learning,
+          title={Learning a single convolutional super-resolution network for multiple degradations},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={3262--3271},
+          year={2018}
+        }
+    '''
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
+    x = bicubic_degradation(x, sf=sf)
+    return x
+
+
+def dpsr_degradation(x, k, sf=3):
+    ''' bicubic downsampling + blur
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2019deep,
+          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={1671--1681},
+          year={2019}
+        }
+    '''
+    x = bicubic_degradation(x, sf=sf)
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    return x
+
+
+def classical_degradation(x, k, sf=3):
+    ''' blur + downsampling
+    Args:
+        x: HxWxC image, [0, 1]/[0, 255]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    '''
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+    st = 0
+    return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+    """USM sharpening. borrowed from real-ESRGAN
+    Input image: I; Blurry image: B.
+    1. K = I + weight * (I - B)
+    2. Mask = 1 if abs(I - B) > threshold, else: 0
+    3. Blur mask:
+    4. Out = Mask * K + (1 - Mask) * I
+    Args:
+        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+        weight (float): Sharp weight. Default: 1.
+        radius (float): Kernel size of Gaussian blur. Default: 50.
+        threshold (int):
+    """
+    if radius % 2 == 0:
+        radius += 1
+    blur = cv2.GaussianBlur(img, (radius, radius), 0)
+    residual = img - blur
+    mask = np.abs(residual) * 255 > threshold
+    mask = mask.astype('float32')
+    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+    K = img + weight * residual
+    K = np.clip(K, 0, 1)
+    return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+    wd2 = 4.0 + sf
+    wd = 2.0 + 0.2 * sf
+    if random.random() < 0.5:
+        l1 = wd2 * random.random()
+        l2 = wd2 * random.random()
+        k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+    else:
+        k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+    img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+    return img
+
+
+def add_resize(img, sf=4):
+    rnum = np.random.rand()
+    if rnum > 0.8:  # up
+        sf1 = random.uniform(1, 2)
+    elif rnum < 0.7:  # down
+        sf1 = random.uniform(0.5 / sf, 1)
+    else:
+        sf1 = 1.0
+    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+    img = np.clip(img, 0.0, 1.0)
+
+    return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+#     noise_level = random.randint(noise_level1, noise_level2)
+#     rnum = np.random.rand()
+#     if rnum > 0.6:  # add color Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+#     elif rnum < 0.4:  # add grayscale Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+#     else:  # add  noise
+#         L = noise_level2 / 255.
+#         D = np.diag(np.random.rand(3))
+#         U = orth(np.random.rand(3, 3))
+#         conv = np.dot(np.dot(np.transpose(U), D), U)
+#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+#     img = np.clip(img, 0.0, 1.0)
+#     return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    rnum = np.random.rand()
+    if rnum > 0.6:  # add color Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:  # add grayscale Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:  # add  noise
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    img = np.clip(img, 0.0, 1.0)
+    rnum = random.random()
+    if rnum > 0.6:
+        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:
+        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_Poisson_noise(img):
+    img = np.clip((img * 255.0).round(), 0, 255) / 255.
+    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]
+    if random.random() < 0.5:
+        img = np.random.poisson(img * vals).astype(np.float32) / vals
+    else:
+        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+        img += noise_gray[:, :, np.newaxis]
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_JPEG_noise(img):
+    quality_factor = random.randint(30, 95)
+    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+    img = cv2.imdecode(encimg, 1)
+    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+    return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+    h, w = lq.shape[:2]
+    rnd_h = random.randint(0, h - lq_patchsize)
+    rnd_w = random.randint(0, w - lq_patchsize)
+    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+    return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    hq = img.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+                             interpolation=random.choice([1, 2, 3]))
+        else:
+            img = util.imresize_np(img, 1 / 2, True)
+        img = np.clip(img, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            img = add_blur(img, sf=sf)
+
+        elif i == 1:
+            img = add_blur(img, sf=sf)
+
+        elif i == 2:
+            a, b = img.shape[1], img.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+                                 interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                img = img[0::sf, 0::sf, ...]  # nearest downsampling
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                img = add_JPEG_noise(img)
+
+        elif i == 6:
+            # add processed camera sensor noise
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+    return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    image = util.uint2single(image)
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = image.shape[:2]
+    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = image.shape[:2]
+
+    hq = image.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+                               interpolation=random.choice([1, 2, 3]))
+        else:
+            image = util.imresize_np(image, 1 / 2, True)
+        image = np.clip(image, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            image = add_blur(image, sf=sf)
+
+        elif i == 1:
+            image = add_blur(image, sf=sf)
+
+        elif i == 2:
+            a, b = image.shape[1], image.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+                                   interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                image = image[0::sf, 0::sf, ...]  # nearest downsampling
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                image = add_JPEG_noise(image)
+
+        # elif i == 6:
+        #     # add processed camera sensor noise
+        #     if random.random() < isp_prob and isp_model is not None:
+        #         with torch.no_grad():
+        #             img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    image = add_JPEG_noise(image)
+    image = util.single2uint(image)
+    example = {"image":image}
+    return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+    """
+    This is an extended degradation model by combining
+    the degradation models of BSRGAN and Real-ESRGAN
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    use_shuffle: the degradation shuffle
+    use_sharp: sharpening the img
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    if use_sharp:
+        img = add_sharpening(img)
+    hq = img.copy()
+
+    if random.random() < shuffle_prob:
+        shuffle_order = random.sample(range(13), 13)
+    else:
+        shuffle_order = list(range(13))
+        # local shuffle for noise, JPEG is always the last one
+        shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+        shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+    poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+    for i in shuffle_order:
+        if i == 0:
+            img = add_blur(img, sf=sf)
+        elif i == 1:
+            img = add_resize(img, sf=sf)
+        elif i == 2:
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+        elif i == 3:
+            if random.random() < poisson_prob:
+                img = add_Poisson_noise(img)
+        elif i == 4:
+            if random.random() < speckle_prob:
+                img = add_speckle_noise(img)
+        elif i == 5:
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+        elif i == 6:
+            img = add_JPEG_noise(img)
+        elif i == 7:
+            img = add_blur(img, sf=sf)
+        elif i == 8:
+            img = add_resize(img, sf=sf)
+        elif i == 9:
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+        elif i == 10:
+            if random.random() < poisson_prob:
+                img = add_Poisson_noise(img)
+        elif i == 11:
+            if random.random() < speckle_prob:
+                img = add_speckle_noise(img)
+        elif i == 12:
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+        else:
+            print('check the shuffle!')
+
+    # resize to desired size
+    img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+                     interpolation=random.choice([1, 2, 3]))
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+    return img, hq
+
+
+if __name__ == '__main__':
+	print("hey")
+	img = util.imread_uint('utils/test.png', 3)
+	print(img)
+	img = util.uint2single(img)
+	print(img)
+	img = img[:448, :448]
+	h = img.shape[0] // 4
+	print("resizing to", h)
+	sf = 4
+	deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+	for i in range(20):
+		print(i)
+		img_lq = deg_fn(img)
+		print(img_lq)
+		img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+		print(img_lq.shape)
+		print("bicubic", img_lq_bicubic.shape)
+		print(img_hq.shape)
+		lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+		                        interpolation=0)
+		lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+		                        interpolation=0)
+		img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+		util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+    '''
+    Args:
+        img: numpy image, WxH or WxHxC
+        sf: scale factor
+    Return:
+        cropped image
+    '''
+    w, h = img.shape[:2]
+    im = np.copy(img)
+    return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+    """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+    k_size = k.shape[0]
+    # Calculate the big kernels size
+    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+    # Loop over the small kernel to fill the big one
+    for r in range(k_size):
+        for c in range(k_size):
+            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+    # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+    crop = k_size // 2
+    cropped_big_k = big_k[crop:-crop, crop:-crop]
+    # Normalize to 1
+    return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+    """ generate an anisotropic Gaussian kernel
+    Args:
+        ksize : e.g., 15, kernel size
+        theta : [0,  pi], rotation angle range
+        l1    : [0.1,50], scaling of eigenvalues
+        l2    : [0.1,l1], scaling of eigenvalues
+        If l1 = l2, will get an isotropic Gaussian kernel.
+    Returns:
+        k     : kernel
+    """
+
+    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+    V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+    D = np.array([[l1, 0], [0, l2]])
+    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+    return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+    center = size / 2.0 + 0.5
+    k = np.zeros([size, size])
+    for y in range(size):
+        for x in range(size):
+            cy = y - center + 1
+            cx = x - center + 1
+            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+    k = k / np.sum(k)
+    return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+    """shift pixel for super-resolution with different scale factors
+    Args:
+        x: WxHxC or WxH
+        sf: scale factor
+        upper_left: shift direction
+    """
+    h, w = x.shape[:2]
+    shift = (sf - 1) * 0.5
+    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+    if upper_left:
+        x1 = xv + shift
+        y1 = yv + shift
+    else:
+        x1 = xv - shift
+        y1 = yv - shift
+
+    x1 = np.clip(x1, 0, w - 1)
+    y1 = np.clip(y1, 0, h - 1)
+
+    if x.ndim == 2:
+        x = interp2d(xv, yv, x)(x1, y1)
+    if x.ndim == 3:
+        for i in range(x.shape[-1]):
+            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+    return x
+
+
+def blur(x, k):
+    '''
+    x: image, NxcxHxW
+    k: kernel, Nx1xhxw
+    '''
+    n, c = x.shape[:2]
+    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+    k = k.repeat(1, c, 1, 1)
+    k = k.view(-1, 1, k.shape[2], k.shape[3])
+    x = x.view(1, -1, x.shape[2], x.shape[3])
+    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+    x = x.view(n, c, x.shape[2], x.shape[3])
+
+    return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+    """"
+    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+    # Kai Zhang
+    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
+    # max_var = 2.5 * sf
+    """
+    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+    lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+    lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+    theta = np.random.rand() * np.pi  # random theta
+    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+    # Set COV matrix using Lambdas and Theta
+    LAMBDA = np.diag([lambda_1, lambda_2])
+    Q = np.array([[np.cos(theta), -np.sin(theta)],
+                  [np.sin(theta), np.cos(theta)]])
+    SIGMA = Q @ LAMBDA @ Q.T
+    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+    # Set expectation position (shifting kernel for aligned image)
+    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)
+    MU = MU[None, None, :, None]
+
+    # Create meshgrid for Gaussian
+    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+    Z = np.stack([X, Y], 2)[:, :, :, None]
+
+    # Calcualte Gaussian for every pixel of the kernel
+    ZZ = Z - MU
+    ZZ_t = ZZ.transpose(0, 1, 3, 2)
+    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+    # shift the kernel so it will be centered
+    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+    # Normalize the kernel and return
+    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+    kernel = raw_kernel / np.sum(raw_kernel)
+    return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+    hsize = [hsize, hsize]
+    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+    std = sigma
+    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+    arg = -(x * x + y * y) / (2 * std * std)
+    h = np.exp(arg)
+    h[h < scipy.finfo(float).eps * h.max()] = 0
+    sumh = h.sum()
+    if sumh != 0:
+        h = h / sumh
+    return h
+
+
+def fspecial_laplacian(alpha):
+    alpha = max([0, min([alpha, 1])])
+    h1 = alpha / (alpha + 1)
+    h2 = (1 - alpha) / (alpha + 1)
+    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+    h = np.array(h)
+    return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+    '''
+    python code from:
+    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+    '''
+    if filter_type == 'gaussian':
+        return fspecial_gaussian(*args, **kwargs)
+    if filter_type == 'laplacian':
+        return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+    '''
+    Args:
+        x: HxWxC image, [0, 1]
+        sf: down-scale factor
+    Return:
+        bicubicly downsampled LR image
+    '''
+    x = util.imresize_np(x, scale=1 / sf)
+    return x
+
+
+def srmd_degradation(x, k, sf=3):
+    ''' blur + bicubic downsampling
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2018learning,
+          title={Learning a single convolutional super-resolution network for multiple degradations},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={3262--3271},
+          year={2018}
+        }
+    '''
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
+    x = bicubic_degradation(x, sf=sf)
+    return x
+
+
+def dpsr_degradation(x, k, sf=3):
+    ''' bicubic downsampling + blur
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2019deep,
+          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={1671--1681},
+          year={2019}
+        }
+    '''
+    x = bicubic_degradation(x, sf=sf)
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    return x
+
+
+def classical_degradation(x, k, sf=3):
+    ''' blur + downsampling
+    Args:
+        x: HxWxC image, [0, 1]/[0, 255]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    '''
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+    st = 0
+    return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+    """USM sharpening. borrowed from real-ESRGAN
+    Input image: I; Blurry image: B.
+    1. K = I + weight * (I - B)
+    2. Mask = 1 if abs(I - B) > threshold, else: 0
+    3. Blur mask:
+    4. Out = Mask * K + (1 - Mask) * I
+    Args:
+        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+        weight (float): Sharp weight. Default: 1.
+        radius (float): Kernel size of Gaussian blur. Default: 50.
+        threshold (int):
+    """
+    if radius % 2 == 0:
+        radius += 1
+    blur = cv2.GaussianBlur(img, (radius, radius), 0)
+    residual = img - blur
+    mask = np.abs(residual) * 255 > threshold
+    mask = mask.astype('float32')
+    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+    K = img + weight * residual
+    K = np.clip(K, 0, 1)
+    return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+    wd2 = 4.0 + sf
+    wd = 2.0 + 0.2 * sf
+
+    wd2 = wd2/4
+    wd = wd/4
+
+    if random.random() < 0.5:
+        l1 = wd2 * random.random()
+        l2 = wd2 * random.random()
+        k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+    else:
+        k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+    img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+    return img
+
+
+def add_resize(img, sf=4):
+    rnum = np.random.rand()
+    if rnum > 0.8:  # up
+        sf1 = random.uniform(1, 2)
+    elif rnum < 0.7:  # down
+        sf1 = random.uniform(0.5 / sf, 1)
+    else:
+        sf1 = 1.0
+    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+    img = np.clip(img, 0.0, 1.0)
+
+    return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+#     noise_level = random.randint(noise_level1, noise_level2)
+#     rnum = np.random.rand()
+#     if rnum > 0.6:  # add color Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+#     elif rnum < 0.4:  # add grayscale Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+#     else:  # add  noise
+#         L = noise_level2 / 255.
+#         D = np.diag(np.random.rand(3))
+#         U = orth(np.random.rand(3, 3))
+#         conv = np.dot(np.dot(np.transpose(U), D), U)
+#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+#     img = np.clip(img, 0.0, 1.0)
+#     return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    rnum = np.random.rand()
+    if rnum > 0.6:  # add color Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:  # add grayscale Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:  # add  noise
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    img = np.clip(img, 0.0, 1.0)
+    rnum = random.random()
+    if rnum > 0.6:
+        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:
+        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_Poisson_noise(img):
+    img = np.clip((img * 255.0).round(), 0, 255) / 255.
+    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]
+    if random.random() < 0.5:
+        img = np.random.poisson(img * vals).astype(np.float32) / vals
+    else:
+        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+        img += noise_gray[:, :, np.newaxis]
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_JPEG_noise(img):
+    quality_factor = random.randint(80, 95)
+    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+    img = cv2.imdecode(encimg, 1)
+    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+    return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+    h, w = lq.shape[:2]
+    rnd_h = random.randint(0, h - lq_patchsize)
+    rnd_w = random.randint(0, w - lq_patchsize)
+    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+    return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    hq = img.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+                             interpolation=random.choice([1, 2, 3]))
+        else:
+            img = util.imresize_np(img, 1 / 2, True)
+        img = np.clip(img, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            img = add_blur(img, sf=sf)
+
+        elif i == 1:
+            img = add_blur(img, sf=sf)
+
+        elif i == 2:
+            a, b = img.shape[1], img.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+                                 interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                img = img[0::sf, 0::sf, ...]  # nearest downsampling
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                img = add_JPEG_noise(img)
+
+        elif i == 6:
+            # add processed camera sensor noise
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+    return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    image = util.uint2single(image)
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = image.shape[:2]
+    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = image.shape[:2]
+
+    hq = image.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+                               interpolation=random.choice([1, 2, 3]))
+        else:
+            image = util.imresize_np(image, 1 / 2, True)
+        image = np.clip(image, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            image = add_blur(image, sf=sf)
+
+        # elif i == 1:
+        #     image = add_blur(image, sf=sf)
+
+        if i == 0:
+            pass
+
+        elif i == 2:
+            a, b = image.shape[1], image.shape[0]
+            # downsample2
+            if random.random() < 0.8:
+                sf1 = random.uniform(1, 2 * sf)
+                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+                                   interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                image = image[0::sf, 0::sf, ...]  # nearest downsampling
+
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                image = add_JPEG_noise(image)
+        #
+        # elif i == 6:
+        #     # add processed camera sensor noise
+        #     if random.random() < isp_prob and isp_model is not None:
+        #         with torch.no_grad():
+        #             img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    image = add_JPEG_noise(image)
+    image = util.single2uint(image)
+    if up:
+        image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC)  # todo: random, as above? want to condition on it then
+    example = {"image": image}
+    return example
+
+
+
+
+if __name__ == '__main__':
+    print("hey")
+    img = util.imread_uint('utils/test.png', 3)
+    img = img[:448, :448]
+    h = img.shape[0] // 4
+    print("resizing to", h)
+    sf = 4
+    deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+    for i in range(20):
+        print(i)
+        img_hq = img
+        img_lq = deg_fn(img)["image"]
+        img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+        print(img_lq)
+        img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+        print(img_lq.shape)
+        print("bicubic", img_lq_bicubic.shape)
+        print(img_hq.shape)
+        lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+                                interpolation=0)
+        lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+                                        (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+                                        interpolation=0)
+        img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+        util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt   # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+    return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+    plt.figure(figsize=figsize)
+    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+    if title:
+        plt.title(title)
+    if cbar:
+        plt.colorbar()
+    plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+    plt.figure(figsize=figsize)
+    ax3 = plt.axes(projection='3d')
+
+    w, h = Z.shape[:2]
+    xx = np.arange(0,w,1)
+    yy = np.arange(0,h,1)
+    X, Y = np.meshgrid(xx, yy)
+    ax3.plot_surface(X,Y,Z,cmap=cmap)
+    #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+    plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+    paths = None  # return None if dataroot is None
+    if dataroot is not None:
+        paths = sorted(_get_paths_from_images(dataroot))
+    return paths
+
+
+def _get_paths_from_images(path):
+    assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+    images = []
+    for dirpath, _, fnames in sorted(os.walk(path)):
+        for fname in sorted(fnames):
+            if is_image_file(fname):
+                img_path = os.path.join(dirpath, fname)
+                images.append(img_path)
+    assert images, '{:s} has no valid image file'.format(path)
+    return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images 
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+    w, h = img.shape[:2]
+    patches = []
+    if w > p_max and h > p_max:
+        w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+        h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+        w1.append(w-p_size)
+        h1.append(h-p_size)
+#        print(w1)
+#        print(h1)
+        for i in w1:
+            for j in h1:
+                patches.append(img[i:i+p_size, j:j+p_size,:])
+    else:
+        patches.append(img)
+
+    return patches
+
+
+def imssave(imgs, img_path):
+    """
+    imgs: list, N images of size WxHxC
+    """
+    img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+    for i, img in enumerate(imgs):
+        if img.ndim == 3:
+            img = img[:, :, [2, 1, 0]]
+        new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+        cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+    """
+    split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+    and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+    will be splitted.
+    Args:
+        original_dataroot:
+        taget_dataroot:
+        p_size: size of small images
+        p_overlap: patch size in training is a good choice
+        p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+    """
+    paths = get_image_paths(original_dataroot)
+    for img_path in paths:
+        # img_name, ext = os.path.splitext(os.path.basename(img_path))
+        img = imread_uint(img_path, n_channels=n_channels)
+        patches = patches_from_image(img, p_size, p_overlap, p_max)
+        imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+        #if original_dataroot == taget_dataroot:
+        #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def mkdirs(paths):
+    if isinstance(paths, str):
+        mkdir(paths)
+    else:
+        for path in paths:
+            mkdir(path)
+
+
+def mkdir_and_rename(path):
+    if os.path.exists(path):
+        new_name = path + '_archived_' + get_timestamp()
+        print('Path already exists. Rename it to [{:s}]'.format(new_name))
+        os.rename(path, new_name)
+    os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+    #  input: path
+    # output: HxWx3(RGB or GGG), or HxWx1 (G)
+    if n_channels == 1:
+        img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE
+        img = np.expand_dims(img, axis=2)  # HxWx1
+    elif n_channels == 3:
+        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G
+        if img.ndim == 2:
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
+        else:
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB
+    return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+    img = np.squeeze(img)
+    if img.ndim == 3:
+        img = img[:, :, [2, 1, 0]]
+    cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+    img = np.squeeze(img)
+    if img.ndim == 3:
+        img = img[:, :, [2, 1, 0]]
+    cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+    # read image by cv2
+    # return: Numpy float32, HWC, BGR, [0,1]
+    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # cv2.IMREAD_GRAYSCALE
+    img = img.astype(np.float32) / 255.
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    # some images have 4 channels
+    if img.shape[2] > 3:
+        img = img[:, :, :3]
+    return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <--->  numpy(unit)
+# numpy(single) <--->  tensor
+# numpy(unit)   <--->  tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <--->  numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+    return np.float32(img/255.)
+
+
+def single2uint(img):
+
+    return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+    return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+    return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <--->  tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+    img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+    return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <--->  tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+    img = img.data.squeeze().float().cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+
+    return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+    img = img.data.squeeze().float().cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+    elif img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return img
+
+
+def single2tensor5(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+    '''
+    Converts a torch Tensor into an image Numpy array of BGR channel order
+    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+    '''
+    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # squeeze first, then clamp
+    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
+    n_dim = tensor.dim()
+    if n_dim == 4:
+        n_img = len(tensor)
+        img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
+    elif n_dim == 3:
+        img_np = tensor.numpy()
+        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
+    elif n_dim == 2:
+        img_np = tensor.numpy()
+    else:
+        raise TypeError(
+            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+    if out_type == np.uint8:
+        img_np = (img_np * 255.0).round()
+        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+    return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return np.flipud(np.rot90(img))
+    elif mode == 2:
+        return np.flipud(img)
+    elif mode == 3:
+        return np.rot90(img, k=3)
+    elif mode == 4:
+        return np.flipud(np.rot90(img, k=2))
+    elif mode == 5:
+        return np.rot90(img)
+    elif mode == 6:
+        return np.rot90(img, k=2)
+    elif mode == 7:
+        return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return img.rot90(1, [2, 3]).flip([2])
+    elif mode == 2:
+        return img.flip([2])
+    elif mode == 3:
+        return img.rot90(3, [2, 3])
+    elif mode == 4:
+        return img.rot90(2, [2, 3]).flip([2])
+    elif mode == 5:
+        return img.rot90(1, [2, 3])
+    elif mode == 6:
+        return img.rot90(2, [2, 3])
+    elif mode == 7:
+        return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    img_size = img.size()
+    img_np = img.data.cpu().numpy()
+    if len(img_size) == 3:
+        img_np = np.transpose(img_np, (1, 2, 0))
+    elif len(img_size) == 4:
+        img_np = np.transpose(img_np, (2, 3, 1, 0))
+    img_np = augment_img(img_np, mode=mode)
+    img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+    if len(img_size) == 3:
+        img_tensor = img_tensor.permute(2, 0, 1)
+    elif len(img_size) == 4:
+        img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+    return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return img.transpose(1, 0, 2)
+    elif mode == 2:
+        return img[::-1, :, :]
+    elif mode == 3:
+        img = img[::-1, :, :]
+        img = img.transpose(1, 0, 2)
+        return img
+    elif mode == 4:
+        return img[:, ::-1, :]
+    elif mode == 5:
+        img = img[:, ::-1, :]
+        img = img.transpose(1, 0, 2)
+        return img
+    elif mode == 6:
+        img = img[:, ::-1, :]
+        img = img[::-1, :, :]
+        return img
+    elif mode == 7:
+        img = img[:, ::-1, :]
+        img = img[::-1, :, :]
+        img = img.transpose(1, 0, 2)
+        return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+    # horizontal flip OR rotate
+    hflip = hflip and random.random() < 0.5
+    vflip = rot and random.random() < 0.5
+    rot90 = rot and random.random() < 0.5
+
+    def _augment(img):
+        if hflip:
+            img = img[:, ::-1, :]
+        if vflip:
+            img = img[::-1, :, :]
+        if rot90:
+            img = img.transpose(1, 0, 2)
+        return img
+
+    return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+    # img_in: Numpy, HWC or HW
+    img = np.copy(img_in)
+    if img.ndim == 2:
+        H, W = img.shape
+        H_r, W_r = H % scale, W % scale
+        img = img[:H - H_r, :W - W_r]
+    elif img.ndim == 3:
+        H, W, C = img.shape
+        H_r, W_r = H % scale, W % scale
+        img = img[:H - H_r, :W - W_r, :]
+    else:
+        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+    return img
+
+
+def shave(img_in, border=0):
+    # img_in: Numpy, HWC or HW
+    img = np.copy(img_in)
+    h, w = img.shape[:2]
+    img = img[border:h-border, border:w-border]
+    return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+    '''same as matlab rgb2ycbcr
+    only_y: only return Y channel
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    if only_y:
+        rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+    else:
+        rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+                              [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+    '''same as matlab ycbcr2rgb
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+                          [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+    '''bgr version of rgb2ycbcr
+    only_y: only return Y channel
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    if only_y:
+        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+    else:
+        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+    # conversion among BGR, gray and y
+    if in_c == 3 and tar_type == 'gray':  # BGR to gray
+        gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+        return [np.expand_dims(img, axis=2) for img in gray_list]
+    elif in_c == 3 and tar_type == 'y':  # BGR to y
+        y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+        return [np.expand_dims(img, axis=2) for img in y_list]
+    elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR
+        return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+    else:
+        return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+    # img1 and img2 have range [0, 255]
+    #img1 = img1.squeeze()
+    #img2 = img2.squeeze()
+    if not img1.shape == img2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+    h, w = img1.shape[:2]
+    img1 = img1[border:h-border, border:w-border]
+    img2 = img2[border:h-border, border:w-border]
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    mse = np.mean((img1 - img2)**2)
+    if mse == 0:
+        return float('inf')
+    return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+    '''calculate SSIM
+    the same outputs as MATLAB's
+    img1, img2: [0, 255]
+    '''
+    #img1 = img1.squeeze()
+    #img2 = img2.squeeze()
+    if not img1.shape == img2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+    h, w = img1.shape[:2]
+    img1 = img1[border:h-border, border:w-border]
+    img2 = img2[border:h-border, border:w-border]
+
+    if img1.ndim == 2:
+        return ssim(img1, img2)
+    elif img1.ndim == 3:
+        if img1.shape[2] == 3:
+            ssims = []
+            for i in range(3):
+                ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+            return np.array(ssims).mean()
+        elif img1.shape[2] == 1:
+            return ssim(np.squeeze(img1), np.squeeze(img2))
+    else:
+        raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+    C1 = (0.01 * 255)**2
+    C2 = (0.03 * 255)**2
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel.transpose())
+
+    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
+    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+    mu1_sq = mu1**2
+    mu2_sq = mu2**2
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+                                                            (sigma1_sq + sigma2_sq + C2))
+    return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+    absx = torch.abs(x)
+    absx2 = absx**2
+    absx3 = absx**3
+    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+    if (scale < 1) and (antialiasing):
+        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+        kernel_width = kernel_width / scale
+
+    # Output-space coordinates
+    x = torch.linspace(1, out_length, out_length)
+
+    # Input-space coordinates. Calculate the inverse mapping such that 0.5
+    # in output space maps to 0.5 in input space, and 0.5+scale in output
+    # space maps to 1.5 in input space.
+    u = x / scale + 0.5 * (1 - 1 / scale)
+
+    # What is the left-most pixel that can be involved in the computation?
+    left = torch.floor(u - kernel_width / 2)
+
+    # What is the maximum number of pixels that can be involved in the
+    # computation?  Note: it's OK to use an extra pixel here; if the
+    # corresponding weights are all zero, it will be eliminated at the end
+    # of this function.
+    P = math.ceil(kernel_width) + 2
+
+    # The indices of the input pixels involved in computing the k-th output
+    # pixel are in row k of the indices matrix.
+    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+        1, P).expand(out_length, P)
+
+    # The weights used to compute the k-th output pixel are in row k of the
+    # weights matrix.
+    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+    # apply cubic kernel
+    if (scale < 1) and (antialiasing):
+        weights = scale * cubic(distance_to_center * scale)
+    else:
+        weights = cubic(distance_to_center)
+    # Normalize the weights matrix so that each row sums to 1.
+    weights_sum = torch.sum(weights, 1).view(out_length, 1)
+    weights = weights / weights_sum.expand(out_length, P)
+
+    # If a column in weights is all zero, get rid of it. only consider the first and last column.
+    weights_zero_tmp = torch.sum((weights == 0), 0)
+    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 1, P - 2)
+        weights = weights.narrow(1, 1, P - 2)
+    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 0, P - 2)
+        weights = weights.narrow(1, 0, P - 2)
+    weights = weights.contiguous()
+    indices = indices.contiguous()
+    sym_len_s = -indices.min() + 1
+    sym_len_e = indices.max() - in_length
+    indices = indices + sym_len_s - 1
+    return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+    # Now the scale should be the same for H and W
+    # input: img: pytorch tensor, CHW or HW [0,1]
+    # output: CHW or HW [0,1] w/o round
+    need_squeeze = True if img.dim() == 2 else False
+    if need_squeeze:
+        img.unsqueeze_(0)
+    in_C, in_H, in_W = img.size()
+    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # Return the desired dimension order for performing the resize.  The
+    # strategy is to perform the resize first along the dimension with the
+    # smallest scale factor.
+    # Now we do not support this.
+
+    # get weights and indices
+    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+        in_H, out_H, scale, kernel, kernel_width, antialiasing)
+    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+        in_W, out_W, scale, kernel, kernel_width, antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+    img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+    sym_patch = img[:, :sym_len_Hs, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+    sym_patch = img[:, -sym_len_He:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(in_C, out_H, in_W)
+    kernel_width = weights_H.size(1)
+    for i in range(out_H):
+        idx = int(indices_H[i][0])
+        for j in range(out_C):
+            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+    out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+    sym_patch = out_1[:, :, :sym_len_Ws]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, :, -sym_len_We:]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(in_C, out_H, out_W)
+    kernel_width = weights_W.size(1)
+    for i in range(out_W):
+        idx = int(indices_W[i][0])
+        for j in range(out_C):
+            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+    if need_squeeze:
+        out_2.squeeze_()
+    return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+    # Now the scale should be the same for H and W
+    # input: img: Numpy, HWC or HW [0,1]
+    # output: HWC or HW [0,1] w/o round
+    img = torch.from_numpy(img)
+    need_squeeze = True if img.dim() == 2 else False
+    if need_squeeze:
+        img.unsqueeze_(2)
+
+    in_H, in_W, in_C = img.size()
+    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # Return the desired dimension order for performing the resize.  The
+    # strategy is to perform the resize first along the dimension with the
+    # smallest scale factor.
+    # Now we do not support this.
+
+    # get weights and indices
+    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+        in_H, out_H, scale, kernel, kernel_width, antialiasing)
+    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+        in_W, out_W, scale, kernel, kernel_width, antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+    sym_patch = img[:sym_len_Hs, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+    sym_patch = img[-sym_len_He:, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(out_H, in_W, in_C)
+    kernel_width = weights_H.size(1)
+    for i in range(out_H):
+        idx = int(indices_H[i][0])
+        for j in range(out_C):
+            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+    sym_patch = out_1[:, :sym_len_Ws, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, -sym_len_We:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(out_H, out_W, in_C)
+    kernel_width = weights_W.size(1)
+    for i in range(out_W):
+        idx = int(indices_W[i][0])
+        for j in range(out_C):
+            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+    if need_squeeze:
+        out_2.squeeze_()
+
+    return out_2.numpy()
+
+
+if __name__ == '__main__':
+    print('---')
+#    img = imread_uint('test.bmp', 3)
+#    img = uint2single(img)
+#    img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ldm/modules/midas/__init__.py b/ldm/modules/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/api.py b/ldm/modules/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c
--- /dev/null
+++ b/ldm/modules/midas/api.py
@@ -0,0 +1,170 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.midas.midas.midas_net import MidasNet
+from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+
+ISL_PATHS = {
+    "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
+    "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
+    "midas_v21": "",
+    "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def load_midas_transform(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load transform only
+    if model_type == "dpt_large":  # DPT-Large
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    elif model_type == "midas_v21_small":
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    else:
+        assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return transform
+
+
+def load_model(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load network
+    model_path = ISL_PATHS[model_type]
+    if model_type == "dpt_large":  # DPT-Large
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitl16_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitb_rn50_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        model = MidasNet(model_path, non_negative=True)
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    elif model_type == "midas_v21_small":
+        model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+                               non_negative=True, blocks={'expand': True})
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    else:
+        print(f"model_type '{model_type}' not implemented, use: --model_type large")
+        assert False
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+    MODEL_TYPES_TORCH_HUB = [
+        "DPT_Large",
+        "DPT_Hybrid",
+        "MiDaS_small"
+    ]
+    MODEL_TYPES_ISL = [
+        "dpt_large",
+        "dpt_hybrid",
+        "midas_v21",
+        "midas_v21_small",
+    ]
+
+    def __init__(self, model_type):
+        super().__init__()
+        assert (model_type in self.MODEL_TYPES_ISL)
+        model, _ = load_model(model_type)
+        self.model = model
+        self.model.train = disabled_train
+
+    def forward(self, x):
+        # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+        # NOTE: we expect that the correct transform has been called during dataloading.
+        with torch.no_grad():
+            prediction = self.model(x)
+            prediction = torch.nn.functional.interpolate(
+                prediction.unsqueeze(1),
+                size=x.shape[2:],
+                mode="bicubic",
+                align_corners=False,
+            )
+        assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+        return prediction
+
diff --git a/ldm/modules/midas/midas/__init__.py b/ldm/modules/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/midas/base_model.py b/ldm/modules/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/ldm/modules/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+    def load(self, path):
+        """Load model from file.
+
+        Args:
+            path (str): file path
+        """
+        parameters = torch.load(path, map_location=torch.device('cpu'))
+
+        if "optimizer" in parameters:
+            parameters = parameters["model"]
+
+        self.load_state_dict(parameters)
diff --git a/ldm/modules/midas/midas/blocks.py b/ldm/modules/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/ldm/modules/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+    _make_pretrained_vitb_rn50_384,
+    _make_pretrained_vitl16_384,
+    _make_pretrained_vitb16_384,
+    forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+    if backbone == "vitl16_384":
+        pretrained = _make_pretrained_vitl16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [256, 512, 1024, 1024], features, groups=groups, expand=expand
+        )  # ViT-L/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb_rn50_384":
+        pretrained = _make_pretrained_vitb_rn50_384(
+            use_pretrained,
+            hooks=hooks,
+            use_vit_only=use_vit_only,
+            use_readout=use_readout,
+        )
+        scratch = _make_scratch(
+            [256, 512, 768, 768], features, groups=groups, expand=expand
+        )  # ViT-H/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb16_384":
+        pretrained = _make_pretrained_vitb16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [96, 192, 384, 768], features, groups=groups, expand=expand
+        )  # ViT-B/16 - 84.6% Top1 (backbone)
+    elif backbone == "resnext101_wsl":
+        pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+        scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand)     # efficientnet_lite3  
+    elif backbone == "efficientnet_lite3":
+        pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+        scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand)  # efficientnet_lite3     
+    else:
+        print(f"Backbone '{backbone}' not implemented")
+        assert False
+        
+    return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+    scratch = nn.Module()
+
+    out_shape1 = out_shape
+    out_shape2 = out_shape
+    out_shape3 = out_shape
+    out_shape4 = out_shape
+    if expand==True:
+        out_shape1 = out_shape
+        out_shape2 = out_shape*2
+        out_shape3 = out_shape*4
+        out_shape4 = out_shape*8
+
+    scratch.layer1_rn = nn.Conv2d(
+        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer2_rn = nn.Conv2d(
+        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer3_rn = nn.Conv2d(
+        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer4_rn = nn.Conv2d(
+        in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+
+    return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+    efficientnet = torch.hub.load(
+        "rwightman/gen-efficientnet-pytorch",
+        "tf_efficientnet_lite3",
+        pretrained=use_pretrained,
+        exportable=exportable
+    )
+    return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+    pretrained = nn.Module()
+
+    pretrained.layer1 = nn.Sequential(
+        effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+    )
+    pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+    pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+    pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+    return pretrained
+    
+
+def _make_resnet_backbone(resnet):
+    pretrained = nn.Module()
+    pretrained.layer1 = nn.Sequential(
+        resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+    )
+
+    pretrained.layer2 = resnet.layer2
+    pretrained.layer3 = resnet.layer3
+    pretrained.layer4 = resnet.layer4
+
+    return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+    resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+    return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+    """Interpolation module.
+    """
+
+    def __init__(self, scale_factor, mode, align_corners=False):
+        """Init.
+
+        Args:
+            scale_factor (float): scaling
+            mode (str): interpolation mode
+        """
+        super(Interpolate, self).__init__()
+
+        self.interp = nn.functional.interpolate
+        self.scale_factor = scale_factor
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: interpolated data
+        """
+
+        x = self.interp(
+            x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+        )
+
+        return x
+
+
+class ResidualConvUnit(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        out = self.relu(x)
+        out = self.conv1(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+
+        return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock, self).__init__()
+
+        self.resConfUnit1 = ResidualConvUnit(features)
+        self.resConfUnit2 = ResidualConvUnit(features)
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            output += self.resConfUnit1(xs[1])
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=True
+        )
+
+        return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features, activation, bn):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.bn = bn
+
+        self.groups=1
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+        
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+
+        if self.bn==True:
+            self.bn1 = nn.BatchNorm2d(features)
+            self.bn2 = nn.BatchNorm2d(features)
+
+        self.activation = activation
+
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        
+        out = self.activation(x)
+        out = self.conv1(out)
+        if self.bn==True:
+            out = self.bn1(out)
+       
+        out = self.activation(out)
+        out = self.conv2(out)
+        if self.bn==True:
+            out = self.bn2(out)
+
+        if self.groups > 1:
+            out = self.conv_merge(out)
+
+        return self.skip_add.add(out, x)
+
+        # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock_custom, self).__init__()
+
+        self.deconv = deconv
+        self.align_corners = align_corners
+
+        self.groups=1
+
+        self.expand = expand
+        out_features = features
+        if self.expand==True:
+            out_features = features//2
+        
+        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+        
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            res = self.resConfUnit1(xs[1])
+            output = self.skip_add.add(output, res)
+            # output += res
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+        )
+
+        output = self.out_conv(output)
+
+        return output
+
diff --git a/ldm/modules/midas/midas/dpt_depth.py b/ldm/modules/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/ldm/modules/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+    FeatureFusionBlock,
+    FeatureFusionBlock_custom,
+    Interpolate,
+    _make_encoder,
+    forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+    return FeatureFusionBlock_custom(
+        features,
+        nn.ReLU(False),
+        deconv=False,
+        bn=use_bn,
+        expand=False,
+        align_corners=True,
+    )
+
+
+class DPT(BaseModel):
+    def __init__(
+        self,
+        head,
+        features=256,
+        backbone="vitb_rn50_384",
+        readout="project",
+        channels_last=False,
+        use_bn=False,
+    ):
+
+        super(DPT, self).__init__()
+
+        self.channels_last = channels_last
+
+        hooks = {
+            "vitb_rn50_384": [0, 1, 8, 11],
+            "vitb16_384": [2, 5, 8, 11],
+            "vitl16_384": [5, 11, 17, 23],
+        }
+
+        # Instantiate backbone and reassemble blocks
+        self.pretrained, self.scratch = _make_encoder(
+            backbone,
+            features,
+            False, # Set to true of you want to train from scratch, uses ImageNet weights
+            groups=1,
+            expand=False,
+            exportable=False,
+            hooks=hooks[backbone],
+            use_readout=readout,
+        )
+
+        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+        self.scratch.output_conv = head
+
+
+    def forward(self, x):
+        if self.channels_last == True:
+            x.contiguous(memory_format=torch.channels_last)
+
+        layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return out
+
+
+class DPTDepthModel(DPT):
+    def __init__(self, path=None, non_negative=True, **kwargs):
+        features = kwargs["features"] if "features" in kwargs else 256
+
+        head = nn.Sequential(
+            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+
+        super().__init__(head, **kwargs)
+
+        if path is not None:
+           self.load(path)
+
+    def forward(self, x):
+        return super().forward(x).squeeze(dim=1)
+
diff --git a/ldm/modules/midas/midas/midas_net.py b/ldm/modules/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=256, non_negative=True):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet, self).__init__()
+
+        use_pretrained = False if path is None else True
+
+        self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+        self.scratch.refinenet4 = FeatureFusionBlock(features)
+        self.scratch.refinenet3 = FeatureFusionBlock(features)
+        self.scratch.refinenet2 = FeatureFusionBlock(features)
+        self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+        )
+
+        if path:
+            self.load(path)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
diff --git a/ldm/modules/midas/midas/midas_net_custom.py b/ldm/modules/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+        blocks={'expand': True}):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet_small, self).__init__()
+
+        use_pretrained = False if path else True
+                
+        self.channels_last = channels_last
+        self.blocks = blocks
+        self.backbone = backbone
+
+        self.groups = 1
+
+        features1=features
+        features2=features
+        features3=features
+        features4=features
+        self.expand = False
+        if "expand" in self.blocks and self.blocks['expand'] == True:
+            self.expand = True
+            features1=features
+            features2=features*2
+            features3=features*4
+            features4=features*8
+
+        self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+  
+        self.scratch.activation = nn.ReLU(False)    
+
+        self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+        
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+            self.scratch.activation,
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+        
+        if path:
+            self.load(path)
+
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+        if self.channels_last==True:
+            print("self.channels_last = ", self.channels_last)
+            x.contiguous(memory_format=torch.channels_last)
+
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+        
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+        
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+    prev_previous_type = nn.Identity()
+    prev_previous_name = ''
+    previous_type = nn.Identity()
+    previous_name = ''
+    for name, module in m.named_modules():
+        if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+            # print("FUSED ", prev_previous_name, previous_name, name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+        elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+            # print("FUSED ", prev_previous_name, previous_name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+        # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+        #    print("FUSED ", previous_name, name)
+        #    torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+        prev_previous_type = previous_type
+        prev_previous_name = previous_name
+        previous_type = type(module)
+        previous_name = name
\ No newline at end of file
diff --git a/ldm/modules/midas/midas/transforms.py b/ldm/modules/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/ldm/modules/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+    """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+    Args:
+        sample (dict): sample
+        size (tuple): image size
+
+    Returns:
+        tuple: new size
+    """
+    shape = list(sample["disparity"].shape)
+
+    if shape[0] >= size[0] and shape[1] >= size[1]:
+        return sample
+
+    scale = [0, 0]
+    scale[0] = size[0] / shape[0]
+    scale[1] = size[1] / shape[1]
+
+    scale = max(scale)
+
+    shape[0] = math.ceil(scale * shape[0])
+    shape[1] = math.ceil(scale * shape[1])
+
+    # resize
+    sample["image"] = cv2.resize(
+        sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+    )
+
+    sample["disparity"] = cv2.resize(
+        sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+    )
+    sample["mask"] = cv2.resize(
+        sample["mask"].astype(np.float32),
+        tuple(shape[::-1]),
+        interpolation=cv2.INTER_NEAREST,
+    )
+    sample["mask"] = sample["mask"].astype(bool)
+
+    return tuple(shape)
+
+
+class Resize(object):
+    """Resize sample to given size (width, height).
+    """
+
+    def __init__(
+        self,
+        width,
+        height,
+        resize_target=True,
+        keep_aspect_ratio=False,
+        ensure_multiple_of=1,
+        resize_method="lower_bound",
+        image_interpolation_method=cv2.INTER_AREA,
+    ):
+        """Init.
+
+        Args:
+            width (int): desired output width
+            height (int): desired output height
+            resize_target (bool, optional):
+                True: Resize the full sample (image, mask, target).
+                False: Resize image only.
+                Defaults to True.
+            keep_aspect_ratio (bool, optional):
+                True: Keep the aspect ratio of the input sample.
+                Output sample might not have the given width and height, and
+                resize behaviour depends on the parameter 'resize_method'.
+                Defaults to False.
+            ensure_multiple_of (int, optional):
+                Output width and height is constrained to be multiple of this parameter.
+                Defaults to 1.
+            resize_method (str, optional):
+                "lower_bound": Output will be at least as large as the given size.
+                "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+                "minimal": Scale as least as possible.  (Output size might be smaller than given size.)
+                Defaults to "lower_bound".
+        """
+        self.__width = width
+        self.__height = height
+
+        self.__resize_target = resize_target
+        self.__keep_aspect_ratio = keep_aspect_ratio
+        self.__multiple_of = ensure_multiple_of
+        self.__resize_method = resize_method
+        self.__image_interpolation_method = image_interpolation_method
+
+    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if max_val is not None and y > max_val:
+            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if y < min_val:
+            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        return y
+
+    def get_size(self, width, height):
+        # determine new height and width
+        scale_height = self.__height / height
+        scale_width = self.__width / width
+
+        if self.__keep_aspect_ratio:
+            if self.__resize_method == "lower_bound":
+                # scale such that output size is lower bound
+                if scale_width > scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "upper_bound":
+                # scale such that output size is upper bound
+                if scale_width < scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "minimal":
+                # scale as least as possbile
+                if abs(1 - scale_width) < abs(1 - scale_height):
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            else:
+                raise ValueError(
+                    f"resize_method {self.__resize_method} not implemented"
+                )
+
+        if self.__resize_method == "lower_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, min_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, min_val=self.__width
+            )
+        elif self.__resize_method == "upper_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, max_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, max_val=self.__width
+            )
+        elif self.__resize_method == "minimal":
+            new_height = self.constrain_to_multiple_of(scale_height * height)
+            new_width = self.constrain_to_multiple_of(scale_width * width)
+        else:
+            raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+        return (new_width, new_height)
+
+    def __call__(self, sample):
+        width, height = self.get_size(
+            sample["image"].shape[1], sample["image"].shape[0]
+        )
+
+        # resize sample
+        sample["image"] = cv2.resize(
+            sample["image"],
+            (width, height),
+            interpolation=self.__image_interpolation_method,
+        )
+
+        if self.__resize_target:
+            if "disparity" in sample:
+                sample["disparity"] = cv2.resize(
+                    sample["disparity"],
+                    (width, height),
+                    interpolation=cv2.INTER_NEAREST,
+                )
+
+            if "depth" in sample:
+                sample["depth"] = cv2.resize(
+                    sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+                )
+
+            sample["mask"] = cv2.resize(
+                sample["mask"].astype(np.float32),
+                (width, height),
+                interpolation=cv2.INTER_NEAREST,
+            )
+            sample["mask"] = sample["mask"].astype(bool)
+
+        return sample
+
+
+class NormalizeImage(object):
+    """Normlize image by given mean and std.
+    """
+
+    def __init__(self, mean, std):
+        self.__mean = mean
+        self.__std = std
+
+    def __call__(self, sample):
+        sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+        return sample
+
+
+class PrepareForNet(object):
+    """Prepare sample for usage as network input.
+    """
+
+    def __init__(self):
+        pass
+
+    def __call__(self, sample):
+        image = np.transpose(sample["image"], (2, 0, 1))
+        sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+        if "mask" in sample:
+            sample["mask"] = sample["mask"].astype(np.float32)
+            sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+        if "disparity" in sample:
+            disparity = sample["disparity"].astype(np.float32)
+            sample["disparity"] = np.ascontiguousarray(disparity)
+
+        if "depth" in sample:
+            depth = sample["depth"].astype(np.float32)
+            sample["depth"] = np.ascontiguousarray(depth)
+
+        return sample
diff --git a/ldm/modules/midas/midas/vit.py b/ldm/modules/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/ldm/modules/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+    def __init__(self, start_index=1):
+        super(Slice, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+    def __init__(self, start_index=1):
+        super(AddReadout, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        if self.start_index == 2:
+            readout = (x[:, 0] + x[:, 1]) / 2
+        else:
+            readout = x[:, 0]
+        return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+    def __init__(self, in_features, start_index=1):
+        super(ProjectReadout, self).__init__()
+        self.start_index = start_index
+
+        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+    def forward(self, x):
+        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+        features = torch.cat((x[:, self.start_index :], readout), -1)
+
+        return self.project(features)
+
+
+class Transpose(nn.Module):
+    def __init__(self, dim0, dim1):
+        super(Transpose, self).__init__()
+        self.dim0 = dim0
+        self.dim1 = dim1
+
+    def forward(self, x):
+        x = x.transpose(self.dim0, self.dim1)
+        return x
+
+
+def forward_vit(pretrained, x):
+    b, c, h, w = x.shape
+
+    glob = pretrained.model.forward_flex(x)
+
+    layer_1 = pretrained.activations["1"]
+    layer_2 = pretrained.activations["2"]
+    layer_3 = pretrained.activations["3"]
+    layer_4 = pretrained.activations["4"]
+
+    layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+    layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+    layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+    layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+    unflatten = nn.Sequential(
+        nn.Unflatten(
+            2,
+            torch.Size(
+                [
+                    h // pretrained.model.patch_size[1],
+                    w // pretrained.model.patch_size[0],
+                ]
+            ),
+        )
+    )
+
+    if layer_1.ndim == 3:
+        layer_1 = unflatten(layer_1)
+    if layer_2.ndim == 3:
+        layer_2 = unflatten(layer_2)
+    if layer_3.ndim == 3:
+        layer_3 = unflatten(layer_3)
+    if layer_4.ndim == 3:
+        layer_4 = unflatten(layer_4)
+
+    layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+    layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+    layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+    layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+    return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+    posemb_tok, posemb_grid = (
+        posemb[:, : self.start_index],
+        posemb[0, self.start_index :],
+    )
+
+    gs_old = int(math.sqrt(len(posemb_grid)))
+
+    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+    return posemb
+
+
+def forward_flex(self, x):
+    b, c, h, w = x.shape
+
+    pos_embed = self._resize_pos_embed(
+        self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+    )
+
+    B = x.shape[0]
+
+    if hasattr(self.patch_embed, "backbone"):
+        x = self.patch_embed.backbone(x)
+        if isinstance(x, (list, tuple)):
+            x = x[-1]  # last feature if backbone outputs list/tuple of features
+
+    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+    if getattr(self, "dist_token", None) is not None:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        dist_token = self.dist_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, dist_token, x), dim=1)
+    else:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+
+    x = x + pos_embed
+    x = self.pos_drop(x)
+
+    for blk in self.blocks:
+        x = blk(x)
+
+    x = self.norm(x)
+
+    return x
+
+
+activations = {}
+
+
+def get_activation(name):
+    def hook(model, input, output):
+        activations[name] = output
+
+    return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+    if use_readout == "ignore":
+        readout_oper = [Slice(start_index)] * len(features)
+    elif use_readout == "add":
+        readout_oper = [AddReadout(start_index)] * len(features)
+    elif use_readout == "project":
+        readout_oper = [
+            ProjectReadout(vit_features, start_index) for out_feat in features
+        ]
+    else:
+        assert (
+            False
+        ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+    return readout_oper
+
+
+def _make_vit_b16_backbone(
+    model,
+    features=[96, 192, 384, 768],
+    size=[384, 384],
+    hooks=[2, 5, 8, 11],
+    vit_features=768,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    # 32, 48, 136, 384
+    pretrained.act_postprocess1 = nn.Sequential(
+        readout_oper[0],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[0],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[0],
+            out_channels=features[0],
+            kernel_size=4,
+            stride=4,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess2 = nn.Sequential(
+        readout_oper[1],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[1],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[1],
+            out_channels=features[1],
+            kernel_size=2,
+            stride=2,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+    hooks = [5, 11, 17, 23] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[256, 512, 1024, 1024],
+        hooks=hooks,
+        vit_features=1024,
+        use_readout=use_readout,
+    )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model(
+        "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+    )
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[96, 192, 384, 768],
+        hooks=hooks,
+        use_readout=use_readout,
+        start_index=2,
+    )
+
+
+def _make_vit_b_rn50_backbone(
+    model,
+    features=[256, 512, 768, 768],
+    size=[384, 384],
+    hooks=[0, 1, 8, 11],
+    vit_features=768,
+    use_vit_only=False,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+
+    if use_vit_only == True:
+        pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+        pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    else:
+        pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+            get_activation("1")
+        )
+        pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+            get_activation("2")
+        )
+
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    if use_vit_only == True:
+        pretrained.act_postprocess1 = nn.Sequential(
+            readout_oper[0],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[0],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[0],
+                out_channels=features[0],
+                kernel_size=4,
+                stride=4,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+
+        pretrained.act_postprocess2 = nn.Sequential(
+            readout_oper[1],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[1],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[1],
+                out_channels=features[1],
+                kernel_size=2,
+                stride=2,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+    else:
+        pretrained.act_postprocess1 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+        pretrained.act_postprocess2 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+    pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+    model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+    hooks = [0, 1, 8, 11] if hooks == None else hooks
+    return _make_vit_b_rn50_backbone(
+        model,
+        features=[256, 512, 768, 768],
+        size=[384, 384],
+        hooks=hooks,
+        use_vit_only=use_vit_only,
+        use_readout=use_readout,
+    )
diff --git a/ldm/modules/midas/utils.py b/ldm/modules/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/ldm/modules/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+    """Read pfm file.
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        tuple: (data, scale)
+    """
+    with open(path, "rb") as file:
+
+        color = None
+        width = None
+        height = None
+        scale = None
+        endian = None
+
+        header = file.readline().rstrip()
+        if header.decode("ascii") == "PF":
+            color = True
+        elif header.decode("ascii") == "Pf":
+            color = False
+        else:
+            raise Exception("Not a PFM file: " + path)
+
+        dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+        if dim_match:
+            width, height = list(map(int, dim_match.groups()))
+        else:
+            raise Exception("Malformed PFM header.")
+
+        scale = float(file.readline().decode("ascii").rstrip())
+        if scale < 0:
+            # little-endian
+            endian = "<"
+            scale = -scale
+        else:
+            # big-endian
+            endian = ">"
+
+        data = np.fromfile(file, endian + "f")
+        shape = (height, width, 3) if color else (height, width)
+
+        data = np.reshape(data, shape)
+        data = np.flipud(data)
+
+        return data, scale
+
+
+def write_pfm(path, image, scale=1):
+    """Write pfm file.
+
+    Args:
+        path (str): pathto file
+        image (array): data
+        scale (int, optional): Scale. Defaults to 1.
+    """
+
+    with open(path, "wb") as file:
+        color = None
+
+        if image.dtype.name != "float32":
+            raise Exception("Image dtype must be float32.")
+
+        image = np.flipud(image)
+
+        if len(image.shape) == 3 and image.shape[2] == 3:  # color image
+            color = True
+        elif (
+            len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+        ):  # greyscale
+            color = False
+        else:
+            raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+        file.write("PF\n" if color else "Pf\n".encode())
+        file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+        endian = image.dtype.byteorder
+
+        if endian == "<" or endian == "=" and sys.byteorder == "little":
+            scale = -scale
+
+        file.write("%f\n".encode() % scale)
+
+        image.tofile(file)
+
+
+def read_image(path):
+    """Read image and output RGB image (0-1).
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        array: RGB image (0-1)
+    """
+    img = cv2.imread(path)
+
+    if img.ndim == 2:
+        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+    return img
+
+
+def resize_image(img):
+    """Resize image and make it fit for network.
+
+    Args:
+        img (array): image
+
+    Returns:
+        tensor: data ready for network
+    """
+    height_orig = img.shape[0]
+    width_orig = img.shape[1]
+
+    if width_orig > height_orig:
+        scale = width_orig / 384
+    else:
+        scale = height_orig / 384
+
+    height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+    width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+    img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+    img_resized = (
+        torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+    )
+    img_resized = img_resized.unsqueeze(0)
+
+    return img_resized
+
+
+def resize_depth(depth, width, height):
+    """Resize depth map and bring to CPU (numpy).
+
+    Args:
+        depth (tensor): depth
+        width (int): image width
+        height (int): image height
+
+    Returns:
+        array: processed depth
+    """
+    depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+    depth_resized = cv2.resize(
+        depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+    )
+
+    return depth_resized
+
+def write_depth(path, depth, bits=1):
+    """Write depth map to pfm and png file.
+
+    Args:
+        path (str): filepath without extension
+        depth (array): depth
+    """
+    write_pfm(path + ".pfm", depth.astype(np.float32))
+
+    depth_min = depth.min()
+    depth_max = depth.max()
+
+    max_val = (2**(8*bits))-1
+
+    if depth_max - depth_min > np.finfo("float").eps:
+        out = max_val * (depth - depth_min) / (depth_max - depth_min)
+    else:
+        out = np.zeros(depth.shape, dtype=depth.type)
+
+    if bits == 1:
+        cv2.imwrite(path + ".png", out.astype("uint8"))
+    elif bits == 2:
+        cv2.imwrite(path + ".png", out.astype("uint16"))
+
+    return
diff --git a/ldm/util.py b/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d456a86a60968788abf9f5235a41fa826ba578dc
--- /dev/null
+++ b/ldm/util.py
@@ -0,0 +1,197 @@
+import importlib
+
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+    # wh a tuple of (width, height)
+    # xc a list of captions to plot
+    b = len(xc)
+    txts = list()
+    for bi in range(b):
+        txt = Image.new("RGB", wh, color="white")
+        draw = ImageDraw.Draw(txt)
+        font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
+        nc = int(32 * (wh[0] / 256))
+        lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+        try:
+            draw.text((0, 0), lines, fill="black", font=font)
+        except UnicodeEncodeError:
+            print("Cant encode string for logging. Skipping.")
+
+        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+        txts.append(txt)
+    txts = np.stack(txts)
+    txts = torch.tensor(txts)
+    return txts
+
+
+def ismap(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+    if not isinstance(x,torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+    """
+    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+    return total_params
+
+
+def instantiate_from_config(config, **kwargs):
+    if "target" not in config:
+        if config == '__is_first_stage__':
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+    # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+    def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8,  # TODO: check hyperparameters before using
+                 weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999,   # ema decay to match previous code
+                 ema_power=1., param_names=()):
+        """AdamW that saves EMA versions of the parameters."""
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        if not 0.0 <= weight_decay:
+            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+        if not 0.0 <= ema_decay <= 1.0:
+            raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+                        ema_power=ema_power, param_names=param_names)
+        super().__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault('amsgrad', False)
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """Performs a single optimization step.
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        for group in self.param_groups:
+            params_with_grad = []
+            grads = []
+            exp_avgs = []
+            exp_avg_sqs = []
+            ema_params_with_grad = []
+            state_sums = []
+            max_exp_avg_sqs = []
+            state_steps = []
+            amsgrad = group['amsgrad']
+            beta1, beta2 = group['betas']
+            ema_decay = group['ema_decay']
+            ema_power = group['ema_power']
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                params_with_grad.append(p)
+                if p.grad.is_sparse:
+                    raise RuntimeError('AdamW does not support sparse gradients')
+                grads.append(p.grad)
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    if amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    # Exponential moving average of parameter values
+                    state['param_exp_avg'] = p.detach().float().clone()
+
+                exp_avgs.append(state['exp_avg'])
+                exp_avg_sqs.append(state['exp_avg_sq'])
+                ema_params_with_grad.append(state['param_exp_avg'])
+
+                if amsgrad:
+                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+                # update the steps for each param group update
+                state['step'] += 1
+                # record the step after step update
+                state_steps.append(state['step'])
+
+            optim._functional.adamw(params_with_grad,
+                    grads,
+                    exp_avgs,
+                    exp_avg_sqs,
+                    max_exp_avg_sqs,
+                    state_steps,
+                    amsgrad=amsgrad,
+                    beta1=beta1,
+                    beta2=beta2,
+                    lr=group['lr'],
+                    weight_decay=group['weight_decay'],
+                    eps=group['eps'],
+                    maximize=False)
+
+            cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+            for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+                ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+        return loss
\ No newline at end of file
diff --git a/models_yaml/anytext_sd15.yaml b/models_yaml/anytext_sd15.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3509940b54e87a7c564bf31f638e2df3f68301f9
--- /dev/null
+++ b/models_yaml/anytext_sd15.yaml
@@ -0,0 +1,99 @@
+model:
+  target: cldm.cldm.ControlLDM
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "img"
+    cond_stage_key: "caption"
+    control_key: "hint"
+    glyph_key: "glyphs"
+    position_key: "positions"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: true  # need be true when embedding_manager is valid
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    only_mid_control: False
+    loss_alpha: 0  # perceptual loss, 0.003
+    loss_beta: 0  # ctc loss
+    latin_weight: 1.0  # latin text line may need smaller weigth
+    with_step_weight: true
+    use_vae_upsample: true
+    embedding_manager_config:
+      target: cldm.embedding_manager.EmbeddingManager
+      params:
+        valid: true  # v6
+        emb_type: ocr  # ocr, vit, conv
+        glyph_channels: 1
+        position_channels: 1
+        add_pos: false
+        placeholder_string: '*'
+
+    control_stage_config:
+      target: cldm.cldm.ControlNet
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        model_channels: 320
+        glyph_channels: 1
+        position_channels: 1
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    unet_config:
+      target: cldm.cldm.ControlledUnetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
+      params:
+        version: ./models/clip-vit-large-patch14
+        use_vision: false  # v6
diff --git a/models_yaml/anytext_sd15_conv.yaml b/models_yaml/anytext_sd15_conv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..025158f3a0a864f4f390fe6d6bccb88bdbb13e3d
--- /dev/null
+++ b/models_yaml/anytext_sd15_conv.yaml
@@ -0,0 +1,99 @@
+model:
+  target: cldm.cldm.ControlLDM
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "img"
+    cond_stage_key: "caption"
+    control_key: "hint"
+    glyph_key: "glyphs"
+    position_key: "positions"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: true  # need be true when embedding_manager is valid
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    only_mid_control: False
+    loss_alpha: 0  # perceptual loss, 0.003
+    loss_beta: 0  # ctc loss
+    latin_weight: 1.0  # latin text line may need smaller weigth
+    with_step_weight: true
+    use_vae_upsample: true
+    embedding_manager_config:
+      target: cldm.embedding_manager.EmbeddingManager
+      params:
+        valid: true  # v6
+        emb_type: conv  # ocr, vit, conv
+        glyph_channels: 1
+        position_channels: 1
+        add_pos: false
+        placeholder_string: '*'
+
+    control_stage_config:
+      target: cldm.cldm.ControlNet
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        model_channels: 320
+        glyph_channels: 1
+        position_channels: 1
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    unet_config:
+      target: cldm.cldm.ControlledUnetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
+      params:
+        version: ./models/clip-vit-large-patch14
+        use_vision: false  # v6
diff --git a/models_yaml/anytext_sd15_perloss.yaml b/models_yaml/anytext_sd15_perloss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7b9a49d0b212aedb4f5d51e2abdcc019fa221491
--- /dev/null
+++ b/models_yaml/anytext_sd15_perloss.yaml
@@ -0,0 +1,99 @@
+model:
+  target: cldm.cldm.ControlLDM
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "img"
+    cond_stage_key: "caption"
+    control_key: "hint"
+    glyph_key: "glyphs"
+    position_key: "positions"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: true  # need be true when embedding_manager is valid
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    only_mid_control: False
+    loss_alpha: 0.003  # perceptual loss, 0.003
+    loss_beta: 0  # ctc loss
+    latin_weight: 1.0  # latin text line may need smaller weigth
+    with_step_weight: true
+    use_vae_upsample: true
+    embedding_manager_config:
+      target: cldm.embedding_manager.EmbeddingManager
+      params:
+        valid: true  # v6
+        emb_type: ocr  # ocr, vit, conv
+        glyph_channels: 1
+        position_channels: 1
+        add_pos: false
+        placeholder_string: '*'
+
+    control_stage_config:
+      target: cldm.cldm.ControlNet
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        model_channels: 320
+        glyph_channels: 1
+        position_channels: 1
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    unet_config:
+      target: cldm.cldm.ControlledUnetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
+      params:
+        version: ./models/clip-vit-large-patch14
+        use_vision: false  # v6
diff --git a/models_yaml/anytext_sd15_vit.yaml b/models_yaml/anytext_sd15_vit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..831d3227ccb6182b4916ce5921abf4c2ae97d576
--- /dev/null
+++ b/models_yaml/anytext_sd15_vit.yaml
@@ -0,0 +1,99 @@
+model:
+  target: cldm.cldm.ControlLDM
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "img"
+    cond_stage_key: "caption"
+    control_key: "hint"
+    glyph_key: "glyphs"
+    position_key: "positions"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: true  # need be true when embedding_manager is valid
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    only_mid_control: False
+    loss_alpha: 0  # perceptual loss, 0.003
+    loss_beta: 0  # ctc loss
+    latin_weight: 1.0  # latin text line may need smaller weigth
+    with_step_weight: true
+    use_vae_upsample: true
+    embedding_manager_config:
+      target: cldm.embedding_manager.EmbeddingManager
+      params:
+        valid: true  # v6
+        emb_type: vit  # ocr, vit, conv
+        glyph_channels: 1
+        position_channels: 1
+        add_pos: false
+        placeholder_string: '*'
+
+    control_stage_config:
+      target: cldm.cldm.ControlNet
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        model_channels: 320
+        glyph_channels: 1
+        position_channels: 1
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    unet_config:
+      target: cldm.cldm.ControlledUnetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
+      params:
+        version: ./models/clip-vit-large-patch14
+        use_vision: true  # v6
diff --git a/ocr_recog/RNN.py b/ocr_recog/RNN.py
new file mode 100755
index 0000000000000000000000000000000000000000..cf16855b37112c34a722a9ae7d21578a82d8c6d8
--- /dev/null
+++ b/ocr_recog/RNN.py
@@ -0,0 +1,210 @@
+from torch import nn
+import torch
+from .RecSVTR import Block
+
+class Swish(nn.Module):
+    def __int__(self):
+        super(Swish, self).__int__()
+
+    def forward(self,x):
+        return x*torch.sigmoid(x)
+
+class Im2Im(nn.Module):
+    def __init__(self, in_channels, **kwargs):
+        super().__init__()
+        self.out_channels = in_channels
+
+    def forward(self, x):
+        return x
+
+class Im2Seq(nn.Module):
+    def __init__(self, in_channels, **kwargs):
+        super().__init__()
+        self.out_channels = in_channels
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # assert H == 1
+        x = x.reshape(B, C, H * W)
+        x = x.permute((0, 2, 1))
+        return x
+
+class EncoderWithRNN(nn.Module):
+    def __init__(self, in_channels,**kwargs):
+        super(EncoderWithRNN, self).__init__()
+        hidden_size = kwargs.get('hidden_size', 256)
+        self.out_channels = hidden_size * 2
+        self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True)
+
+    def forward(self, x):
+        self.lstm.flatten_parameters()
+        x, _ = self.lstm(x)
+        return x
+
+class SequenceEncoder(nn.Module):
+    def __init__(self, in_channels, encoder_type='rnn',  **kwargs):
+        super(SequenceEncoder, self).__init__()
+        self.encoder_reshape = Im2Seq(in_channels)
+        self.out_channels = self.encoder_reshape.out_channels
+        self.encoder_type = encoder_type
+        if encoder_type == 'reshape':
+            self.only_reshape = True
+        else:
+            support_encoder_dict = {
+                'reshape': Im2Seq,
+                'rnn': EncoderWithRNN,
+                'svtr': EncoderWithSVTR
+            }
+            assert encoder_type in support_encoder_dict, '{} must in {}'.format(
+                encoder_type, support_encoder_dict.keys())
+
+            self.encoder = support_encoder_dict[encoder_type](
+                self.encoder_reshape.out_channels,**kwargs)
+            self.out_channels = self.encoder.out_channels
+            self.only_reshape = False
+
+    def forward(self, x):
+        if self.encoder_type != 'svtr':
+            x = self.encoder_reshape(x)
+            if not self.only_reshape:
+                x = self.encoder(x)
+            return x
+        else:
+            x = self.encoder(x)
+            x = self.encoder_reshape(x)
+            return x
+
+class ConvBNLayer(nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size=3,
+                 stride=1,
+                 padding=0,
+                 bias_attr=False,
+                 groups=1,
+                 act=nn.GELU):
+        super().__init__()
+        self.conv = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            groups=groups,
+            # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+            bias=bias_attr)
+        self.norm = nn.BatchNorm2d(out_channels)
+        self.act = Swish()
+
+    def forward(self, inputs):
+        out = self.conv(inputs)
+        out = self.norm(out)
+        out = self.act(out)
+        return out
+
+
+class EncoderWithSVTR(nn.Module):
+    def __init__(
+            self,
+            in_channels,
+            dims=64,  # XS
+            depth=2,
+            hidden_dims=120,
+            use_guide=False,
+            num_heads=8,
+            qkv_bias=True,
+            mlp_ratio=2.0,
+            drop_rate=0.1,
+            attn_drop_rate=0.1,
+            drop_path=0.,
+            qk_scale=None):
+        super(EncoderWithSVTR, self).__init__()
+        self.depth = depth
+        self.use_guide = use_guide
+        self.conv1 = ConvBNLayer(
+            in_channels, in_channels // 8, padding=1, act='swish')
+        self.conv2 = ConvBNLayer(
+            in_channels // 8, hidden_dims, kernel_size=1, act='swish')
+
+        self.svtr_block = nn.ModuleList([
+            Block(
+                dim=hidden_dims,
+                num_heads=num_heads,
+                mixer='Global',
+                HW=None,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                act_layer='swish',
+                attn_drop=attn_drop_rate,
+                drop_path=drop_path,
+                norm_layer='nn.LayerNorm',
+                epsilon=1e-05,
+                prenorm=False) for i in range(depth)
+        ])
+        self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
+        self.conv3 = ConvBNLayer(
+            hidden_dims, in_channels, kernel_size=1, act='swish')
+        # last conv-nxn, the input is concat of input tensor and conv3 output tensor
+        self.conv4 = ConvBNLayer(
+            2 * in_channels, in_channels // 8, padding=1, act='swish')
+
+        self.conv1x1 = ConvBNLayer(
+            in_channels // 8, dims, kernel_size=1, act='swish')
+        self.out_channels = dims
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        # weight initialization
+        if isinstance(m, nn.Conv2d):
+            nn.init.kaiming_normal_(m.weight, mode='fan_out')
+            if m.bias is not None:
+                nn.init.zeros_(m.bias)
+        elif isinstance(m, nn.BatchNorm2d):
+            nn.init.ones_(m.weight)
+            nn.init.zeros_(m.bias)
+        elif isinstance(m, nn.Linear):
+            nn.init.normal_(m.weight, 0, 0.01)
+            if m.bias is not None:
+                nn.init.zeros_(m.bias)
+        elif isinstance(m, nn.ConvTranspose2d):
+            nn.init.kaiming_normal_(m.weight, mode='fan_out')
+            if m.bias is not None:
+                nn.init.zeros_(m.bias)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.ones_(m.weight)
+            nn.init.zeros_(m.bias)
+
+    def forward(self, x):
+        # for use guide
+        if self.use_guide:
+            z = x.clone()
+            z.stop_gradient = True
+        else:
+            z = x
+        # for short cut
+        h = z
+        # reduce dim
+        z = self.conv1(z)
+        z = self.conv2(z)
+        # SVTR global block
+        B, C, H, W = z.shape
+        z = z.flatten(2).permute(0, 2, 1)
+
+        for blk in self.svtr_block:
+            z = blk(z)
+
+        z = self.norm(z)
+        # last stage
+        z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
+        z = self.conv3(z)
+        z = torch.cat((h, z), dim=1)
+        z = self.conv1x1(self.conv4(z))
+
+        return z
+
+if __name__=="__main__":
+    svtrRNN = EncoderWithSVTR(56)
+    print(svtrRNN)
\ No newline at end of file
diff --git a/ocr_recog/RecCTCHead.py b/ocr_recog/RecCTCHead.py
new file mode 100755
index 0000000000000000000000000000000000000000..867ede9916b10d49dc18e8633ae1cb3c4c87ad9d
--- /dev/null
+++ b/ocr_recog/RecCTCHead.py
@@ -0,0 +1,48 @@
+from torch import nn
+
+
+class CTCHead(nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels=6625,
+                 fc_decay=0.0004,
+                 mid_channels=None,
+                 return_feats=False,
+                 **kwargs):
+        super(CTCHead, self).__init__()
+        if mid_channels is None:
+            self.fc = nn.Linear(
+                in_channels,
+                out_channels,
+                bias=True,)
+        else:
+            self.fc1 = nn.Linear(
+                in_channels,
+                mid_channels,
+                bias=True,
+            )
+            self.fc2 = nn.Linear(
+                mid_channels,
+                out_channels,
+                bias=True,
+            )
+
+        self.out_channels = out_channels
+        self.mid_channels = mid_channels
+        self.return_feats = return_feats
+
+    def forward(self, x, labels=None):
+        if self.mid_channels is None:
+            predicts = self.fc(x)
+        else:
+            x = self.fc1(x)
+            predicts = self.fc2(x)
+
+        if self.return_feats:
+            result = dict()
+            result['ctc'] = predicts
+            result['ctc_neck'] = x
+        else:
+            result = predicts
+
+        return result
diff --git a/ocr_recog/RecModel.py b/ocr_recog/RecModel.py
new file mode 100755
index 0000000000000000000000000000000000000000..c2313bf02c952d7c5351175ccf36482cddc76cac
--- /dev/null
+++ b/ocr_recog/RecModel.py
@@ -0,0 +1,45 @@
+from torch import nn
+from .RNN import SequenceEncoder, Im2Seq, Im2Im
+from .RecMv1_enhance import MobileNetV1Enhance
+
+from .RecCTCHead import CTCHead
+
+backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance}
+neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
+head_dict = {'CTCHead':CTCHead}
+
+
+class RecModel(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        assert 'in_channels' in config, 'in_channels must in model config'
+        backbone_type = config.backbone.pop('type')
+        assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
+        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
+
+        neck_type = config.neck.pop('type')
+        assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
+        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
+
+        head_type = config.head.pop('type')
+        assert head_type in head_dict, f'head.type must in {head_dict}'
+        self.head = head_dict[head_type](self.neck.out_channels, **config.head)
+
+        self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
+
+    def load_3rd_state_dict(self, _3rd_name, _state):
+        self.backbone.load_3rd_state_dict(_3rd_name, _state)
+        self.neck.load_3rd_state_dict(_3rd_name, _state)
+        self.head.load_3rd_state_dict(_3rd_name, _state)
+
+    def forward(self, x):
+        x = self.backbone(x)
+        x = self.neck(x)
+        x = self.head(x)
+        return x
+
+    def encode(self, x):
+        x = self.backbone(x)
+        x = self.neck(x)
+        x = self.head.ctc_encoder(x)
+        return x
diff --git a/ocr_recog/RecMv1_enhance.py b/ocr_recog/RecMv1_enhance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5c848533dd35108bc963be4e6f0ea8c76564006
--- /dev/null
+++ b/ocr_recog/RecMv1_enhance.py
@@ -0,0 +1,233 @@
+import os, sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .common import Activation
+
+
+class ConvBNLayer(nn.Module):
+    def __init__(self,
+                 num_channels,
+                 filter_size,
+                 num_filters,
+                 stride,
+                 padding,
+                 channels=None,
+                 num_groups=1,
+                 act='hard_swish'):
+        super(ConvBNLayer, self).__init__()
+        self.act = act
+        self._conv = nn.Conv2d(
+            in_channels=num_channels,
+            out_channels=num_filters,
+            kernel_size=filter_size,
+            stride=stride,
+            padding=padding,
+            groups=num_groups,
+            bias=False)
+
+        self._batch_norm = nn.BatchNorm2d(
+            num_filters,
+        )
+        if self.act is not None:
+            self._act = Activation(act_type=act, inplace=True)
+
+    def forward(self, inputs):
+        y = self._conv(inputs)
+        y = self._batch_norm(y)
+        if self.act is not None:
+            y = self._act(y)
+        return y
+
+
+class DepthwiseSeparable(nn.Module):
+    def __init__(self,
+                 num_channels,
+                 num_filters1,
+                 num_filters2,
+                 num_groups,
+                 stride,
+                 scale,
+                 dw_size=3,
+                 padding=1,
+                 use_se=False):
+        super(DepthwiseSeparable, self).__init__()
+        self.use_se = use_se
+        self._depthwise_conv = ConvBNLayer(
+            num_channels=num_channels,
+            num_filters=int(num_filters1 * scale),
+            filter_size=dw_size,
+            stride=stride,
+            padding=padding,
+            num_groups=int(num_groups * scale))
+        if use_se:
+            self._se = SEModule(int(num_filters1 * scale))
+        self._pointwise_conv = ConvBNLayer(
+            num_channels=int(num_filters1 * scale),
+            filter_size=1,
+            num_filters=int(num_filters2 * scale),
+            stride=1,
+            padding=0)
+
+    def forward(self, inputs):
+        y = self._depthwise_conv(inputs)
+        if self.use_se:
+            y = self._se(y)
+        y = self._pointwise_conv(y)
+        return y
+
+
+class MobileNetV1Enhance(nn.Module):
+    def __init__(self,
+                 in_channels=3,
+                 scale=0.5,
+                 last_conv_stride=1,
+                 last_pool_type='max',
+                 **kwargs):
+        super().__init__()
+        self.scale = scale
+        self.block_list = []
+
+        self.conv1 = ConvBNLayer(
+            num_channels=in_channels,
+            filter_size=3,
+            channels=3,
+            num_filters=int(32 * scale),
+            stride=2,
+            padding=1)
+
+        conv2_1 = DepthwiseSeparable(
+            num_channels=int(32 * scale),
+            num_filters1=32,
+            num_filters2=64,
+            num_groups=32,
+            stride=1,
+            scale=scale)
+        self.block_list.append(conv2_1)
+
+        conv2_2 = DepthwiseSeparable(
+            num_channels=int(64 * scale),
+            num_filters1=64,
+            num_filters2=128,
+            num_groups=64,
+            stride=1,
+            scale=scale)
+        self.block_list.append(conv2_2)
+
+        conv3_1 = DepthwiseSeparable(
+            num_channels=int(128 * scale),
+            num_filters1=128,
+            num_filters2=128,
+            num_groups=128,
+            stride=1,
+            scale=scale)
+        self.block_list.append(conv3_1)
+
+        conv3_2 = DepthwiseSeparable(
+            num_channels=int(128 * scale),
+            num_filters1=128,
+            num_filters2=256,
+            num_groups=128,
+            stride=(2, 1),
+            scale=scale)
+        self.block_list.append(conv3_2)
+
+        conv4_1 = DepthwiseSeparable(
+            num_channels=int(256 * scale),
+            num_filters1=256,
+            num_filters2=256,
+            num_groups=256,
+            stride=1,
+            scale=scale)
+        self.block_list.append(conv4_1)
+
+        conv4_2 = DepthwiseSeparable(
+            num_channels=int(256 * scale),
+            num_filters1=256,
+            num_filters2=512,
+            num_groups=256,
+            stride=(2, 1),
+            scale=scale)
+        self.block_list.append(conv4_2)
+
+        for _ in range(5):
+            conv5 = DepthwiseSeparable(
+                num_channels=int(512 * scale),
+                num_filters1=512,
+                num_filters2=512,
+                num_groups=512,
+                stride=1,
+                dw_size=5,
+                padding=2,
+                scale=scale,
+                use_se=False)
+            self.block_list.append(conv5)
+
+        conv5_6 = DepthwiseSeparable(
+            num_channels=int(512 * scale),
+            num_filters1=512,
+            num_filters2=1024,
+            num_groups=512,
+            stride=(2, 1),
+            dw_size=5,
+            padding=2,
+            scale=scale,
+            use_se=True)
+        self.block_list.append(conv5_6)
+
+        conv6 = DepthwiseSeparable(
+            num_channels=int(1024 * scale),
+            num_filters1=1024,
+            num_filters2=1024,
+            num_groups=1024,
+            stride=last_conv_stride,
+            dw_size=5,
+            padding=2,
+            use_se=True,
+            scale=scale)
+        self.block_list.append(conv6)
+
+        self.block_list = nn.Sequential(*self.block_list)
+        if last_pool_type == 'avg':
+            self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
+        else:
+            self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+        self.out_channels = int(1024 * scale)
+
+    def forward(self, inputs):
+        y = self.conv1(inputs)
+        y = self.block_list(y)
+        y = self.pool(y)
+        return y
+
+def hardsigmoid(x):
+    return F.relu6(x + 3., inplace=True) / 6.
+
+class SEModule(nn.Module):
+    def __init__(self, channel, reduction=4):
+        super(SEModule, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.conv1 = nn.Conv2d(
+            in_channels=channel,
+            out_channels=channel // reduction,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=True)
+        self.conv2 = nn.Conv2d(
+            in_channels=channel // reduction,
+            out_channels=channel,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=True)
+
+    def forward(self, inputs):
+        outputs = self.avg_pool(inputs)
+        outputs = self.conv1(outputs)
+        outputs = F.relu(outputs)
+        outputs = self.conv2(outputs)
+        outputs = hardsigmoid(outputs)
+        x = torch.mul(inputs, outputs)
+
+        return x
diff --git a/ocr_recog/RecSVTR.py b/ocr_recog/RecSVTR.py
new file mode 100644
index 0000000000000000000000000000000000000000..484b3df991255590616f552d9f942d5a6d6973c9
--- /dev/null
+++ b/ocr_recog/RecSVTR.py
@@ -0,0 +1,591 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from torch.nn.init import trunc_normal_, zeros_, ones_
+from torch.nn import functional
+
+
+def drop_path(x, drop_prob=0., training=False):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = torch.tensor(1 - drop_prob)
+    shape = (x.size()[0], ) + (1, ) * (x.ndim - 1)
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
+    random_tensor = torch.floor(random_tensor)  # binarize
+    output = x.divide(keep_prob) * random_tensor
+    return output
+
+
+class Swish(nn.Module):
+    def __int__(self):
+        super(Swish, self).__int__()
+
+    def forward(self,x):
+        return x*torch.sigmoid(x)
+
+
+class ConvBNLayer(nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size=3,
+                 stride=1,
+                 padding=0,
+                 bias_attr=False,
+                 groups=1,
+                 act=nn.GELU):
+        super().__init__()
+        self.conv = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            groups=groups,
+            # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+            bias=bias_attr)
+        self.norm = nn.BatchNorm2d(out_channels)
+        self.act = act()
+
+    def forward(self, inputs):
+        out = self.conv(inputs)
+        out = self.norm(out)
+        out = self.act(out)
+        return out
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class Identity(nn.Module):
+    def __init__(self):
+        super(Identity, self).__init__()
+
+    def forward(self, input):
+        return input
+
+
+class Mlp(nn.Module):
+    def __init__(self,
+                 in_features,
+                 hidden_features=None,
+                 out_features=None,
+                 act_layer=nn.GELU,
+                 drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        if isinstance(act_layer, str):
+            self.act = Swish()
+        else:
+            self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class ConvMixer(nn.Module):
+    def __init__(
+            self,
+            dim,
+            num_heads=8,
+            HW=(8, 25),
+            local_k=(3, 3), ):
+        super().__init__()
+        self.HW = HW
+        self.dim = dim
+        self.local_mixer = nn.Conv2d(
+            dim,
+            dim,
+            local_k,
+            1, (local_k[0] // 2, local_k[1] // 2),
+            groups=num_heads,
+            # weight_attr=ParamAttr(initializer=KaimingNormal())
+        )
+
+    def forward(self, x):
+        h = self.HW[0]
+        w = self.HW[1]
+        x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
+        x = self.local_mixer(x)
+        x = x.flatten(2).transpose([0, 2, 1])
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self,
+                 dim,
+                 num_heads=8,
+                 mixer='Global',
+                 HW=(8, 25),
+                 local_k=(7, 11),
+                 qkv_bias=False,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+        self.HW = HW
+        if HW is not None:
+            H = HW[0]
+            W = HW[1]
+            self.N = H * W
+            self.C = dim
+        if mixer == 'Local' and HW is not None:
+            hk = local_k[0]
+            wk = local_k[1]
+            mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
+            for h in range(0, H):
+                for w in range(0, W):
+                    mask[h * W + w, h:h + hk, w:w + wk] = 0.
+            mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
+                               2].flatten(1)
+            mask_inf = torch.full([H * W, H * W],fill_value=float('-inf'))
+            mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
+            self.mask = mask[None,None,:]
+            # self.mask = mask.unsqueeze([0, 1])
+        self.mixer = mixer
+
+    def forward(self, x):
+        if self.HW is not None:
+            N = self.N
+            C = self.C
+        else:
+            _, N, C = x.shape
+        qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4))
+        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+
+        attn = (q.matmul(k.permute((0, 1, 3, 2))))
+        if self.mixer == 'Local':
+            attn += self.mask
+        attn = functional.softmax(attn, dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 mixer='Global',
+                 local_mixer=(7, 11),
+                 HW=(8, 25),
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer='nn.LayerNorm',
+                 epsilon=1e-6,
+                 prenorm=True):
+        super().__init__()
+        if isinstance(norm_layer, str):
+            self.norm1 = eval(norm_layer)(dim, eps=epsilon)
+        else:
+            self.norm1 = norm_layer(dim)
+        if mixer == 'Global' or mixer == 'Local':
+
+            self.mixer = Attention(
+                dim,
+                num_heads=num_heads,
+                mixer=mixer,
+                HW=HW,
+                local_k=local_mixer,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                attn_drop=attn_drop,
+                proj_drop=drop)
+        elif mixer == 'Conv':
+            self.mixer = ConvMixer(
+                dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
+        else:
+            raise TypeError("The mixer must be one of [Global, Local, Conv]")
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+        if isinstance(norm_layer, str):
+            self.norm2 = eval(norm_layer)(dim, eps=epsilon)
+        else:
+            self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp_ratio = mlp_ratio
+        self.mlp = Mlp(in_features=dim,
+                       hidden_features=mlp_hidden_dim,
+                       act_layer=act_layer,
+                       drop=drop)
+        self.prenorm = prenorm
+
+    def forward(self, x):
+        if self.prenorm:
+            x = self.norm1(x + self.drop_path(self.mixer(x)))
+            x = self.norm2(x + self.drop_path(self.mlp(x)))
+        else:
+            x = x + self.drop_path(self.mixer(self.norm1(x)))
+            x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self,
+                 img_size=(32, 100),
+                 in_channels=3,
+                 embed_dim=768,
+                 sub_num=2):
+        super().__init__()
+        num_patches = (img_size[1] // (2 ** sub_num)) * \
+                      (img_size[0] // (2 ** sub_num))
+        self.img_size = img_size
+        self.num_patches = num_patches
+        self.embed_dim = embed_dim
+        self.norm = None
+        if sub_num == 2:
+            self.proj = nn.Sequential(
+                ConvBNLayer(
+                    in_channels=in_channels,
+                    out_channels=embed_dim // 2,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    act=nn.GELU,
+                    bias_attr=False),
+                ConvBNLayer(
+                    in_channels=embed_dim // 2,
+                    out_channels=embed_dim,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    act=nn.GELU,
+                    bias_attr=False))
+        if sub_num == 3:
+            self.proj = nn.Sequential(
+                ConvBNLayer(
+                    in_channels=in_channels,
+                    out_channels=embed_dim // 4,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    act=nn.GELU,
+                    bias_attr=False),
+                ConvBNLayer(
+                    in_channels=embed_dim // 4,
+                    out_channels=embed_dim // 2,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    act=nn.GELU,
+                    bias_attr=False),
+                ConvBNLayer(
+                    in_channels=embed_dim // 2,
+                    out_channels=embed_dim,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    act=nn.GELU,
+                    bias_attr=False))
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).permute(0, 2, 1)
+        return x
+
+
+class SubSample(nn.Module):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 types='Pool',
+                 stride=(2, 1),
+                 sub_norm='nn.LayerNorm',
+                 act=None):
+        super().__init__()
+        self.types = types
+        if types == 'Pool':
+            self.avgpool = nn.AvgPool2d(
+                kernel_size=(3, 5), stride=stride, padding=(1, 2))
+            self.maxpool = nn.MaxPool2d(
+                kernel_size=(3, 5), stride=stride, padding=(1, 2))
+            self.proj = nn.Linear(in_channels, out_channels)
+        else:
+            self.conv = nn.Conv2d(
+                in_channels,
+                out_channels,
+                kernel_size=3,
+                stride=stride,
+                padding=1,
+                # weight_attr=ParamAttr(initializer=KaimingNormal())
+            )
+        self.norm = eval(sub_norm)(out_channels)
+        if act is not None:
+            self.act = act()
+        else:
+            self.act = None
+
+    def forward(self, x):
+
+        if self.types == 'Pool':
+            x1 = self.avgpool(x)
+            x2 = self.maxpool(x)
+            x = (x1 + x2) * 0.5
+            out = self.proj(x.flatten(2).permute((0, 2, 1)))
+        else:
+            x = self.conv(x)
+            out = x.flatten(2).permute((0, 2, 1))
+        out = self.norm(out)
+        if self.act is not None:
+            out = self.act(out)
+
+        return out
+
+
+class SVTRNet(nn.Module):
+    def __init__(
+            self,
+            img_size=[48, 100],
+            in_channels=3,
+            embed_dim=[64, 128, 256],
+            depth=[3, 6, 3],
+            num_heads=[2, 4, 8],
+            mixer=['Local'] * 6 + ['Global'] *
+            6,  # Local atten, Global atten, Conv
+            local_mixer=[[7, 11], [7, 11], [7, 11]],
+            patch_merging='Conv',  # Conv, Pool, None
+            mlp_ratio=4,
+            qkv_bias=True,
+            qk_scale=None,
+            drop_rate=0.,
+            last_drop=0.1,
+            attn_drop_rate=0.,
+            drop_path_rate=0.1,
+            norm_layer='nn.LayerNorm',
+            sub_norm='nn.LayerNorm',
+            epsilon=1e-6,
+            out_channels=192,
+            out_char_num=25,
+            block_unit='Block',
+            act='nn.GELU',
+            last_stage=True,
+            sub_num=2,
+            prenorm=True,
+            use_lenhead=False,
+            **kwargs):
+        super().__init__()
+        self.img_size = img_size
+        self.embed_dim = embed_dim
+        self.out_channels = out_channels
+        self.prenorm = prenorm
+        patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
+        self.patch_embed = PatchEmbed(
+            img_size=img_size,
+            in_channels=in_channels,
+            embed_dim=embed_dim[0],
+            sub_num=sub_num)
+        num_patches = self.patch_embed.num_patches
+        self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
+        # self.pos_embed = self.create_parameter(
+        #     shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
+
+        # self.add_parameter("pos_embed", self.pos_embed)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+        Block_unit = eval(block_unit)
+
+        dpr = np.linspace(0, drop_path_rate, sum(depth))
+        self.blocks1 = nn.ModuleList(
+            [
+            Block_unit(
+                dim=embed_dim[0],
+                num_heads=num_heads[0],
+                mixer=mixer[0:depth[0]][i],
+                HW=self.HW,
+                local_mixer=local_mixer[0],
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                act_layer=eval(act),
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[0:depth[0]][i],
+                norm_layer=norm_layer,
+                epsilon=epsilon,
+                prenorm=prenorm) for i in range(depth[0])
+        ]
+        )
+        if patch_merging is not None:
+            self.sub_sample1 = SubSample(
+                embed_dim[0],
+                embed_dim[1],
+                sub_norm=sub_norm,
+                stride=[2, 1],
+                types=patch_merging)
+            HW = [self.HW[0] // 2, self.HW[1]]
+        else:
+            HW = self.HW
+        self.patch_merging = patch_merging
+        self.blocks2 = nn.ModuleList([
+            Block_unit(
+                dim=embed_dim[1],
+                num_heads=num_heads[1],
+                mixer=mixer[depth[0]:depth[0] + depth[1]][i],
+                HW=HW,
+                local_mixer=local_mixer[1],
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                act_layer=eval(act),
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
+                norm_layer=norm_layer,
+                epsilon=epsilon,
+                prenorm=prenorm) for i in range(depth[1])
+        ])
+        if patch_merging is not None:
+            self.sub_sample2 = SubSample(
+                embed_dim[1],
+                embed_dim[2],
+                sub_norm=sub_norm,
+                stride=[2, 1],
+                types=patch_merging)
+            HW = [self.HW[0] // 4, self.HW[1]]
+        else:
+            HW = self.HW
+        self.blocks3 = nn.ModuleList([
+            Block_unit(
+                dim=embed_dim[2],
+                num_heads=num_heads[2],
+                mixer=mixer[depth[0] + depth[1]:][i],
+                HW=HW,
+                local_mixer=local_mixer[2],
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                act_layer=eval(act),
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[depth[0] + depth[1]:][i],
+                norm_layer=norm_layer,
+                epsilon=epsilon,
+                prenorm=prenorm) for i in range(depth[2])
+        ])
+        self.last_stage = last_stage
+        if last_stage:
+            self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
+            self.last_conv = nn.Conv2d(
+                in_channels=embed_dim[2],
+                out_channels=self.out_channels,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False)
+            self.hardswish = nn.Hardswish()
+            self.dropout = nn.Dropout(p=last_drop)
+        if not prenorm:
+            self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
+        self.use_lenhead = use_lenhead
+        if use_lenhead:
+            self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
+            self.hardswish_len = nn.Hardswish()
+            self.dropout_len = nn.Dropout(
+                p=last_drop)
+
+        trunc_normal_(self.pos_embed,std=.02)
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight,std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                zeros_(m.bias)
+        elif isinstance(m, nn.LayerNorm):
+            zeros_(m.bias)
+            ones_(m.weight)
+
+    def forward_features(self, x):
+        x = self.patch_embed(x)
+        x = x + self.pos_embed
+        x = self.pos_drop(x)
+        for blk in self.blocks1:
+            x = blk(x)
+        if self.patch_merging is not None:
+            x = self.sub_sample1(
+                x.permute([0, 2, 1]).reshape(
+                    [-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
+        for blk in self.blocks2:
+            x = blk(x)
+        if self.patch_merging is not None:
+            x = self.sub_sample2(
+                x.permute([0, 2, 1]).reshape(
+                    [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
+        for blk in self.blocks3:
+            x = blk(x)
+        if not self.prenorm:
+            x = self.norm(x)
+        return x
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        if self.use_lenhead:
+            len_x = self.len_conv(x.mean(1))
+            len_x = self.dropout_len(self.hardswish_len(len_x))
+        if self.last_stage:
+            if self.patch_merging is not None:
+                h = self.HW[0] // 4
+            else:
+                h = self.HW[0]
+            x = self.avg_pool(
+                x.permute([0, 2, 1]).reshape(
+                    [-1, self.embed_dim[2], h, self.HW[1]]))
+            x = self.last_conv(x)
+            x = self.hardswish(x)
+            x = self.dropout(x)
+        if self.use_lenhead:
+            return x, len_x
+        return x
+
+
+if __name__=="__main__":
+    a = torch.rand(1,3,48,100)
+    svtr = SVTRNet()
+
+    out = svtr(a)
+    print(svtr)
+    print(out.size())
\ No newline at end of file
diff --git a/ocr_recog/common.py b/ocr_recog/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..a328bb034a37934b7437893b5c2e42cd3504c17f
--- /dev/null
+++ b/ocr_recog/common.py
@@ -0,0 +1,74 @@
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Hswish(nn.Module):
+    def __init__(self, inplace=True):
+        super(Hswish, self).__init__()
+        self.inplace = inplace
+
+    def forward(self, x):
+        return x * F.relu6(x + 3., inplace=self.inplace) / 6.
+
+# out = max(0, min(1, slop*x+offset))
+# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
+class Hsigmoid(nn.Module):
+    def __init__(self, inplace=True):
+        super(Hsigmoid, self).__init__()
+        self.inplace = inplace
+
+    def forward(self, x):
+        # torch: F.relu6(x + 3., inplace=self.inplace) / 6.
+        # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
+        return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
+
+class GELU(nn.Module):
+    def __init__(self, inplace=True):
+        super(GELU, self).__init__()
+        self.inplace = inplace
+
+    def forward(self, x):
+        return torch.nn.functional.gelu(x)
+
+
+class Swish(nn.Module):
+    def __init__(self, inplace=True):
+        super(Swish, self).__init__()
+        self.inplace = inplace
+
+    def forward(self, x):
+        if self.inplace:
+            x.mul_(torch.sigmoid(x))
+            return x
+        else:
+            return x*torch.sigmoid(x)
+
+
+class Activation(nn.Module):
+    def __init__(self, act_type, inplace=True):
+        super(Activation, self).__init__()
+        act_type = act_type.lower()
+        if act_type == 'relu':
+            self.act = nn.ReLU(inplace=inplace)
+        elif act_type == 'relu6':
+            self.act = nn.ReLU6(inplace=inplace)
+        elif act_type == 'sigmoid':
+            raise NotImplementedError
+        elif act_type == 'hard_sigmoid':
+            self.act = Hsigmoid(inplace)
+        elif act_type == 'hard_swish':
+            self.act = Hswish(inplace=inplace)
+        elif act_type == 'leakyrelu':
+            self.act = nn.LeakyReLU(inplace=inplace)
+        elif act_type == 'gelu':
+            self.act = GELU(inplace=inplace)
+        elif act_type == 'swish':
+            self.act = Swish(inplace=inplace)
+        else:
+            raise NotImplementedError
+
+    def forward(self, inputs):
+        return self.act(inputs)
\ No newline at end of file
diff --git a/ocr_recog/en_dict.txt b/ocr_recog/en_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7677d31b9d3f08eef2823c2cf051beeab1f0470b
--- /dev/null
+++ b/ocr_recog/en_dict.txt
@@ -0,0 +1,95 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+ 
diff --git a/ocr_recog/ppocr_keys_v1.txt b/ocr_recog/ppocr_keys_v1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..84b885d8352226e49b1d5d791b8f43a663e246aa
--- /dev/null
+++ b/ocr_recog/ppocr_keys_v1.txt
@@ -0,0 +1,6623 @@
+'
+疗
+绚
+诚
+娇
+溜
+题
+贿
+者
+廖
+更
+纳
+加
+奉
+公
+一
+就
+汴
+计
+与
+路
+房
+原
+妇
+2
+0
+8
+-
+7
+其
+>
+:
+]
+,
+,
+骑
+刈
+全
+消
+昏
+傈
+安
+久
+钟
+嗅
+不
+影
+处
+驽
+蜿
+资
+关
+椤
+地
+瘸
+专
+问
+忖
+票
+嫉
+炎
+韵
+要
+月
+田
+节
+陂
+鄙
+捌
+备
+拳
+伺
+眼
+网
+盎
+大
+傍
+心
+东
+愉
+汇
+蹿
+科
+每
+业
+里
+航
+晏
+字
+平
+录
+先
+1
+3
+彤
+鲶
+产
+稍
+督
+腴
+有
+象
+岳
+注
+绍
+在
+泺
+文
+定
+核
+名
+水
+过
+理
+让
+偷
+率
+等
+这
+发
+”
+为
+含
+肥
+酉
+相
+鄱
+七
+编
+猥
+锛
+日
+镀
+蒂
+掰
+倒
+辆
+栾
+栗
+综
+涩
+州
+雌
+滑
+馀
+了
+机
+块
+司
+宰
+甙
+兴
+矽
+抚
+保
+用
+沧
+秩
+如
+收
+息
+滥
+页
+疑
+埠
+!
+!
+姥
+异
+橹
+钇
+向
+下
+跄
+的
+椴
+沫
+国
+绥
+獠
+报
+开
+民
+蜇
+何
+分
+凇
+长
+讥
+藏
+掏
+施
+羽
+中
+讲
+派
+嘟
+人
+提
+浼
+间
+世
+而
+古
+多
+倪
+唇
+饯
+控
+庚
+首
+赛
+蜓
+味
+断
+制
+觉
+技
+替
+艰
+溢
+潮
+夕
+钺
+外
+摘
+枋
+动
+双
+单
+啮
+户
+枇
+确
+锦
+曜
+杜
+或
+能
+效
+霜
+盒
+然
+侗
+电
+晁
+放
+步
+鹃
+新
+杖
+蜂
+吒
+濂
+瞬
+评
+总
+隍
+对
+独
+合
+也
+是
+府
+青
+天
+诲
+墙
+组
+滴
+级
+邀
+帘
+示
+已
+时
+骸
+仄
+泅
+和
+遨
+店
+雇
+疫
+持
+巍
+踮
+境
+只
+亨
+目
+鉴
+崤
+闲
+体
+泄
+杂
+作
+般
+轰
+化
+解
+迂
+诿
+蛭
+璀
+腾
+告
+版
+服
+省
+师
+小
+规
+程
+线
+海
+办
+引
+二
+桧
+牌
+砺
+洄
+裴
+修
+图
+痫
+胡
+许
+犊
+事
+郛
+基
+柴
+呼
+食
+研
+奶
+律
+蛋
+因
+葆
+察
+戏
+褒
+戒
+再
+李
+骁
+工
+貂
+油
+鹅
+章
+啄
+休
+场
+给
+睡
+纷
+豆
+器
+捎
+说
+敏
+学
+会
+浒
+设
+诊
+格
+廓
+查
+来
+霓
+室
+溆
+¢
+诡
+寥
+焕
+舜
+柒
+狐
+回
+戟
+砾
+厄
+实
+翩
+尿
+五
+入
+径
+惭
+喹
+股
+宇
+篝
+|
+;
+美
+期
+云
+九
+祺
+扮
+靠
+锝
+槌
+系
+企
+酰
+阊
+暂
+蚕
+忻
+豁
+本
+羹
+执
+条
+钦
+H
+獒
+限
+进
+季
+楦
+于
+芘
+玖
+铋
+茯
+未
+答
+粘
+括
+样
+精
+欠
+矢
+甥
+帷
+嵩
+扣
+令
+仔
+风
+皈
+行
+支
+部
+蓉
+刮
+站
+蜡
+救
+钊
+汗
+松
+嫌
+成
+可
+.
+鹤
+院
+从
+交
+政
+怕
+活
+调
+球
+局
+验
+髌
+第
+韫
+谗
+串
+到
+圆
+年
+米
+/
+*
+友
+忿
+检
+区
+看
+自
+敢
+刃
+个
+兹
+弄
+流
+留
+同
+没
+齿
+星
+聆
+轼
+湖
+什
+三
+建
+蛔
+儿
+椋
+汕
+震
+颧
+鲤
+跟
+力
+情
+璺
+铨
+陪
+务
+指
+族
+训
+滦
+鄣
+濮
+扒
+商
+箱
+十
+召
+慷
+辗
+所
+莞
+管
+护
+臭
+横
+硒
+嗓
+接
+侦
+六
+露
+党
+馋
+驾
+剖
+高
+侬
+妪
+幂
+猗
+绺
+骐
+央
+酐
+孝
+筝
+课
+徇
+缰
+门
+男
+西
+项
+句
+谙
+瞒
+秃
+篇
+教
+碲
+罚
+声
+呐
+景
+前
+富
+嘴
+鳌
+稀
+免
+朋
+啬
+睐
+去
+赈
+鱼
+住
+肩
+愕
+速
+旁
+波
+厅
+健
+茼
+厥
+鲟
+谅
+投
+攸
+炔
+数
+方
+击
+呋
+谈
+绩
+别
+愫
+僚
+躬
+鹧
+胪
+炳
+招
+喇
+膨
+泵
+蹦
+毛
+结
+5
+4
+谱
+识
+陕
+粽
+婚
+拟
+构
+且
+搜
+任
+潘
+比
+郢
+妨
+醪
+陀
+桔
+碘
+扎
+选
+哈
+骷
+楷
+亿
+明
+缆
+脯
+监
+睫
+逻
+婵
+共
+赴
+淝
+凡
+惦
+及
+达
+揖
+谩
+澹
+减
+焰
+蛹
+番
+祁
+柏
+员
+禄
+怡
+峤
+龙
+白
+叽
+生
+闯
+起
+细
+装
+谕
+竟
+聚
+钙
+上
+导
+渊
+按
+艾
+辘
+挡
+耒
+盹
+饪
+臀
+记
+邮
+蕙
+受
+各
+医
+搂
+普
+滇
+朗
+茸
+带
+翻
+酚
+(
+光
+堤
+墟
+蔷
+万
+幻
+〓
+瑙
+辈
+昧
+盏
+亘
+蛀
+吉
+铰
+请
+子
+假
+闻
+税
+井
+诩
+哨
+嫂
+好
+面
+琐
+校
+馊
+鬣
+缂
+营
+访
+炖
+占
+农
+缀
+否
+经
+钚
+棵
+趟
+张
+亟
+吏
+茶
+谨
+捻
+论
+迸
+堂
+玉
+信
+吧
+瞠
+乡
+姬
+寺
+咬
+溏
+苄
+皿
+意
+赉
+宝
+尔
+钰
+艺
+特
+唳
+踉
+都
+荣
+倚
+登
+荐
+丧
+奇
+涵
+批
+炭
+近
+符
+傩
+感
+道
+着
+菊
+虹
+仲
+众
+懈
+濯
+颞
+眺
+南
+释
+北
+缝
+标
+既
+茗
+整
+撼
+迤
+贲
+挎
+耱
+拒
+某
+妍
+卫
+哇
+英
+矶
+藩
+治
+他
+元
+领
+膜
+遮
+穗
+蛾
+飞
+荒
+棺
+劫
+么
+市
+火
+温
+拈
+棚
+洼
+转
+果
+奕
+卸
+迪
+伸
+泳
+斗
+邡
+侄
+涨
+屯
+萋
+胭
+氡
+崮
+枞
+惧
+冒
+彩
+斜
+手
+豚
+随
+旭
+淑
+妞
+形
+菌
+吲
+沱
+争
+驯
+歹
+挟
+兆
+柱
+传
+至
+包
+内
+响
+临
+红
+功
+弩
+衡
+寂
+禁
+老
+棍
+耆
+渍
+织
+害
+氵
+渑
+布
+载
+靥
+嗬
+虽
+苹
+咨
+娄
+库
+雉
+榜
+帜
+嘲
+套
+瑚
+亲
+簸
+欧
+边
+6
+腿
+旮
+抛
+吹
+瞳
+得
+镓
+梗
+厨
+继
+漾
+愣
+憨
+士
+策
+窑
+抑
+躯
+襟
+脏
+参
+贸
+言
+干
+绸
+鳄
+穷
+藜
+音
+折
+详
+)
+举
+悍
+甸
+癌
+黎
+谴
+死
+罩
+迁
+寒
+驷
+袖
+媒
+蒋
+掘
+模
+纠
+恣
+观
+祖
+蛆
+碍
+位
+稿
+主
+澧
+跌
+筏
+京
+锏
+帝
+贴
+证
+糠
+才
+黄
+鲸
+略
+炯
+饱
+四
+出
+园
+犀
+牧
+容
+汉
+杆
+浈
+汰
+瑷
+造
+虫
+瘩
+怪
+驴
+济
+应
+花
+沣
+谔
+夙
+旅
+价
+矿
+以
+考
+s
+u
+呦
+晒
+巡
+茅
+准
+肟
+瓴
+詹
+仟
+褂
+译
+桌
+混
+宁
+怦
+郑
+抿
+些
+余
+鄂
+饴
+攒
+珑
+群
+阖
+岔
+琨
+藓
+预
+环
+洮
+岌
+宀
+杲
+瀵
+最
+常
+囡
+周
+踊
+女
+鼓
+袭
+喉
+简
+范
+薯
+遐
+疏
+粱
+黜
+禧
+法
+箔
+斤
+遥
+汝
+奥
+直
+贞
+撑
+置
+绱
+集
+她
+馅
+逗
+钧
+橱
+魉
+[
+恙
+躁
+唤
+9
+旺
+膘
+待
+脾
+惫
+购
+吗
+依
+盲
+度
+瘿
+蠖
+俾
+之
+镗
+拇
+鲵
+厝
+簧
+续
+款
+展
+啃
+表
+剔
+品
+钻
+腭
+损
+清
+锶
+统
+涌
+寸
+滨
+贪
+链
+吠
+冈
+伎
+迥
+咏
+吁
+览
+防
+迅
+失
+汾
+阔
+逵
+绀
+蔑
+列
+川
+凭
+努
+熨
+揪
+利
+俱
+绉
+抢
+鸨
+我
+即
+责
+膦
+易
+毓
+鹊
+刹
+玷
+岿
+空
+嘞
+绊
+排
+术
+估
+锷
+违
+们
+苟
+铜
+播
+肘
+件
+烫
+审
+鲂
+广
+像
+铌
+惰
+铟
+巳
+胍
+鲍
+康
+憧
+色
+恢
+想
+拷
+尤
+疳
+知
+S
+Y
+F
+D
+A
+峄
+裕
+帮
+握
+搔
+氐
+氘
+难
+墒
+沮
+雨
+叁
+缥
+悴
+藐
+湫
+娟
+苑
+稠
+颛
+簇
+后
+阕
+闭
+蕤
+缚
+怎
+佞
+码
+嘤
+蔡
+痊
+舱
+螯
+帕
+赫
+昵
+升
+烬
+岫
+、
+疵
+蜻
+髁
+蕨
+隶
+烛
+械
+丑
+盂
+梁
+强
+鲛
+由
+拘
+揉
+劭
+龟
+撤
+钩
+呕
+孛
+费
+妻
+漂
+求
+阑
+崖
+秤
+甘
+通
+深
+补
+赃
+坎
+床
+啪
+承
+吼
+量
+暇
+钼
+烨
+阂
+擎
+脱
+逮
+称
+P
+神
+属
+矗
+华
+届
+狍
+葑
+汹
+育
+患
+窒
+蛰
+佼
+静
+槎
+运
+鳗
+庆
+逝
+曼
+疱
+克
+代
+官
+此
+麸
+耧
+蚌
+晟
+例
+础
+榛
+副
+测
+唰
+缢
+迹
+灬
+霁
+身
+岁
+赭
+扛
+又
+菡
+乜
+雾
+板
+读
+陷
+徉
+贯
+郁
+虑
+变
+钓
+菜
+圾
+现
+琢
+式
+乐
+维
+渔
+浜
+左
+吾
+脑
+钡
+警
+T
+啵
+拴
+偌
+漱
+湿
+硕
+止
+骼
+魄
+积
+燥
+联
+踢
+玛
+则
+窿
+见
+振
+畿
+送
+班
+钽
+您
+赵
+刨
+印
+讨
+踝
+籍
+谡
+舌
+崧
+汽
+蔽
+沪
+酥
+绒
+怖
+财
+帖
+肱
+私
+莎
+勋
+羔
+霸
+励
+哼
+帐
+将
+帅
+渠
+纪
+婴
+娩
+岭
+厘
+滕
+吻
+伤
+坝
+冠
+戊
+隆
+瘁
+介
+涧
+物
+黍
+并
+姗
+奢
+蹑
+掣
+垸
+锴
+命
+箍
+捉
+病
+辖
+琰
+眭
+迩
+艘
+绌
+繁
+寅
+若
+毋
+思
+诉
+类
+诈
+燮
+轲
+酮
+狂
+重
+反
+职
+筱
+县
+委
+磕
+绣
+奖
+晋
+濉
+志
+徽
+肠
+呈
+獐
+坻
+口
+片
+碰
+几
+村
+柿
+劳
+料
+获
+亩
+惕
+晕
+厌
+号
+罢
+池
+正
+鏖
+煨
+家
+棕
+复
+尝
+懋
+蜥
+锅
+岛
+扰
+队
+坠
+瘾
+钬
+@
+卧
+疣
+镇
+譬
+冰
+彷
+频
+黯
+据
+垄
+采
+八
+缪
+瘫
+型
+熹
+砰
+楠
+襁
+箐
+但
+嘶
+绳
+啤
+拍
+盥
+穆
+傲
+洗
+盯
+塘
+怔
+筛
+丿
+台
+恒
+喂
+葛
+永
+¥
+烟
+酒
+桦
+书
+砂
+蚝
+缉
+态
+瀚
+袄
+圳
+轻
+蛛
+超
+榧
+遛
+姒
+奘
+铮
+右
+荽
+望
+偻
+卡
+丶
+氰
+附
+做
+革
+索
+戚
+坨
+桷
+唁
+垅
+榻
+岐
+偎
+坛
+莨
+山
+殊
+微
+骇
+陈
+爨
+推
+嗝
+驹
+澡
+藁
+呤
+卤
+嘻
+糅
+逛
+侵
+郓
+酌
+德
+摇
+※
+鬃
+被
+慨
+殡
+羸
+昌
+泡
+戛
+鞋
+河
+宪
+沿
+玲
+鲨
+翅
+哽
+源
+铅
+语
+照
+邯
+址
+荃
+佬
+顺
+鸳
+町
+霭
+睾
+瓢
+夸
+椁
+晓
+酿
+痈
+咔
+侏
+券
+噎
+湍
+签
+嚷
+离
+午
+尚
+社
+锤
+背
+孟
+使
+浪
+缦
+潍
+鞅
+军
+姹
+驶
+笑
+鳟
+鲁
+》
+孽
+钜
+绿
+洱
+礴
+焯
+椰
+颖
+囔
+乌
+孔
+巴
+互
+性
+椽
+哞
+聘
+昨
+早
+暮
+胶
+炀
+隧
+低
+彗
+昝
+铁
+呓
+氽
+藉
+喔
+癖
+瑗
+姨
+权
+胱
+韦
+堑
+蜜
+酋
+楝
+砝
+毁
+靓
+歙
+锲
+究
+屋
+喳
+骨
+辨
+碑
+武
+鸠
+宫
+辜
+烊
+适
+坡
+殃
+培
+佩
+供
+走
+蜈
+迟
+翼
+况
+姣
+凛
+浔
+吃
+飘
+债
+犟
+金
+促
+苛
+崇
+坂
+莳
+畔
+绂
+兵
+蠕
+斋
+根
+砍
+亢
+欢
+恬
+崔
+剁
+餐
+榫
+快
+扶
+‖
+濒
+缠
+鳜
+当
+彭
+驭
+浦
+篮
+昀
+锆
+秸
+钳
+弋
+娣
+瞑
+夷
+龛
+苫
+拱
+致
+%
+嵊
+障
+隐
+弑
+初
+娓
+抉
+汩
+累
+蓖
+"
+唬
+助
+苓
+昙
+押
+毙
+破
+城
+郧
+逢
+嚏
+獭
+瞻
+溱
+婿
+赊
+跨
+恼
+璧
+萃
+姻
+貉
+灵
+炉
+密
+氛
+陶
+砸
+谬
+衔
+点
+琛
+沛
+枳
+层
+岱
+诺
+脍
+榈
+埂
+征
+冷
+裁
+打
+蹴
+素
+瘘
+逞
+蛐
+聊
+激
+腱
+萘
+踵
+飒
+蓟
+吆
+取
+咙
+簋
+涓
+矩
+曝
+挺
+揣
+座
+你
+史
+舵
+焱
+尘
+苏
+笈
+脚
+溉
+榨
+诵
+樊
+邓
+焊
+义
+庶
+儋
+蟋
+蒲
+赦
+呷
+杞
+诠
+豪
+还
+试
+颓
+茉
+太
+除
+紫
+逃
+痴
+草
+充
+鳕
+珉
+祗
+墨
+渭
+烩
+蘸
+慕
+璇
+镶
+穴
+嵘
+恶
+骂
+险
+绋
+幕
+碉
+肺
+戳
+刘
+潞
+秣
+纾
+潜
+銮
+洛
+须
+罘
+销
+瘪
+汞
+兮
+屉
+r
+林
+厕
+质
+探
+划
+狸
+殚
+善
+煊
+烹
+〒
+锈
+逯
+宸
+辍
+泱
+柚
+袍
+远
+蹋
+嶙
+绝
+峥
+娥
+缍
+雀
+徵
+认
+镱
+谷
+=
+贩
+勉
+撩
+鄯
+斐
+洋
+非
+祚
+泾
+诒
+饿
+撬
+威
+晷
+搭
+芍
+锥
+笺
+蓦
+候
+琊
+档
+礁
+沼
+卵
+荠
+忑
+朝
+凹
+瑞
+头
+仪
+弧
+孵
+畏
+铆
+突
+衲
+车
+浩
+气
+茂
+悖
+厢
+枕
+酝
+戴
+湾
+邹
+飚
+攘
+锂
+写
+宵
+翁
+岷
+无
+喜
+丈
+挑
+嗟
+绛
+殉
+议
+槽
+具
+醇
+淞
+笃
+郴
+阅
+饼
+底
+壕
+砚
+弈
+询
+缕
+庹
+翟
+零
+筷
+暨
+舟
+闺
+甯
+撞
+麂
+茌
+蔼
+很
+珲
+捕
+棠
+角
+阉
+媛
+娲
+诽
+剿
+尉
+爵
+睬
+韩
+诰
+匣
+危
+糍
+镯
+立
+浏
+阳
+少
+盆
+舔
+擘
+匪
+申
+尬
+铣
+旯
+抖
+赘
+瓯
+居
+ˇ
+哮
+游
+锭
+茏
+歌
+坏
+甚
+秒
+舞
+沙
+仗
+劲
+潺
+阿
+燧
+郭
+嗖
+霏
+忠
+材
+奂
+耐
+跺
+砀
+输
+岖
+媳
+氟
+极
+摆
+灿
+今
+扔
+腻
+枝
+奎
+药
+熄
+吨
+话
+q
+额
+慑
+嘌
+协
+喀
+壳
+埭
+视
+著
+於
+愧
+陲
+翌
+峁
+颅
+佛
+腹
+聋
+侯
+咎
+叟
+秀
+颇
+存
+较
+罪
+哄
+岗
+扫
+栏
+钾
+羌
+己
+璨
+枭
+霉
+煌
+涸
+衿
+键
+镝
+益
+岢
+奏
+连
+夯
+睿
+冥
+均
+糖
+狞
+蹊
+稻
+爸
+刿
+胥
+煜
+丽
+肿
+璃
+掸
+跚
+灾
+垂
+樾
+濑
+乎
+莲
+窄
+犹
+撮
+战
+馄
+软
+络
+显
+鸢
+胸
+宾
+妲
+恕
+埔
+蝌
+份
+遇
+巧
+瞟
+粒
+恰
+剥
+桡
+博
+讯
+凯
+堇
+阶
+滤
+卖
+斌
+骚
+彬
+兑
+磺
+樱
+舷
+两
+娱
+福
+仃
+差
+找
+桁
+÷
+净
+把
+阴
+污
+戬
+雷
+碓
+蕲
+楚
+罡
+焖
+抽
+妫
+咒
+仑
+闱
+尽
+邑
+菁
+爱
+贷
+沥
+鞑
+牡
+嗉
+崴
+骤
+塌
+嗦
+订
+拮
+滓
+捡
+锻
+次
+坪
+杩
+臃
+箬
+融
+珂
+鹗
+宗
+枚
+降
+鸬
+妯
+阄
+堰
+盐
+毅
+必
+杨
+崃
+俺
+甬
+状
+莘
+货
+耸
+菱
+腼
+铸
+唏
+痤
+孚
+澳
+懒
+溅
+翘
+疙
+杷
+淼
+缙
+骰
+喊
+悉
+砻
+坷
+艇
+赁
+界
+谤
+纣
+宴
+晃
+茹
+归
+饭
+梢
+铡
+街
+抄
+肼
+鬟
+苯
+颂
+撷
+戈
+炒
+咆
+茭
+瘙
+负
+仰
+客
+琉
+铢
+封
+卑
+珥
+椿
+镧
+窨
+鬲
+寿
+御
+袤
+铃
+萎
+砖
+餮
+脒
+裳
+肪
+孕
+嫣
+馗
+嵇
+恳
+氯
+江
+石
+褶
+冢
+祸
+阻
+狈
+羞
+银
+靳
+透
+咳
+叼
+敷
+芷
+啥
+它
+瓤
+兰
+痘
+懊
+逑
+肌
+往
+捺
+坊
+甩
+呻
+〃
+沦
+忘
+膻
+祟
+菅
+剧
+崆
+智
+坯
+臧
+霍
+墅
+攻
+眯
+倘
+拢
+骠
+铐
+庭
+岙
+瓠
+′
+缺
+泥
+迢
+捶
+?
+?
+郏
+喙
+掷
+沌
+纯
+秘
+种
+听
+绘
+固
+螨
+团
+香
+盗
+妒
+埚
+蓝
+拖
+旱
+荞
+铀
+血
+遏
+汲
+辰
+叩
+拽
+幅
+硬
+惶
+桀
+漠
+措
+泼
+唑
+齐
+肾
+念
+酱
+虚
+屁
+耶
+旗
+砦
+闵
+婉
+馆
+拭
+绅
+韧
+忏
+窝
+醋
+葺
+顾
+辞
+倜
+堆
+辋
+逆
+玟
+贱
+疾
+董
+惘
+倌
+锕
+淘
+嘀
+莽
+俭
+笏
+绑
+鲷
+杈
+择
+蟀
+粥
+嗯
+驰
+逾
+案
+谪
+褓
+胫
+哩
+昕
+颚
+鲢
+绠
+躺
+鹄
+崂
+儒
+俨
+丝
+尕
+泌
+啊
+萸
+彰
+幺
+吟
+骄
+苣
+弦
+脊
+瑰
+〈
+诛
+镁
+析
+闪
+剪
+侧
+哟
+框
+螃
+守
+嬗
+燕
+狭
+铈
+缮
+概
+迳
+痧
+鲲
+俯
+售
+笼
+痣
+扉
+挖
+满
+咋
+援
+邱
+扇
+歪
+便
+玑
+绦
+峡
+蛇
+叨
+〖
+泽
+胃
+斓
+喋
+怂
+坟
+猪
+该
+蚬
+炕
+弥
+赞
+棣
+晔
+娠
+挲
+狡
+创
+疖
+铕
+镭
+稷
+挫
+弭
+啾
+翔
+粉
+履
+苘
+哦
+楼
+秕
+铂
+土
+锣
+瘟
+挣
+栉
+习
+享
+桢
+袅
+磨
+桂
+谦
+延
+坚
+蔚
+噗
+署
+谟
+猬
+钎
+恐
+嬉
+雒
+倦
+衅
+亏
+璩
+睹
+刻
+殿
+王
+算
+雕
+麻
+丘
+柯
+骆
+丸
+塍
+谚
+添
+鲈
+垓
+桎
+蚯
+芥
+予
+飕
+镦
+谌
+窗
+醚
+菀
+亮
+搪
+莺
+蒿
+羁
+足
+J
+真
+轶
+悬
+衷
+靛
+翊
+掩
+哒
+炅
+掐
+冼
+妮
+l
+谐
+稚
+荆
+擒
+犯
+陵
+虏
+浓
+崽
+刍
+陌
+傻
+孜
+千
+靖
+演
+矜
+钕
+煽
+杰
+酗
+渗
+伞
+栋
+俗
+泫
+戍
+罕
+沾
+疽
+灏
+煦
+芬
+磴
+叱
+阱
+榉
+湃
+蜀
+叉
+醒
+彪
+租
+郡
+篷
+屎
+良
+垢
+隗
+弱
+陨
+峪
+砷
+掴
+颁
+胎
+雯
+绵
+贬
+沐
+撵
+隘
+篙
+暖
+曹
+陡
+栓
+填
+臼
+彦
+瓶
+琪
+潼
+哪
+鸡
+摩
+啦
+俟
+锋
+域
+耻
+蔫
+疯
+纹
+撇
+毒
+绶
+痛
+酯
+忍
+爪
+赳
+歆
+嘹
+辕
+烈
+册
+朴
+钱
+吮
+毯
+癜
+娃
+谀
+邵
+厮
+炽
+璞
+邃
+丐
+追
+词
+瓒
+忆
+轧
+芫
+谯
+喷
+弟
+半
+冕
+裙
+掖
+墉
+绮
+寝
+苔
+势
+顷
+褥
+切
+衮
+君
+佳
+嫒
+蚩
+霞
+佚
+洙
+逊
+镖
+暹
+唛
+&
+殒
+顶
+碗
+獗
+轭
+铺
+蛊
+废
+恹
+汨
+崩
+珍
+那
+杵
+曲
+纺
+夏
+薰
+傀
+闳
+淬
+姘
+舀
+拧
+卷
+楂
+恍
+讪
+厩
+寮
+篪
+赓
+乘
+灭
+盅
+鞣
+沟
+慎
+挂
+饺
+鼾
+杳
+树
+缨
+丛
+絮
+娌
+臻
+嗳
+篡
+侩
+述
+衰
+矛
+圈
+蚜
+匕
+筹
+匿
+濞
+晨
+叶
+骋
+郝
+挚
+蚴
+滞
+增
+侍
+描
+瓣
+吖
+嫦
+蟒
+匾
+圣
+赌
+毡
+癞
+恺
+百
+曳
+需
+篓
+肮
+庖
+帏
+卿
+驿
+遗
+蹬
+鬓
+骡
+歉
+芎
+胳
+屐
+禽
+烦
+晌
+寄
+媾
+狄
+翡
+苒
+船
+廉
+终
+痞
+殇
+々
+畦
+饶
+改
+拆
+悻
+萄
+£
+瓿
+乃
+訾
+桅
+匮
+溧
+拥
+纱
+铍
+骗
+蕃
+龋
+缬
+父
+佐
+疚
+栎
+醍
+掳
+蓄
+x
+惆
+颜
+鲆
+榆
+〔
+猎
+敌
+暴
+谥
+鲫
+贾
+罗
+玻
+缄
+扦
+芪
+癣
+落
+徒
+臾
+恿
+猩
+托
+邴
+肄
+牵
+春
+陛
+耀
+刊
+拓
+蓓
+邳
+堕
+寇
+枉
+淌
+啡
+湄
+兽
+酷
+萼
+碚
+濠
+萤
+夹
+旬
+戮
+梭
+琥
+椭
+昔
+勺
+蜊
+绐
+晚
+孺
+僵
+宣
+摄
+冽
+旨
+萌
+忙
+蚤
+眉
+噼
+蟑
+付
+契
+瓜
+悼
+颡
+壁
+曾
+窕
+颢
+澎
+仿
+俑
+浑
+嵌
+浣
+乍
+碌
+褪
+乱
+蔟
+隙
+玩
+剐
+葫
+箫
+纲
+围
+伐
+决
+伙
+漩
+瑟
+刑
+肓
+镳
+缓
+蹭
+氨
+皓
+典
+畲
+坍
+铑
+檐
+塑
+洞
+倬
+储
+胴
+淳
+戾
+吐
+灼
+惺
+妙
+毕
+珐
+缈
+虱
+盖
+羰
+鸿
+磅
+谓
+髅
+娴
+苴
+唷
+蚣
+霹
+抨
+贤
+唠
+犬
+誓
+逍
+庠
+逼
+麓
+籼
+釉
+呜
+碧
+秧
+氩
+摔
+霄
+穸
+纨
+辟
+妈
+映
+完
+牛
+缴
+嗷
+炊
+恩
+荔
+茆
+掉
+紊
+慌
+莓
+羟
+阙
+萁
+磐
+另
+蕹
+辱
+鳐
+湮
+吡
+吩
+唐
+睦
+垠
+舒
+圜
+冗
+瞿
+溺
+芾
+囱
+匠
+僳
+汐
+菩
+饬
+漓
+黑
+霰
+浸
+濡
+窥
+毂
+蒡
+兢
+驻
+鹉
+芮
+诙
+迫
+雳
+厂
+忐
+臆
+猴
+鸣
+蚪
+栈
+箕
+羡
+渐
+莆
+捍
+眈
+哓
+趴
+蹼
+埕
+嚣
+骛
+宏
+淄
+斑
+噜
+严
+瑛
+垃
+椎
+诱
+压
+庾
+绞
+焘
+廿
+抡
+迄
+棘
+夫
+纬
+锹
+眨
+瞌
+侠
+脐
+竞
+瀑
+孳
+骧
+遁
+姜
+颦
+荪
+滚
+萦
+伪
+逸
+粳
+爬
+锁
+矣
+役
+趣
+洒
+颔
+诏
+逐
+奸
+甭
+惠
+攀
+蹄
+泛
+尼
+拼
+阮
+鹰
+亚
+颈
+惑
+勒
+〉
+际
+肛
+爷
+刚
+钨
+丰
+养
+冶
+鲽
+辉
+蔻
+画
+覆
+皴
+妊
+麦
+返
+醉
+皂
+擀
+〗
+酶
+凑
+粹
+悟
+诀
+硖
+港
+卜
+z
+杀
+涕
+±
+舍
+铠
+抵
+弛
+段
+敝
+镐
+奠
+拂
+轴
+跛
+袱
+e
+t
+沉
+菇
+俎
+薪
+峦
+秭
+蟹
+历
+盟
+菠
+寡
+液
+肢
+喻
+染
+裱
+悱
+抱
+氙
+赤
+捅
+猛
+跑
+氮
+谣
+仁
+尺
+辊
+窍
+烙
+衍
+架
+擦
+倏
+璐
+瑁
+币
+楞
+胖
+夔
+趸
+邛
+惴
+饕
+虔
+蝎
+§
+哉
+贝
+宽
+辫
+炮
+扩
+饲
+籽
+魏
+菟
+锰
+伍
+猝
+末
+琳
+哚
+蛎
+邂
+呀
+姿
+鄞
+却
+歧
+仙
+恸
+椐
+森
+牒
+寤
+袒
+婆
+虢
+雅
+钉
+朵
+贼
+欲
+苞
+寰
+故
+龚
+坭
+嘘
+咫
+礼
+硷
+兀
+睢
+汶
+’
+铲
+烧
+绕
+诃
+浃
+钿
+哺
+柜
+讼
+颊
+璁
+腔
+洽
+咐
+脲
+簌
+筠
+镣
+玮
+鞠
+谁
+兼
+姆
+挥
+梯
+蝴
+谘
+漕
+刷
+躏
+宦
+弼
+b
+垌
+劈
+麟
+莉
+揭
+笙
+渎
+仕
+嗤
+仓
+配
+怏
+抬
+错
+泯
+镊
+孰
+猿
+邪
+仍
+秋
+鼬
+壹
+歇
+吵
+炼
+<
+尧
+射
+柬
+廷
+胧
+霾
+凳
+隋
+肚
+浮
+梦
+祥
+株
+堵
+退
+L
+鹫
+跎
+凶
+毽
+荟
+炫
+栩
+玳
+甜
+沂
+鹿
+顽
+伯
+爹
+赔
+蛴
+徐
+匡
+欣
+狰
+缸
+雹
+蟆
+疤
+默
+沤
+啜
+痂
+衣
+禅
+w
+i
+h
+辽
+葳
+黝
+钗
+停
+沽
+棒
+馨
+颌
+肉
+吴
+硫
+悯
+劾
+娈
+马
+啧
+吊
+悌
+镑
+峭
+帆
+瀣
+涉
+咸
+疸
+滋
+泣
+翦
+拙
+癸
+钥
+蜒
++
+尾
+庄
+凝
+泉
+婢
+渴
+谊
+乞
+陆
+锉
+糊
+鸦
+淮
+I
+B
+N
+晦
+弗
+乔
+庥
+葡
+尻
+席
+橡
+傣
+渣
+拿
+惩
+麋
+斛
+缃
+矮
+蛏
+岘
+鸽
+姐
+膏
+催
+奔
+镒
+喱
+蠡
+摧
+钯
+胤
+柠
+拐
+璋
+鸥
+卢
+荡
+倾
+^
+_
+珀
+逄
+萧
+塾
+掇
+贮
+笆
+聂
+圃
+冲
+嵬
+M
+滔
+笕
+值
+炙
+偶
+蜱
+搐
+梆
+汪
+蔬
+腑
+鸯
+蹇
+敞
+绯
+仨
+祯
+谆
+梧
+糗
+鑫
+啸
+豺
+囹
+猾
+巢
+柄
+瀛
+筑
+踌
+沭
+暗
+苁
+鱿
+蹉
+脂
+蘖
+牢
+热
+木
+吸
+溃
+宠
+序
+泞
+偿
+拜
+檩
+厚
+朐
+毗
+螳
+吞
+媚
+朽
+担
+蝗
+橘
+畴
+祈
+糟
+盱
+隼
+郜
+惜
+珠
+裨
+铵
+焙
+琚
+唯
+咚
+噪
+骊
+丫
+滢
+勤
+棉
+呸
+咣
+淀
+隔
+蕾
+窈
+饨
+挨
+煅
+短
+匙
+粕
+镜
+赣
+撕
+墩
+酬
+馁
+豌
+颐
+抗
+酣
+氓
+佑
+搁
+哭
+递
+耷
+涡
+桃
+贻
+碣
+截
+瘦
+昭
+镌
+蔓
+氚
+甲
+猕
+蕴
+蓬
+散
+拾
+纛
+狼
+猷
+铎
+埋
+旖
+矾
+讳
+囊
+糜
+迈
+粟
+蚂
+紧
+鲳
+瘢
+栽
+稼
+羊
+锄
+斟
+睁
+桥
+瓮
+蹙
+祉
+醺
+鼻
+昱
+剃
+跳
+篱
+跷
+蒜
+翎
+宅
+晖
+嗑
+壑
+峻
+癫
+屏
+狠
+陋
+袜
+途
+憎
+祀
+莹
+滟
+佶
+溥
+臣
+约
+盛
+峰
+磁
+慵
+婪
+拦
+莅
+朕
+鹦
+粲
+裤
+哎
+疡
+嫖
+琵
+窟
+堪
+谛
+嘉
+儡
+鳝
+斩
+郾
+驸
+酊
+妄
+胜
+贺
+徙
+傅
+噌
+钢
+栅
+庇
+恋
+匝
+巯
+邈
+尸
+锚
+粗
+佟
+蛟
+薹
+纵
+蚊
+郅
+绢
+锐
+苗
+俞
+篆
+淆
+膀
+鲜
+煎
+诶
+秽
+寻
+涮
+刺
+怀
+噶
+巨
+褰
+魅
+灶
+灌
+桉
+藕
+谜
+舸
+薄
+搀
+恽
+借
+牯
+痉
+渥
+愿
+亓
+耘
+杠
+柩
+锔
+蚶
+钣
+珈
+喘
+蹒
+幽
+赐
+稗
+晤
+莱
+泔
+扯
+肯
+菪
+裆
+腩
+豉
+疆
+骜
+腐
+倭
+珏
+唔
+粮
+亡
+润
+慰
+伽
+橄
+玄
+誉
+醐
+胆
+龊
+粼
+塬
+陇
+彼
+削
+嗣
+绾
+芽
+妗
+垭
+瘴
+爽
+薏
+寨
+龈
+泠
+弹
+赢
+漪
+猫
+嘧
+涂
+恤
+圭
+茧
+烽
+屑
+痕
+巾
+赖
+荸
+凰
+腮
+畈
+亵
+蹲
+偃
+苇
+澜
+艮
+换
+骺
+烘
+苕
+梓
+颉
+肇
+哗
+悄
+氤
+涠
+葬
+屠
+鹭
+植
+竺
+佯
+诣
+鲇
+瘀
+鲅
+邦
+移
+滁
+冯
+耕
+癔
+戌
+茬
+沁
+巩
+悠
+湘
+洪
+痹
+锟
+循
+谋
+腕
+鳃
+钠
+捞
+焉
+迎
+碱
+伫
+急
+榷
+奈
+邝
+卯
+辄
+皲
+卟
+醛
+畹
+忧
+稳
+雄
+昼
+缩
+阈
+睑
+扌
+耗
+曦
+涅
+捏
+瞧
+邕
+淖
+漉
+铝
+耦
+禹
+湛
+喽
+莼
+琅
+诸
+苎
+纂
+硅
+始
+嗨
+傥
+燃
+臂
+赅
+嘈
+呆
+贵
+屹
+壮
+肋
+亍
+蚀
+卅
+豹
+腆
+邬
+迭
+浊
+}
+童
+螂
+捐
+圩
+勐
+触
+寞
+汊
+壤
+荫
+膺
+渌
+芳
+懿
+遴
+螈
+泰
+蓼
+蛤
+茜
+舅
+枫
+朔
+膝
+眙
+避
+梅
+判
+鹜
+璜
+牍
+缅
+垫
+藻
+黔
+侥
+惚
+懂
+踩
+腰
+腈
+札
+丞
+唾
+慈
+顿
+摹
+荻
+琬
+~
+斧
+沈
+滂
+胁
+胀
+幄
+莜
+Z
+匀
+鄄
+掌
+绰
+茎
+焚
+赋
+萱
+谑
+汁
+铒
+瞎
+夺
+蜗
+野
+娆
+冀
+弯
+篁
+懵
+灞
+隽
+芡
+脘
+俐
+辩
+芯
+掺
+喏
+膈
+蝈
+觐
+悚
+踹
+蔗
+熠
+鼠
+呵
+抓
+橼
+峨
+畜
+缔
+禾
+崭
+弃
+熊
+摒
+凸
+拗
+穹
+蒙
+抒
+祛
+劝
+闫
+扳
+阵
+醌
+踪
+喵
+侣
+搬
+仅
+荧
+赎
+蝾
+琦
+买
+婧
+瞄
+寓
+皎
+冻
+赝
+箩
+莫
+瞰
+郊
+笫
+姝
+筒
+枪
+遣
+煸
+袋
+舆
+痱
+涛
+母
+〇
+启
+践
+耙
+绲
+盘
+遂
+昊
+搞
+槿
+诬
+纰
+泓
+惨
+檬
+亻
+越
+C
+o
+憩
+熵
+祷
+钒
+暧
+塔
+阗
+胰
+咄
+娶
+魔
+琶
+钞
+邻
+扬
+杉
+殴
+咽
+弓
+〆
+髻
+】
+吭
+揽
+霆
+拄
+殖
+脆
+彻
+岩
+芝
+勃
+辣
+剌
+钝
+嘎
+甄
+佘
+皖
+伦
+授
+徕
+憔
+挪
+皇
+庞
+稔
+芜
+踏
+溴
+兖
+卒
+擢
+饥
+鳞
+煲
+‰
+账
+颗
+叻
+斯
+捧
+鳍
+琮
+讹
+蛙
+纽
+谭
+酸
+兔
+莒
+睇
+伟
+觑
+羲
+嗜
+宜
+褐
+旎
+辛
+卦
+诘
+筋
+鎏
+溪
+挛
+熔
+阜
+晰
+鳅
+丢
+奚
+灸
+呱
+献
+陉
+黛
+鸪
+甾
+萨
+疮
+拯
+洲
+疹
+辑
+叙
+恻
+谒
+允
+柔
+烂
+氏
+逅
+漆
+拎
+惋
+扈
+湟
+纭
+啕
+掬
+擞
+哥
+忽
+涤
+鸵
+靡
+郗
+瓷
+扁
+廊
+怨
+雏
+钮
+敦
+E
+懦
+憋
+汀
+拚
+啉
+腌
+岸
+f
+痼
+瞅
+尊
+咀
+眩
+飙
+忌
+仝
+迦
+熬
+毫
+胯
+篑
+茄
+腺
+凄
+舛
+碴
+锵
+诧
+羯
+後
+漏
+汤
+宓
+仞
+蚁
+壶
+谰
+皑
+铄
+棰
+罔
+辅
+晶
+苦
+牟
+闽
+\
+烃
+饮
+聿
+丙
+蛳
+朱
+煤
+涔
+鳖
+犁
+罐
+荼
+砒
+淦
+妤
+黏
+戎
+孑
+婕
+瑾
+戢
+钵
+枣
+捋
+砥
+衩
+狙
+桠
+稣
+阎
+肃
+梏
+诫
+孪
+昶
+婊
+衫
+嗔
+侃
+塞
+蜃
+樵
+峒
+貌
+屿
+欺
+缫
+阐
+栖
+诟
+珞
+荭
+吝
+萍
+嗽
+恂
+啻
+蜴
+磬
+峋
+俸
+豫
+谎
+徊
+镍
+韬
+魇
+晴
+U
+囟
+猜
+蛮
+坐
+囿
+伴
+亭
+肝
+佗
+蝠
+妃
+胞
+滩
+榴
+氖
+垩
+苋
+砣
+扪
+馏
+姓
+轩
+厉
+夥
+侈
+禀
+垒
+岑
+赏
+钛
+辐
+痔
+披
+纸
+碳
+“
+坞
+蠓
+挤
+荥
+沅
+悔
+铧
+帼
+蒌
+蝇
+a
+p
+y
+n
+g
+哀
+浆
+瑶
+凿
+桶
+馈
+皮
+奴
+苜
+佤
+伶
+晗
+铱
+炬
+优
+弊
+氢
+恃
+甫
+攥
+端
+锌
+灰
+稹
+炝
+曙
+邋
+亥
+眶
+碾
+拉
+萝
+绔
+捷
+浍
+腋
+姑
+菖
+凌
+涞
+麽
+锢
+桨
+潢
+绎
+镰
+殆
+锑
+渝
+铬
+困
+绽
+觎
+匈
+糙
+暑
+裹
+鸟
+盔
+肽
+迷
+綦
+『
+亳
+佝
+俘
+钴
+觇
+骥
+仆
+疝
+跪
+婶
+郯
+瀹
+唉
+脖
+踞
+针
+晾
+忒
+扼
+瞩
+叛
+椒
+疟
+嗡
+邗
+肆
+跆
+玫
+忡
+捣
+咧
+唆
+艄
+蘑
+潦
+笛
+阚
+沸
+泻
+掊
+菽
+贫
+斥
+髂
+孢
+镂
+赂
+麝
+鸾
+屡
+衬
+苷
+恪
+叠
+希
+粤
+爻
+喝
+茫
+惬
+郸
+绻
+庸
+撅
+碟
+宄
+妹
+膛
+叮
+饵
+崛
+嗲
+椅
+冤
+搅
+咕
+敛
+尹
+垦
+闷
+蝉
+霎
+勰
+败
+蓑
+泸
+肤
+鹌
+幌
+焦
+浠
+鞍
+刁
+舰
+乙
+竿
+裔
+。
+茵
+函
+伊
+兄
+丨
+娜
+匍
+謇
+莪
+宥
+似
+蝽
+翳
+酪
+翠
+粑
+薇
+祢
+骏
+赠
+叫
+Q
+噤
+噻
+竖
+芗
+莠
+潭
+俊
+羿
+耜
+O
+郫
+趁
+嗪
+囚
+蹶
+芒
+洁
+笋
+鹑
+敲
+硝
+啶
+堡
+渲
+揩
+』
+携
+宿
+遒
+颍
+扭
+棱
+割
+萜
+蔸
+葵
+琴
+捂
+饰
+衙
+耿
+掠
+募
+岂
+窖
+涟
+蔺
+瘤
+柞
+瞪
+怜
+匹
+距
+楔
+炜
+哆
+秦
+缎
+幼
+茁
+绪
+痨
+恨
+楸
+娅
+瓦
+桩
+雪
+嬴
+伏
+榔
+妥
+铿
+拌
+眠
+雍
+缇
+‘
+卓
+搓
+哌
+觞
+噩
+屈
+哧
+髓
+咦
+巅
+娑
+侑
+淫
+膳
+祝
+勾
+姊
+莴
+胄
+疃
+薛
+蜷
+胛
+巷
+芙
+芋
+熙
+闰
+勿
+窃
+狱
+剩
+钏
+幢
+陟
+铛
+慧
+靴
+耍
+k
+浙
+浇
+飨
+惟
+绗
+祜
+澈
+啼
+咪
+磷
+摞
+诅
+郦
+抹
+跃
+壬
+吕
+肖
+琏
+颤
+尴
+剡
+抠
+凋
+赚
+泊
+津
+宕
+殷
+倔
+氲
+漫
+邺
+涎
+怠
+$
+垮
+荬
+遵
+俏
+叹
+噢
+饽
+蜘
+孙
+筵
+疼
+鞭
+羧
+牦
+箭
+潴
+c
+眸
+祭
+髯
+啖
+坳
+愁
+芩
+驮
+倡
+巽
+穰
+沃
+胚
+怒
+凤
+槛
+剂
+趵
+嫁
+v
+邢
+灯
+鄢
+桐
+睽
+檗
+锯
+槟
+婷
+嵋
+圻
+诗
+蕈
+颠
+遭
+痢
+芸
+怯
+馥
+竭
+锗
+徜
+恭
+遍
+籁
+剑
+嘱
+苡
+龄
+僧
+桑
+潸
+弘
+澶
+楹
+悲
+讫
+愤
+腥
+悸
+谍
+椹
+呢
+桓
+葭
+攫
+阀
+翰
+躲
+敖
+柑
+郎
+笨
+橇
+呃
+魁
+燎
+脓
+葩
+磋
+垛
+玺
+狮
+沓
+砜
+蕊
+锺
+罹
+蕉
+翱
+虐
+闾
+巫
+旦
+茱
+嬷
+枯
+鹏
+贡
+芹
+汛
+矫
+绁
+拣
+禺
+佃
+讣
+舫
+惯
+乳
+趋
+疲
+挽
+岚
+虾
+衾
+蠹
+蹂
+飓
+氦
+铖
+孩
+稞
+瑜
+壅
+掀
+勘
+妓
+畅
+髋
+W
+庐
+牲
+蓿
+榕
+练
+垣
+唱
+邸
+菲
+昆
+婺
+穿
+绡
+麒
+蚱
+掂
+愚
+泷
+涪
+漳
+妩
+娉
+榄
+讷
+觅
+旧
+藤
+煮
+呛
+柳
+腓
+叭
+庵
+烷
+阡
+罂
+蜕
+擂
+猖
+咿
+媲
+脉
+【
+沏
+貅
+黠
+熏
+哲
+烁
+坦
+酵
+兜
+×
+潇
+撒
+剽
+珩
+圹
+乾
+摸
+樟
+帽
+嗒
+襄
+魂
+轿
+憬
+锡
+〕
+喃
+皆
+咖
+隅
+脸
+残
+泮
+袂
+鹂
+珊
+囤
+捆
+咤
+误
+徨
+闹
+淙
+芊
+淋
+怆
+囗
+拨
+梳
+渤
+R
+G
+绨
+蚓
+婀
+幡
+狩
+麾
+谢
+唢
+裸
+旌
+伉
+纶
+裂
+驳
+砼
+咛
+澄
+樨
+蹈
+宙
+澍
+倍
+貔
+操
+勇
+蟠
+摈
+砧
+虬
+够
+缁
+悦
+藿
+撸
+艹
+摁
+淹
+豇
+虎
+榭
+ˉ
+吱
+d
+°
+喧
+荀
+踱
+侮
+奋
+偕
+饷
+犍
+惮
+坑
+璎
+徘
+宛
+妆
+袈
+倩
+窦
+昂
+荏
+乖
+K
+怅
+撰
+鳙
+牙
+袁
+酞
+X
+痿
+琼
+闸
+雁
+趾
+荚
+虻
+涝
+《
+杏
+韭
+偈
+烤
+绫
+鞘
+卉
+症
+遢
+蓥
+诋
+杭
+荨
+匆
+竣
+簪
+辙
+敕
+虞
+丹
+缭
+咩
+黟
+m
+淤
+瑕
+咂
+铉
+硼
+茨
+嶂
+痒
+畸
+敬
+涿
+粪
+窘
+熟
+叔
+嫔
+盾
+忱
+裘
+憾
+梵
+赡
+珙
+咯
+娘
+庙
+溯
+胺
+葱
+痪
+摊
+荷
+卞
+乒
+髦
+寐
+铭
+坩
+胗
+枷
+爆
+溟
+嚼
+羚
+砬
+轨
+惊
+挠
+罄
+竽
+菏
+氧
+浅
+楣
+盼
+枢
+炸
+阆
+杯
+谏
+噬
+淇
+渺
+俪
+秆
+墓
+泪
+跻
+砌
+痰
+垡
+渡
+耽
+釜
+讶
+鳎
+煞
+呗
+韶
+舶
+绷
+鹳
+缜
+旷
+铊
+皱
+龌
+檀
+霖
+奄
+槐
+艳
+蝶
+旋
+哝
+赶
+骞
+蚧
+腊
+盈
+丁
+`
+蜚
+矸
+蝙
+睨
+嚓
+僻
+鬼
+醴
+夜
+彝
+磊
+笔
+拔
+栀
+糕
+厦
+邰
+纫
+逭
+纤
+眦
+膊
+馍
+躇
+烯
+蘼
+冬
+诤
+暄
+骶
+哑
+瘠
+」
+臊
+丕
+愈
+咱
+螺
+擅
+跋
+搏
+硪
+谄
+笠
+淡
+嘿
+骅
+谧
+鼎
+皋
+姚
+歼
+蠢
+驼
+耳
+胬
+挝
+涯
+狗
+蒽
+孓
+犷
+凉
+芦
+箴
+铤
+孤
+嘛
+坤
+V
+茴
+朦
+挞
+尖
+橙
+诞
+搴
+碇
+洵
+浚
+帚
+蜍
+漯
+柘
+嚎
+讽
+芭
+荤
+咻
+祠
+秉
+跖
+埃
+吓
+糯
+眷
+馒
+惹
+娼
+鲑
+嫩
+讴
+轮
+瞥
+靶
+褚
+乏
+缤
+宋
+帧
+删
+驱
+碎
+扑
+俩
+俄
+偏
+涣
+竹
+噱
+皙
+佰
+渚
+唧
+斡
+#
+镉
+刀
+崎
+筐
+佣
+夭
+贰
+肴
+峙
+哔
+艿
+匐
+牺
+镛
+缘
+仡
+嫡
+劣
+枸
+堀
+梨
+簿
+鸭
+蒸
+亦
+稽
+浴
+{
+衢
+束
+槲
+j
+阁
+揍
+疥
+棋
+潋
+聪
+窜
+乓
+睛
+插
+冉
+阪
+苍
+搽
+「
+蟾
+螟
+幸
+仇
+樽
+撂
+慢
+跤
+幔
+俚
+淅
+覃
+觊
+溶
+妖
+帛
+侨
+曰
+妾
+泗
+·
+:
+瀘
+風
+Ë
+(
+)
+∶
+紅
+紗
+瑭
+雲
+頭
+鶏
+財
+許
+•
+¥
+樂
+焗
+麗
+—
+;
+滙
+東
+榮
+繪
+興
+…
+門
+業
+π
+楊
+國
+顧
+é
+盤
+寳
+Λ
+龍
+鳳
+島
+誌
+緣
+結
+銭
+萬
+勝
+祎
+璟
+優
+歡
+臨
+時
+購
+=
+★
+藍
+昇
+鐵
+觀
+勅
+農
+聲
+畫
+兿
+術
+發
+劉
+記
+專
+耑
+園
+書
+壴
+種
+Ο
+●
+褀
+號
+銀
+匯
+敟
+锘
+葉
+橪
+廣
+進
+蒄
+鑽
+阝
+祙
+貢
+鍋
+豊
+夬
+喆
+團
+閣
+開
+燁
+賓
+館
+酡
+沔
+順
++
+硚
+劵
+饸
+陽
+車
+湓
+復
+萊
+氣
+軒
+華
+堃
+迮
+纟
+戶
+馬
+學
+裡
+電
+嶽
+獨
+マ
+シ
+サ
+ジ
+燘
+袪
+環
+❤
+臺
+灣
+専
+賣
+孖
+聖
+攝
+線
+▪
+α
+傢
+俬
+夢
+達
+莊
+喬
+貝
+薩
+劍
+羅
+壓
+棛
+饦
+尃
+璈
+囍
+醫
+G
+I
+A
+#
+N
+鷄
+髙
+嬰
+啓
+約
+隹
+潔
+賴
+藝
+~
+寶
+籣
+麺
+ 
+嶺
+√
+義
+網
+峩
+長
+∧
+魚
+機
+構
+②
+鳯
+偉
+L
+B
+㙟
+畵
+鴿
+'
+詩
+溝
+嚞
+屌
+藔
+佧
+玥
+蘭
+織
+1
+3
+9
+0
+7
+點
+砭
+鴨
+鋪
+銘
+廳
+弍
+‧
+創
+湯
+坶
+℃
+卩
+骝
+&
+烜
+荘
+當
+潤
+扞
+係
+懷
+碶
+钅
+蚨
+讠
+☆
+叢
+爲
+埗
+涫
+塗
+→
+楽
+現
+鯨
+愛
+瑪
+鈺
+忄
+悶
+藥
+飾
+樓
+視
+孬
+ㆍ
+燚
+苪
+師
+①
+丼
+锽
+│
+韓
+標
+è
+兒
+閏
+匋
+張
+漢
+Ü
+髪
+會
+閑
+檔
+習
+裝
+の
+峯
+菘
+輝
+И
+雞
+釣
+億
+浐
+K
+O
+R
+8
+H
+E
+P
+T
+W
+D
+S
+C
+M
+F
+姌
+饹
+»
+晞
+廰
+ä
+嵯
+鷹
+負
+飲
+絲
+冚
+楗
+澤
+綫
+區
+❋
+←
+質
+靑
+揚
+③
+滬
+統
+産
+協
+﹑
+乸
+畐
+經
+運
+際
+洺
+岽
+為
+粵
+諾
+崋
+豐
+碁
+ɔ
+V
+2
+6
+齋
+誠
+訂
+´
+勑
+雙
+陳
+無
+í
+泩
+媄
+夌
+刂
+i
+c
+t
+o
+r
+a
+嘢
+耄
+燴
+暃
+壽
+媽
+靈
+抻
+體
+唻
+É
+冮
+甹
+鎮
+錦
+ʌ
+蜛
+蠄
+尓
+駕
+戀
+飬
+逹
+倫
+貴
+極
+Я
+Й
+寬
+磚
+嶪
+郎
+職
+|
+間
+n
+d
+剎
+伈
+課
+飛
+橋
+瘊
+№
+譜
+骓
+圗
+滘
+縣
+粿
+咅
+養
+濤
+彳
+®
+%
+Ⅱ
+啰
+㴪
+見
+矞
+薬
+糁
+邨
+鲮
+顔
+罱
+З
+選
+話
+贏
+氪
+俵
+競
+瑩
+繡
+枱
+β
+綉
+á
+獅
+爾
+™
+麵
+戋
+淩
+徳
+個
+劇
+場
+務
+簡
+寵
+h
+實
+膠
+轱
+圖
+築
+嘣
+樹
+㸃
+營
+耵
+孫
+饃
+鄺
+飯
+麯
+遠
+輸
+坫
+孃
+乚
+閃
+鏢
+㎡
+題
+廠
+關
+↑
+爺
+將
+軍
+連
+篦
+覌
+參
+箸
+-
+窠
+棽
+寕
+夀
+爰
+歐
+呙
+閥
+頡
+熱
+雎
+垟
+裟
+凬
+勁
+帑
+馕
+夆
+疌
+枼
+馮
+貨
+蒤
+樸
+彧
+旸
+靜
+龢
+暢
+㐱
+鳥
+珺
+鏡
+灡
+爭
+堷
+廚
+Ó
+騰
+診
+┅
+蘇
+褔
+凱
+頂
+豕
+亞
+帥
+嘬
+⊥
+仺
+桖
+複
+饣
+絡
+穂
+顏
+棟
+納
+▏
+濟
+親
+設
+計
+攵
+埌
+烺
+ò
+頤
+燦
+蓮
+撻
+節
+講
+濱
+濃
+娽
+洳
+朿
+燈
+鈴
+護
+膚
+铔
+過
+補
+Z
+U
+5
+4
+坋
+闿
+䖝
+餘
+缐
+铞
+貿
+铪
+桼
+趙
+鍊
+[
+㐂
+垚
+菓
+揸
+捲
+鐘
+滏
+𣇉
+爍
+輪
+燜
+鴻
+鮮
+動
+鹞
+鷗
+丄
+慶
+鉌
+翥
+飮
+腸
+⇋
+漁
+覺
+來
+熘
+昴
+翏
+鲱
+圧
+鄉
+萭
+頔
+爐
+嫚
+г
+貭
+類
+聯
+幛
+輕
+訓
+鑒
+夋
+锨
+芃
+珣
+䝉
+扙
+嵐
+銷
+處
+ㄱ
+語
+誘
+苝
+歸
+儀
+燒
+楿
+內
+粢
+葒
+奧
+麥
+礻
+滿
+蠔
+穵
+瞭
+態
+鱬
+榞
+硂
+鄭
+黃
+煙
+祐
+奓
+逺
+*
+瑄
+獲
+聞
+薦
+讀
+這
+樣
+決
+問
+啟
+們
+執
+説
+轉
+單
+隨
+唘
+帶
+倉
+庫
+還
+贈
+尙
+皺
+■
+餅
+產
+○
+∈
+報
+狀
+楓
+賠
+琯
+嗮
+禮
+`
+傳
+>
+≤
+嗞
+Φ
+≥
+換
+咭
+∣
+↓
+曬
+ε
+応
+寫
+″
+終
+様
+純
+費
+療
+聨
+凍
+壐
+郵
+ü
+黒
+∫
+製
+塊
+調
+軽
+確
+撃
+級
+馴
+Ⅲ
+涇
+繹
+數
+碼
+證
+狒
+処
+劑
+<
+晧
+賀
+衆
+]
+櫥
+兩
+陰
+絶
+對
+鯉
+憶
+◎
+p
+e
+Y
+蕒
+煖
+頓
+測
+試
+鼽
+僑
+碩
+妝
+帯
+≈
+鐡
+舖
+權
+喫
+倆
+ˋ
+該
+悅
+ā
+俫
+.
+f
+s
+b
+m
+k
+g
+u
+j
+貼
+淨
+濕
+針
+適
+備
+l
+/
+給
+謢
+強
+觸
+衛
+與
+⊙
+$
+緯
+變
+⑴
+⑵
+⑶
+㎏
+殺
+∩
+幚
+─
+價
+▲
+離
+ú
+ó
+飄
+烏
+関
+閟
+﹝
+﹞
+邏
+輯
+鍵
+驗
+訣
+導
+歷
+屆
+層
+▼
+儱
+錄
+熳
+ē
+艦
+吋
+錶
+辧
+飼
+顯
+④
+禦
+販
+気
+対
+枰
+閩
+紀
+幹
+瞓
+貊
+淚
+△
+眞
+墊
+Ω
+獻
+褲
+縫
+緑
+亜
+鉅
+餠
+{
+}
+◆
+蘆
+薈
+█
+◇
+溫
+彈
+晳
+粧
+犸
+穩
+訊
+崬
+凖
+熥
+П
+舊
+條
+紋
+圍
+Ⅳ
+筆
+尷
+難
+雜
+錯
+綁
+識
+頰
+鎖
+艶
+□
+殁
+殼
+⑧
+├
+▕
+鵬
+ǐ
+ō
+ǒ
+糝
+綱
+▎
+μ
+盜
+饅
+醬
+籤
+蓋
+釀
+鹽
+據
+à
+ɡ
+辦
+◥
+彐
+┌
+婦
+獸
+鲩
+伱
+ī
+蒟
+蒻
+齊
+袆
+腦
+寧
+凈
+妳
+煥
+詢
+偽
+謹
+啫
+鯽
+騷
+鱸
+損
+傷
+鎻
+髮
+買
+冏
+儥
+両
+﹢
+∞
+載
+喰
+z
+羙
+悵
+燙
+曉
+員
+組
+徹
+艷
+痠
+鋼
+鼙
+縮
+細
+嚒
+爯
+≠
+維
+"
+鱻
+壇
+厍
+帰
+浥
+犇
+薡
+軎
+²
+應
+醜
+刪
+緻
+鶴
+賜
+噁
+軌
+尨
+镔
+鷺
+槗
+彌
+葚
+濛
+請
+溇
+緹
+賢
+訪
+獴
+瑅
+資
+縤
+陣
+蕟
+栢
+韻
+祼
+恁
+伢
+謝
+劃
+涑
+總
+衖
+踺
+砋
+凉
+籃
+駿
+苼
+瘋
+昽
+紡
+驊
+腎
+﹗
+響
+杋
+剛
+嚴
+禪
+歓
+槍
+傘
+檸
+檫
+炣
+勢
+鏜
+鎢
+銑
+尐
+減
+奪
+惡
+θ
+僮
+婭
+臘
+ū
+ì
+殻
+鉄
+∑
+蛲
+焼
+緖
+續
+紹
+懮
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2f78d9ab775346d2688a980a4184859d5c94e5b6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+Pillow==9.5.0
+gradio==3.50.0
+xformers==0.0.20
\ No newline at end of file
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..b86d2b64169ed66088f12dbf14d030af0125c600
--- /dev/null
+++ b/style.css
@@ -0,0 +1,32 @@
+#banner {
+    max-width: 400px;
+    margin: auto;
+    box-shadow: 0 2px 20px rgba(0, 0, 0, 0.5) !important;
+    border-radius: 20px;
+}
+
+.run {
+    background-color: #624AFF !important;
+    color: #FFFFFF !important;
+    border-radius: 2px !important;
+    box-shadow: 0 3px 5px rgba(0, 0, 0, 0.5) !important;
+}
+.run:active {
+    background-color: #d96565 !important; 
+}
+.run:hover {
+    background-color: #a079f5 !important; 
+}
+/* tab button style */
+button.svelte-kqij2n {
+    margin-bottom: -1px;
+    border: 1px solid transparent;
+    border-color: transparent;
+    border-bottom: none;
+    color: #9CA3AF !important;
+    font-size: 16px;
+}
+button.selected.svelte-kqij2n {
+    background: #ddd8f9 !important;
+    color: rgb(62, 7, 240) !important;
+}
\ No newline at end of file
diff --git a/t3_dataset.py b/t3_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..669b4598baf12654f8badc9bbbd5556ba33b8239
--- /dev/null
+++ b/t3_dataset.py
@@ -0,0 +1,447 @@
+import os
+import numpy as np
+import cv2
+import random
+import math
+from PIL import Image, ImageDraw, ImageFont
+from torch.utils.data import Dataset, DataLoader
+from dataset_util import load, show_bbox_on_image
+
+
+phrase_list = [
+    ', content and position of the texts are ',
+    ', textual material depicted in the image are ',
+    ', texts that says ',
+    ', captions shown in the snapshot are ',
+    ', with the words of ',
+    ', that reads ',
+    ', the written materials on the picture: ',
+    ', these texts are written on it: ',
+    ', captions are ',
+    ', content of the text in the graphic is '
+]
+
+
+def insert_spaces(string, nSpace):
+    if nSpace == 0:
+        return string
+    new_string = ""
+    for char in string:
+        new_string += char + " " * nSpace
+    return new_string[:-nSpace]
+
+
+def draw_glyph(font, text):
+    g_size = 50
+    W, H = (512, 80)
+    new_font = font.font_variant(size=g_size)
+    img = Image.new(mode='1', size=(W, H), color=0)
+    draw = ImageDraw.Draw(img)
+    left, top, right, bottom = new_font.getbbox(text)
+    text_width = max(right-left, 5)
+    text_height = max(bottom - top, 5)
+    ratio = min(W*0.9/text_width, H*0.9/text_height)
+    new_font = font.font_variant(size=int(g_size*ratio))
+
+    text_width, text_height = new_font.getsize(text)
+    offset_x, offset_y = new_font.getoffset(text)
+    x = (img.width - text_width) // 2
+    y = (img.height - text_height) // 2 - offset_y//2
+    draw.text((x, y), text, font=new_font, fill='white')
+    img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
+    return img
+
+
+def draw_glyph2(font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True):
+    enlarge_polygon = polygon*scale
+    rect = cv2.minAreaRect(enlarge_polygon)
+    box = cv2.boxPoints(rect)
+    box = np.int0(box)
+    w, h = rect[1]
+    angle = rect[2]
+    if angle < -45:
+        angle += 90
+    angle = -angle
+    if w < h:
+        angle += 90
+
+    vert = False
+    if (abs(angle) % 90 < vertAng or abs(90-abs(angle) % 90) % 90 < vertAng):
+        _w = max(box[:, 0]) - min(box[:, 0])
+        _h = max(box[:, 1]) - min(box[:, 1])
+        if _h >= _w:
+            vert = True
+            angle = 0
+
+    img = np.zeros((height*scale, width*scale, 3), np.uint8)
+    img = Image.fromarray(img)
+
+    # infer font size
+    image4ratio = Image.new("RGB", img.size, "white")
+    draw = ImageDraw.Draw(image4ratio)
+    _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
+    text_w = min(w, h) * (_tw / _th)
+    if text_w <= max(w, h):
+        # add space
+        if len(text) > 1 and not vert and add_space:
+            for i in range(1, 100):
+                text_space = insert_spaces(text, i)
+                _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
+                if min(w, h) * (_tw2 / _th2) > max(w, h):
+                    break
+            text = insert_spaces(text, i-1)
+        font_size = min(w, h)*0.80
+    else:
+        shrink = 0.75 if vert else 0.85
+        font_size = min(w, h) / (text_w/max(w, h)) * shrink
+    new_font = font.font_variant(size=int(font_size))
+
+    left, top, right, bottom = new_font.getbbox(text)
+    text_width = right-left
+    text_height = bottom - top
+
+    layer = Image.new('RGBA', img.size, (0, 0, 0, 0))
+    draw = ImageDraw.Draw(layer)
+    if not vert:
+        draw.text((rect[0][0]-text_width//2, rect[0][1]-text_height//2-top), text, font=new_font, fill=(255, 255, 255, 255))
+    else:
+        x_s = min(box[:, 0]) + _w//2 - text_height//2
+        y_s = min(box[:, 1])
+        for c in text:
+            draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
+            _, _t, _, _b = new_font.getbbox(c)
+            y_s += _b
+
+    rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
+
+    x_offset = int((img.width - rotated_layer.width) / 2)
+    y_offset = int((img.height - rotated_layer.height) / 2)
+    img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
+    img = np.expand_dims(np.array(img.convert('1')), axis=2).astype(np.float64)
+    return img
+
+
+def get_caption_pos(ori_caption, pos_idxs, prob=1.0, place_holder='*'):
+    idx2pos = {
+        0: " top left",
+        1: " top",
+        2: " top right",
+        3: " left",
+        4: random.choice([" middle", " center"]),
+        5: " right",
+        6: " bottom left",
+        7: " bottom",
+        8: " bottom right"
+    }
+    new_caption = ori_caption + random.choice(phrase_list)
+    pos = ''
+    for i in range(len(pos_idxs)):
+        if random.random() < prob and pos_idxs[i] > 0:
+            pos += place_holder + random.choice([' located', ' placed', ' positioned', '']) + random.choice([' at', ' in', ' on']) + idx2pos[pos_idxs[i]] + ', '
+        else:
+            pos += place_holder + ' , '
+    pos = pos[:-2] + '.'
+    new_caption += pos
+    return new_caption
+
+
+def generate_random_rectangles(w, h, box_num):
+    rectangles = []
+    for i in range(box_num):
+        x = random.randint(0, w)
+        y = random.randint(0, h)
+        w = random.randint(16, 256)
+        h = random.randint(16, 96)
+        angle = random.randint(-45, 45)
+        p1 = (x, y)
+        p2 = (x + w, y)
+        p3 = (x + w, y + h)
+        p4 = (x, y + h)
+        center = ((x + x + w) / 2, (y + y + h) / 2)
+        p1 = rotate_point(p1, center, angle)
+        p2 = rotate_point(p2, center, angle)
+        p3 = rotate_point(p3, center, angle)
+        p4 = rotate_point(p4, center, angle)
+        rectangles.append((p1, p2, p3, p4))
+    return rectangles
+
+
+def rotate_point(point, center, angle):
+    # rotation
+    angle = math.radians(angle)
+    x = point[0] - center[0]
+    y = point[1] - center[1]
+    x1 = x * math.cos(angle) - y * math.sin(angle)
+    y1 = x * math.sin(angle) + y * math.cos(angle)
+    x1 += center[0]
+    y1 += center[1]
+    return int(x1), int(y1)
+
+
+class T3DataSet(Dataset):
+    def __init__(
+            self,
+            json_path,
+            max_lines=5,
+            max_chars=20,
+            place_holder='*',
+            font_path='./font/Arial_Unicode.ttf',
+            caption_pos_prob=1.0,
+            mask_pos_prob=1.0,
+            mask_img_prob=0.5,
+            for_show=False,
+            using_dlc=False,
+            glyph_scale=1,
+            percent=1.0,
+            debug=False,
+            wm_thresh=1.0,
+            ):
+        assert isinstance(json_path, (str, list))
+        if isinstance(json_path, str):
+            json_path = [json_path]
+        data_list = []
+        self.using_dlc = using_dlc
+        self.max_lines = max_lines
+        self.max_chars = max_chars
+        self.place_holder = place_holder
+        self.font = ImageFont.truetype(font_path, size=60)
+        self.caption_pos_porb = caption_pos_prob
+        self.mask_pos_prob = mask_pos_prob
+        self.mask_img_prob = mask_img_prob
+        self.for_show = for_show
+        self.glyph_scale = glyph_scale
+        self.wm_thresh = wm_thresh
+        for jp in json_path:
+            data_list += self.load_data(jp, percent)
+        self.data_list = data_list
+        print(f'All dataset loaded, imgs={len(self.data_list)}')
+        self.debug = debug
+        if self.debug:
+            self.tmp_items = [i for i in range(100)]
+
+    def load_data(self, json_path, percent):
+        content = load(json_path)
+        d = []
+        count = 0
+        wm_skip = 0
+        max_img = len(content['data_list']) * percent
+        for gt in content['data_list']:
+            if len(d) > max_img:
+                break
+            if 'wm_score' in gt and gt['wm_score'] > self.wm_thresh:  # wm_score > thresh will be skiped as an img with watermark
+                wm_skip += 1
+                continue
+            data_root = content['data_root']
+            if self.using_dlc:
+                data_root = data_root.replace('/data/vdb', '/mnt/data', 1)
+            img_path = os.path.join(data_root, gt['img_name'])
+            info = {}
+            info['img_path'] = img_path
+            info['caption'] = gt['caption'] if 'caption' in gt else ''
+            if self.place_holder in info['caption']:
+                count += 1
+                info['caption'] = info['caption'].replace(self.place_holder, " ")
+            if 'annotations' in gt:
+                polygons = []
+                invalid_polygons = []
+                texts = []
+                languages = []
+                pos = []
+                for annotation in gt['annotations']:
+                    if len(annotation['polygon']) == 0:
+                        continue
+                    if 'valid' in annotation and annotation['valid'] is False:
+                        invalid_polygons.append(annotation['polygon'])
+                        continue
+                    polygons.append(annotation['polygon'])
+                    texts.append(annotation['text'])
+                    languages.append(annotation['language'])
+                    if 'pos' in annotation:
+                        pos.append(annotation['pos'])
+                info['polygons'] = [np.array(i) for i in polygons]
+                info['invalid_polygons'] = [np.array(i) for i in invalid_polygons]
+                info['texts'] = texts
+                info['language'] = languages
+                info['pos'] = pos
+            d.append(info)
+        print(f'{json_path} loaded, imgs={len(d)}, wm_skip={wm_skip}')
+        if count > 0:
+            print(f"Found {count} image's caption contain placeholder: {self.place_holder}, change to ' '...")
+        return d
+
+    def __getitem__(self, item):
+        item_dict = {}
+        if self.debug:  # sample fixed items
+            item = self.tmp_items.pop()
+            print(f'item = {item}')
+        cur_item = self.data_list[item]
+        # img
+        target = np.array(Image.open(cur_item['img_path']).convert('RGB'))
+        if target.shape[0] != 512 or target.shape[1] != 512:
+            target = cv2.resize(target, (512, 512))
+        target = (target.astype(np.float32) / 127.5) - 1.0
+        item_dict['img'] = target
+        # caption
+        item_dict['caption'] = cur_item['caption']
+        item_dict['glyphs'] = []
+        item_dict['gly_line'] = []
+        item_dict['positions'] = []
+        item_dict['texts'] = []
+        item_dict['language'] = []
+        item_dict['inv_mask'] = []
+        texts = cur_item.get('texts', [])
+        if len(texts) > 0:
+            idxs = [i for i in range(len(texts))]
+            if len(texts) > self.max_lines:
+                sel_idxs = random.sample(idxs, self.max_lines)
+                unsel_idxs = [i for i in idxs if i not in sel_idxs]
+            else:
+                sel_idxs = idxs
+                unsel_idxs = []
+            if len(cur_item['pos']) > 0:
+                pos_idxs = [cur_item['pos'][i] for i in sel_idxs]
+            else:
+                pos_idxs = [-1 for i in sel_idxs]
+            item_dict['caption'] = get_caption_pos(item_dict['caption'], pos_idxs, self.caption_pos_porb, self.place_holder)
+            item_dict['polygons'] = [cur_item['polygons'][i] for i in sel_idxs]
+            item_dict['texts'] = [cur_item['texts'][i][:self.max_chars] for i in sel_idxs]
+            item_dict['language'] = [cur_item['language'][i] for i in sel_idxs]
+            # glyphs
+            for idx, text in enumerate(item_dict['texts']):
+                gly_line = draw_glyph(self.font, text)
+                glyphs = draw_glyph2(self.font, text, item_dict['polygons'][idx], scale=self.glyph_scale)
+                item_dict['glyphs'] += [glyphs]
+                item_dict['gly_line'] += [gly_line]
+            # mask_pos
+            for polygon in item_dict['polygons']:
+                item_dict['positions'] += [self.draw_pos(polygon, self.mask_pos_prob)]
+        # inv_mask
+        invalid_polygons = cur_item['invalid_polygons'] if 'invalid_polygons' in cur_item else []
+        if len(texts) > 0:
+            invalid_polygons += [cur_item['polygons'][i] for i in unsel_idxs]
+        item_dict['inv_mask'] = self.draw_inv_mask(invalid_polygons)
+        item_dict['hint'] = self.get_hint(item_dict['positions'])
+        if random.random() < self.mask_img_prob:
+            # randomly generate 0~3 masks
+            box_num = random.randint(0, 3)
+            boxes = generate_random_rectangles(512, 512, box_num)
+            boxes = np.array(boxes)
+            pos_list = item_dict['positions'].copy()
+            for i in range(box_num):
+                pos_list += [self.draw_pos(boxes[i], self.mask_pos_prob)]
+            mask = self.get_hint(pos_list)
+            masked_img = target*(1-mask)
+        else:
+            masked_img = np.zeros_like(target)
+        item_dict['masked_img'] = masked_img
+
+        if self.for_show:
+            item_dict['img_name'] = os.path.split(cur_item['img_path'])[-1]
+            return item_dict
+        if len(texts) > 0:
+            del item_dict['polygons']
+        # padding
+        n_lines = min(len(texts), self.max_lines)
+        item_dict['n_lines'] = n_lines
+        n_pad = self.max_lines - n_lines
+        if n_pad > 0:
+            item_dict['glyphs'] += [np.zeros((512*self.glyph_scale, 512*self.glyph_scale, 1))] * n_pad
+            item_dict['gly_line'] += [np.zeros((80, 512, 1))] * n_pad
+            item_dict['positions'] += [np.zeros((512, 512, 1))] * n_pad
+            item_dict['texts'] += [' '] * n_pad
+            item_dict['language'] += [' '] * n_pad
+
+        return item_dict
+
+    def __len__(self):
+        return len(self.data_list)
+
+    def draw_inv_mask(self, polygons):
+        img = np.zeros((512, 512))
+        for p in polygons:
+            pts = p.reshape((-1, 1, 2))
+            cv2.fillPoly(img, [pts], color=255)
+        img = img[..., None]
+        return img/255.
+
+    def draw_pos(self, ploygon, prob=1.0):
+        img = np.zeros((512, 512))
+        rect = cv2.minAreaRect(ploygon)
+        w, h = rect[1]
+        small = False
+        if w < 20 or h < 20:
+            small = True
+        if random.random() < prob:
+            pts = ploygon.reshape((-1, 1, 2))
+            cv2.fillPoly(img, [pts], color=255)
+            # 10% dilate / 10% erode / 5% dilatex2  5% erodex2
+            random_value = random.random()
+            kernel = np.ones((3, 3), dtype=np.uint8)
+            if random_value < 0.7:
+                pass
+            elif random_value < 0.8:
+                img = cv2.dilate(img.astype(np.uint8), kernel, iterations=1)
+            elif random_value < 0.9 and not small:
+                img = cv2.erode(img.astype(np.uint8), kernel, iterations=1)
+            elif random_value < 0.95:
+                img = cv2.dilate(img.astype(np.uint8), kernel, iterations=2)
+            elif random_value < 1.0 and not small:
+                img = cv2.erode(img.astype(np.uint8), kernel, iterations=2)
+        img = img[..., None]
+        return img/255.
+
+    def get_hint(self, positions):
+        if len(positions) == 0:
+            return np.zeros((512, 512, 1))
+        return np.sum(positions, axis=0).clip(0, 1)
+
+
+if __name__ == '__main__':
+    '''
+    Run this script to show details of your dataset, such as ocr annotations, glyphs, prompts, etc.
+    '''
+    from tqdm import tqdm
+    from matplotlib import pyplot as plt
+    import shutil
+
+    show_imgs_dir = 'show_results'
+    show_count = 50
+    if os.path.exists(show_imgs_dir):
+        shutil.rmtree(show_imgs_dir)
+    os.makedirs(show_imgs_dir)
+    plt.rcParams['axes.unicode_minus'] = False
+    json_paths = [
+        '/path/of/your/dataset/data1.json',
+        '/path/of/your/dataset/data2.json',
+        # ...
+    ]
+
+    dataset = T3DataSet(json_paths, for_show=True, max_lines=20, glyph_scale=2, mask_img_prob=1.0, caption_pos_prob=0.0)
+    train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+    pbar = tqdm(total=show_count)
+    for i, data in enumerate(train_loader):
+        if i == show_count:
+            break
+        img = ((data['img'][0].numpy() + 1.0) / 2.0 * 255).astype(np.uint8)
+        masked_img = ((data['masked_img'][0].numpy() + 1.0) / 2.0 * 255)[..., ::-1].astype(np.uint8)
+        cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_masked.jpg'), masked_img)
+        if 'texts' in data and len(data['texts']) > 0:
+            texts = [x[0] for x in data['texts']]
+            img = show_bbox_on_image(Image.fromarray(img), data['polygons'], texts)
+        cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}.jpg'),  np.array(img)[..., ::-1])
+        with open(os.path.join(show_imgs_dir, f'plots_{i}.txt'), 'w') as fin:
+            fin.writelines([data['caption'][0]])
+        all_glyphs = []
+        for k, glyphs in enumerate(data['glyphs']):
+            cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_glyph_{k}.jpg'), glyphs[0].numpy().astype(np.int32)*255)
+            all_glyphs += [glyphs[0].numpy().astype(np.int32)*255]
+        cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_allglyphs.jpg'), np.sum(all_glyphs, axis=0))
+        for k, gly_line in enumerate(data['gly_line']):
+            cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_gly_line_{k}.jpg'), gly_line[0].numpy().astype(np.int32)*255)
+        for k, position in enumerate(data['positions']):
+            if position is not None:
+                cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_pos_{k}.jpg'), position[0].numpy().astype(np.int32)*255)
+        cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_hint.jpg'), data['hint'][0].numpy().astype(np.int32)*255)
+        cv2.imwrite(os.path.join(show_imgs_dir, f'plots_{i}_inv_mask.jpg'), np.array(img)[..., ::-1]*(1-data['inv_mask'][0].numpy().astype(np.int32)))
+        pbar.update(1)
+    pbar.close()
diff --git a/util.py b/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8b953f0d18a29f846763346150d7af086207d7
--- /dev/null
+++ b/util.py
@@ -0,0 +1,43 @@
+import datetime
+import os
+import cv2
+
+
+def save_images(img_list, folder):
+    if not os.path.exists(folder):
+        os.makedirs(folder)
+    now = datetime.datetime.now()
+    date_str = now.strftime("%Y-%m-%d")
+    folder_path = os.path.join(folder, date_str)
+    if not os.path.exists(folder_path):
+        os.makedirs(folder_path)
+    time_str = now.strftime("%H_%M_%S")
+    for idx, img in enumerate(img_list):
+        image_number = idx + 1
+        filename = f"{time_str}_{image_number}.jpg"
+        save_path = os.path.join(folder_path, filename)
+        cv2.imwrite(save_path, img[..., ::-1])
+
+
+def check_channels(image):
+    channels = image.shape[2] if len(image.shape) == 3 else 1
+    if channels == 1:
+        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+    elif channels > 3:
+        image = image[:, :, :3]
+    return image
+
+
+def resize_image(img, max_length=768):
+    height, width = img.shape[:2]
+    max_dimension = max(height, width)
+
+    if max_dimension > max_length:
+        scale_factor = max_length / max_dimension
+        new_width = int(round(width * scale_factor))
+        new_height = int(round(height * scale_factor))
+        new_size = (new_width, new_height)
+        img = cv2.resize(img, new_size)
+    height, width = img.shape[:2]
+    img = cv2.resize(img, (width-(width % 64), height-(height % 64)))
+    return img