# -*- coding: utf-8 -*-
# @Organization  : Alibaba XR-Lab
# @Author        : Lingteng Qiu
# @Email         : 220019047@link.cuhk.edu.cn
# @Time          : 2025-03-03 10:28:35
# @Function      : Easy to use PSNR metric
import os
import sys

sys.path.append("./")

import math
import pdb

import cv2
import numpy as np
import skimage
import torch
from PIL import Image
from tqdm import tqdm


def write_json(path, x):
    """write a json file.

    Args:
        path (str): path to write json file.
        x (dict): dict to write.
    """
    import json

    with open(path, "w") as f:
        json.dump(x, f, indent=2)


def img_center_padding(img_np, pad_ratio=0.2, background=1):

    ori_w, ori_h = img_np.shape[:2]

    w = round((1 + pad_ratio) * ori_w)
    h = round((1 + pad_ratio) * ori_h)

    if background == 1:
        img_pad_np = np.ones((w, h, 3), dtype=img_np.dtype)
    else:
        img_pad_np = np.zeros((w, h, 3), dtype=img_np.dtype)
    offset_h, offset_w = (w - img_np.shape[0]) // 2, (h - img_np.shape[1]) // 2
    img_pad_np[
        offset_h : offset_h + img_np.shape[0] :, offset_w : offset_w + img_np.shape[1]
    ] = img_np

    return img_pad_np, offset_w, offset_h


def compute_psnr(src, tar):
    psnr = skimage.metrics.peak_signal_noise_ratio(tar, src, data_range=1)
    return psnr


def get_parse():
    import argparse

    parser = argparse.ArgumentParser(description="")
    parser.add_argument("-f1", "--folder1", required=True, help="input path")
    parser.add_argument("-f2", "--folder2", required=True, help="output path")
    parser.add_argument("-m", "--mask", default=None, help="output path")
    parser.add_argument("--pre", default="")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--pad", action="store_true", help="if the gt pad?")
    args = parser.parse_args()
    return args


def get_image_paths_current_dir(folder_path):
    image_extensions = {
        ".jpg",
        ".jpeg",
        ".png",
        ".gif",
        ".bmp",
        ".tiff",
        ".webp",
        ".jfif",
    }

    return sorted(
        [
            os.path.join(folder_path, f)
            for f in os.listdir(folder_path)
            if os.path.splitext(f)[1].lower() in image_extensions
        ]
    )


def psnr_compute(
    input_data,
    results_data,
    mask_data=None,
    pad=False,
):

    gt_imgs = get_image_paths_current_dir(input_data)
    result_imgs = get_image_paths_current_dir(os.path.join(results_data))

    if mask_data is not None:
        mask_imgs = get_image_paths_current_dir(mask_data)
    else:
        mask_imgs = None

    if "visualization" in result_imgs[-1]:
        result_imgs = result_imgs[:-1]

    if len(gt_imgs) != len(result_imgs):
        return -1

    psnr_mean = []

    for mask_i, (gt, result) in tqdm(enumerate(zip(gt_imgs, result_imgs))):
        result_img = (cv2.imread(result, cv2.IMREAD_UNCHANGED) / 255.0).astype(
            np.float32
        )
        gt_img = (cv2.imread(gt, cv2.IMREAD_UNCHANGED) / 255.0).astype(np.float32)

        if mask_imgs is not None:
            mask_img = (
                cv2.imread(mask_imgs[mask_i], cv2.IMREAD_UNCHANGED) / 255.0
            ).astype(np.float32)
            mask_img = mask_img[..., -1]
            mask_img = np.stack([mask_img] * 3, axis=-1)
            mask_img, _, _ = img_center_padding(mask_img, background=0)

        if pad:
            gt_img, _, _ = img_center_padding(gt_img)

        h, w, c = result_img.shape

        scale_h = int(h * 512 / w)

        gt_img = cv2.resize(gt_img, (512, scale_h), interpolation=cv2.INTER_AREA)
        result_img = cv2.resize(
            result_img, (512, scale_h), interpolation=cv2.INTER_AREA
        )

        if mask_imgs is not None:
            mask_img = cv2.resize(mask_img, (w, h), interpolation=cv2.INTER_AREA)
            gt_img = gt_img * mask_img + 1 - mask_img
            result_img = result_img * mask_img + 1 - mask_img
            mask_label = mask_img[..., 0]
            psnr_mean += [
                compute_psnr(result_img[mask_label > 0.5], gt_img[mask_label > 0.5])
            ]
        else:
            psnr_mean += [compute_psnr(result_img, gt_img)]

    psnr = np.mean(psnr_mean)

    return psnr


if __name__ == "__main__":

    opt = get_parse()

    input_folder = opt.folder1
    target_folder = opt.folder2
    mask_folder = opt.mask

    valid_txt = os.path.join(input_folder, "front_view.txt")

    target_folder = target_folder[:-1] if target_folder[-1] == "/" else target_folder

    if mask_folder is not None:
        mask_folder = mask_folder[:-1] if mask_folder[-1] == "/" else mask_folder

    target_key = target_folder.split("/")[-2:]

    save_folder = os.path.join(f"./exps/metrics{opt.pre}", "psnr_results", *target_key)
    os.makedirs(save_folder, exist_ok=True)

    with open(valid_txt) as f:
        items = f.read().splitlines()
        items = [x.split(" ")[0] for x in items]

    results_dict = dict()
    psnr_list = []

    for item in items:

        input_item_folder = os.path.join(input_folder, item)
        if mask_folder is not None:
            mask_item_folder = os.path.join(mask_folder, item)
        else:
            mask_item_folder = None
        target_item_folder = os.path.join(target_folder, item, "rgb")

        if os.path.exists(input_item_folder) and os.path.exists(target_item_folder):

            psnr = psnr_compute(
                input_item_folder, target_item_folder, mask_item_folder, opt.pad
            )

            if psnr == -1:
                continue

            psnr_list.append(psnr)

            results_dict[item] = psnr
            if opt.debug:
                break
            print(results_dict)

    results_dict["all_mean"] = np.mean(psnr_list)

    print(save_folder)

    print(results_dict)
    write_json(os.path.join(save_folder, "PSNR.json"), results_dict)