'''
@File       :   utils.py
@Time       :   2023/04/05 19:18:00
@Auther     :   Jiazheng Xu
@Contact    :   xjz22@mails.tsinghua.edu.cn
* Based on CLIP code base
* https://github.com/openai/CLIP
* Checkpoint of CLIP/BLIP/Aesthetic are from:
* https://github.com/openai/CLIP
* https://github.com/salesforce/BLIP
* https://github.com/christophschuhmann/improved-aesthetic-predictor
'''

import os
import urllib
from typing import Union, List
import pathlib

import torch
from tqdm import tqdm
from huggingface_hub import hf_hub_download

from .ImageReward import ImageReward
from .models.CLIPScore import CLIPScore
from .models.BLIPScore import BLIPScore
from .models.AestheticScore import AestheticScore

_MODELS = {
    "ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt",
}


def available_models() -> List[str]:
    """Returns the names of available ImageReward models"""
    return list(_MODELS.keys())


def ImageReward_download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)
    download_target = os.path.join(root, filename)
    hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root)
    return download_target


def load(name: str = "ImageReward-v1.0",
         device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
         download_root: str = None,
         med_config_path: str = None):
    """Load a ImageReward model

    Parameters
    ----------
    name: str
        A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
    device: Union[str, torch.device]
        The device to put the loaded model
    download_root: str
        path to download the model files; by default, it uses "~/.cache/ImageReward"
    med_config_path: str

    Returns
    -------
    model : torch.nn.Module
        The ImageReward model
    """
    if name in _MODELS:
        download_root = download_root or "~/.cache/ImageReward"
        download_root = pathlib.Path(download_root)
        model_path = pathlib.Path(download_root) / 'ImageReward.pt'

        if not model_path.exists():
            model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix())
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    print('-> load ImageReward model from %s' % model_path)
    state_dict = torch.load(model_path, map_location='cpu')

    # med_config
    if med_config_path is None:
        med_config_root = download_root or "~/.cache/ImageReward"
        med_config_root = pathlib.Path(med_config_root)
        med_config_path = med_config_root / 'med_config.json'

        if not med_config_path.exists():
            med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
                                                   root=med_config_root.as_posix())
        print('-> load ImageReward med_config from %s' % med_config_path)

    model = ImageReward(device=device, med_config=med_config_path).to(device)
    msg = model.load_state_dict(state_dict, strict=False)
    model.eval()

    return model


_SCORES = {
    "CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
    "BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
    "Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
}


def available_scores() -> List[str]:
    """Returns the names of available ImageReward scores"""
    return list(_SCORES.keys())


def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        return download_target

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
                  unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    return download_target


def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
               download_root: str = None):
    """Load a ImageReward model

    Parameters
    ----------
    name : str
        A model name listed by `ImageReward.available_models()`

    device : Union[str, torch.device]
        The device to put the loaded model

    download_root: str
        path to download the model files; by default, it uses "~/.cache/ImageReward"

    Returns
    -------
    model : torch.nn.Module
        The ImageReward model
    """
    model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward")

    if name in _SCORES:
        model_path = _download(_SCORES[name], model_download_root)
    else:
        raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")

    print('load checkpoint from %s' % model_path)
    if name == "BLIP":
        state_dict = torch.load(model_path, map_location='cpu')
        med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
                                          model_download_root)
        model = BLIPScore(med_config=med_config, device=device).to(device)
        model.blip.load_state_dict(state_dict['model'], strict=False)
    elif name == "CLIP":
        model = CLIPScore(download_root=model_download_root, device=device).to(device)
    elif name == "Aesthetic":
        state_dict = torch.load(model_path, map_location='cpu')
        model = AestheticScore(download_root=model_download_root, device=device).to(device)
        model.mlp.load_state_dict(state_dict, strict=False)
    else:
        raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")

    print("checkpoint loaded")
    model.eval()

    return model