"""Utility functions."""

import base64
import os
import subprocess
import cv2
import numpy as np

import torch

from models import MODEL_ZOO
from models import build_generator
from models import parse_gan_type

__all__ = ['postprocess', 'load_generator', 'factorize_weight',
           'HtmlPageVisualizer']

CHECKPOINT_DIR = 'checkpoints'


def to_tensor(array):
    """Converts a `numpy.ndarray` to `torch.Tensor`.

    Args:
      array: The input array to convert.

    Returns:
      A `torch.Tensor` with dtype `torch.FloatTensor` on cuda device.
    """
    assert isinstance(array, np.ndarray)
    return torch.from_numpy(array).type(torch.FloatTensor)


def postprocess(images, min_val=-1.0, max_val=1.0):
    """Post-processes images from `torch.Tensor` to `numpy.ndarray`.

    Args:
        images: A `torch.Tensor` with shape `NCHW` to process.
        min_val: The minimum value of the input tensor. (default: -1.0)
        max_val: The maximum value of the input tensor. (default: 1.0)

    Returns:
        A `numpy.ndarray` with shape `NHWC` and pixel range [0, 255].
    """
    assert isinstance(images, torch.Tensor)
    images = images.detach().cpu().numpy()
    images = (images - min_val) * 255 / (max_val - min_val)
    images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
    images = images.transpose(0, 2, 3, 1)
    return images


def load_generator(model_name, from_hf_hub=False):
    """Loads pre-trained generator.

    Args:
        model_name: Name of the model. Should be a key in `models.MODEL_ZOO`.

    Returns:
        A generator, which is a `torch.nn.Module`, with pre-trained weights
            loaded.

    Raises:
        KeyError: If the input `model_name` is not in `models.MODEL_ZOO`.
    """
    if model_name not in MODEL_ZOO:
        raise KeyError(f'Unknown model name `{model_name}`!')

    model_config = MODEL_ZOO[model_name].copy()
    url = model_config.pop('url')  # URL to download model if needed.

    # Build generator.
    print(f'Building generator for model `{model_name}` ...')
    generator = build_generator(**model_config)
    print(f'Finish building generator.')

    if from_hf_hub and "hf_hub_repo" in model_config.keys():
        checkpoint = generator.from_pretrained(model_config["hf_hub_repo"])
        generator.load_state_dict(checkpoint)
        print("loaded from hf_hub")
    else:
        # Load pre-trained weights.
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
        print(f'Loading checkpoint from `{checkpoint_path}` ...')
        if not os.path.exists(checkpoint_path):
            print(f'  Downloading checkpoint from `{url}` ...')
            subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
            print(f'  Finish downloading checkpoint.')
        checkpoint = torch.load(checkpoint_path, map_location='cpu')

        if 'generator_smooth' in checkpoint:
            generator.load_state_dict(checkpoint['generator_smooth'])
        else:
            generator.load_state_dict(checkpoint['generator'])
    #generator = generator.cuda()
    generator.eval()
    print(f'Finish loading checkpoint.')
    return generator


def parse_indices(obj, min_val=None, max_val=None):
    """Parses indices.

    The input can be a list or a tuple or a string, which is either a comma
    separated list of numbers 'a, b, c', or a dash separated range 'a - c'.
    Space in the string will be ignored.

    Args:
        obj: The input object to parse indices from.
        min_val: If not `None`, this function will check that all indices are
            equal to or larger than this value. (default: None)
        max_val: If not `None`, this function will check that all indices are
            equal to or smaller than this value. (default: None)

    Returns:
        A list of integers.

    Raises:
        If the input is invalid, i.e., neither a list or tuple, nor a string.
    """
    if obj is None or obj == '':
        indices = []
    elif isinstance(obj, int):
        indices = [obj]
    elif isinstance(obj, (list, tuple, np.ndarray)):
        indices = list(obj)
    elif isinstance(obj, str):
        indices = []
        splits = obj.replace(' ', '').split(',')
        for split in splits:
            numbers = list(map(int, split.split('-')))
            if len(numbers) == 1:
                indices.append(numbers[0])
            elif len(numbers) == 2:
                indices.extend(list(range(numbers[0], numbers[1] + 1)))
            else:
                raise ValueError(f'Unable to parse the input!')

    else:
        raise ValueError(f'Invalid type of input: `{type(obj)}`!')

    assert isinstance(indices, list)
    indices = sorted(list(set(indices)))
    for idx in indices:
        assert isinstance(idx, int)
        if min_val is not None:
            assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!'
        if max_val is not None:
            assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!'

    return indices


def factorize_weight(generator, layer_idx='all'):
    """Factorizes the generator weight to get semantics boundaries.

    Args:
        generator: Generator to factorize.
        layer_idx: Indices of layers to interpret, especially for StyleGAN and
            StyleGAN2. (default: `all`)

    Returns:
        A tuple of (layers_to_interpret, semantic_boundaries, eigen_values).

    Raises:
        ValueError: If the generator type is not supported.
    """
    # Get GAN type.
    gan_type = parse_gan_type(generator)

    # Get layers.
    if gan_type == 'pggan':
        layers = [0]
    elif gan_type in ['stylegan', 'stylegan2']:
        if layer_idx == 'all':
            layers = list(range(generator.num_layers))
        else:
            layers = parse_indices(layer_idx,
                                   min_val=0,
                                   max_val=generator.num_layers - 1)

    # Factorize semantics from weight.
    weights = []
    for idx in layers:
        layer_name = f'layer{idx}'
        if gan_type == 'stylegan2' and idx == generator.num_layers - 1:
            layer_name = f'output{idx // 2}'
        if gan_type == 'pggan':
            weight = generator.__getattr__(layer_name).weight
            weight = weight.flip(2, 3).permute(1, 0, 2, 3).flatten(1)
        elif gan_type in ['stylegan', 'stylegan2']:
            weight = generator.synthesis.__getattr__(layer_name).style.weight.T
        weights.append(weight.cpu().detach().numpy())
    weight = np.concatenate(weights, axis=1).astype(np.float32)
    weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
    eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))

    return layers, eigen_vectors.T, eigen_values


def get_sortable_html_header(column_name_list, sort_by_ascending=False):
    """Gets header for sortable html page.

    Basically, the html page contains a sortable table, where user can sort the
    rows by a particular column by clicking the column head.

    Example:

    column_name_list = [name_1, name_2, name_3]
    header = get_sortable_html_header(column_name_list)
    footer = get_sortable_html_footer()
    sortable_table = ...
    html_page = header + sortable_table + footer

    Args:
        column_name_list: List of column header names.
        sort_by_ascending: Default sorting order. If set as `True`, the html
            page will be sorted by ascending order when the header is clicked
            for the first time.

    Returns:
        A string, which represents for the header for a sortable html page.
    """
    header = '\n'.join([
        '<script type="text/javascript">',
        'var column_idx;',
        'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';',
        '',
        'function sorting(tbody, column_idx){',
        '    this.column_idx = column_idx;',
        '    Array.from(tbody.rows)',
        '             .sort(compareCells)',
        '             .forEach(function(row) { tbody.appendChild(row); })',
        '    sort_by_ascending = !sort_by_ascending;',
        '}',
        '',
        'function compareCells(row_a, row_b) {',
        '    var val_a = row_a.cells[column_idx].innerText;',
        '    var val_b = row_b.cells[column_idx].innerText;',
        '    var flag = sort_by_ascending ? 1 : -1;',
        '    return flag * (val_a > val_b ? 1 : -1);',
        '}',
        '</script>',
        '',
        '<html>',
        '',
        '<head>',
        '<style>',
        '    table {',
        '        border-spacing: 0;',
        '        border: 1px solid black;',
        '    }',
        '    th {',
        '        cursor: pointer;',
        '    }',
        '    th, td {',
        '        text-align: left;',
        '        vertical-align: middle;',
        '        border-collapse: collapse;',
        '        border: 0.5px solid black;',
        '        padding: 8px;',
        '    }',
        '    tr:nth-child(even) {',
        '        background-color: #d2d2d2;',
        '    }',
        '</style>',
        '</head>',
        '',
        '<body>',
        '',
        '<table>',
        '<thead>',
        '<tr>',
        ''])
    for idx, name in enumerate(column_name_list):
        header += f'    <th onclick="sorting(tbody, {idx})">{name}</th>\n'
    header += '</tr>\n'
    header += '</thead>\n'
    header += '<tbody id="tbody">\n'

    return header


def get_sortable_html_footer():
    """Gets footer for sortable html page.

    Check function `get_sortable_html_header()` for more details.
    """
    return '</tbody>\n</table>\n\n</body>\n</html>\n'


def parse_image_size(obj):
    """Parses object to a pair of image size, i.e., (width, height).

    Args:
        obj: The input object to parse image size from.

    Returns:
        A two-element tuple, indicating image width and height respectively.

    Raises:
        If the input is invalid, i.e., neither a list or tuple, nor a string.
    """
    if obj is None or obj == '':
        width = height = 0
    elif isinstance(obj, int):
        width = height = obj
    elif isinstance(obj, (list, tuple, np.ndarray)):
        numbers = tuple(obj)
        if len(numbers) == 0:
            width = height = 0
        elif len(numbers) == 1:
            width = height = numbers[0]
        elif len(numbers) == 2:
            width = numbers[0]
            height = numbers[1]
        else:
            raise ValueError(f'At most two elements for image size.')
    elif isinstance(obj, str):
        splits = obj.replace(' ', '').split(',')
        numbers = tuple(map(int, splits))
        if len(numbers) == 0:
            width = height = 0
        elif len(numbers) == 1:
            width = height = numbers[0]
        elif len(numbers) == 2:
            width = numbers[0]
            height = numbers[1]
        else:
            raise ValueError(f'At most two elements for image size.')
    else:
        raise ValueError(f'Invalid type of input: {type(obj)}!')

    return (max(0, width), max(0, height))


def encode_image_to_html_str(image, image_size=None):
    """Encodes an image to html language.
    NOTE: Input image is always assumed to be with `RGB` channel order.
    Args:
        image: The input image to encode. Should be with `RGB` channel order.
        image_size: This field is used to resize the image before encoding. `0`
            disables resizing. (default: None)
    Returns:
        A string which represents the encoded image.
    """
    if image is None:
        return ''

    assert image.ndim == 3 and image.shape[2] in [1, 3]

    # Change channel order to `BGR`, which is opencv-friendly.
    image = image[:, :, ::-1]

    # Resize the image if needed.
    width, height = parse_image_size(image_size)
    if height or width:
        height = height or image.shape[0]
        width = width or image.shape[1]
        image = cv2.resize(image, (width, height))

    # Encode the image to html-format string.
    encoded_image = cv2.imencode('.jpg', image)[1].tostring()
    encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8')
    html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>'

    return html_str


def get_grid_shape(size, row=0, col=0, is_portrait=False):
    """Gets the shape of a grid based on the size.

    This function makes greatest effort on making the output grid square if
    neither `row` nor `col` is set. If `is_portrait` is set as `False`, the
    height will always be equal to or smaller than the width. For example, if
    input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`,
    output shape will be (3, 5). Otherwise, the height will always be equal to
    or larger than the width.

    Args:
        size: Size (height * width) of the target grid.
        is_portrait: Whether to return a portrait size of a landscape size.
            (default: False)

    Returns:
        A two-element tuple, representing height and width respectively.
    """
    assert isinstance(size, int)
    assert isinstance(row, int)
    assert isinstance(col, int)
    if size == 0:
        return (0, 0)

    if row > 0 and col > 0 and row * col != size:
        row = 0
        col = 0

    if row > 0 and size % row == 0:
        return (row, size // row)
    if col > 0 and size % col == 0:
        return (size // col, col)

    row = int(np.sqrt(size))
    while row > 0:
        if size % row == 0:
            col = size // row
            break
        row = row - 1

    return (col, row) if is_portrait else (row, col)


class HtmlPageVisualizer(object):
    """Defines the html page visualizer.

    This class can be used to visualize image results as html page. Basically,
    it is based on an html-format sorted table with helper functions
    `get_sortable_html_header()`, `get_sortable_html_footer()`, and
    `encode_image_to_html_str()`. To simplify the usage, specifying the
    following fields are enough to create a visualization page:

    (1) num_rows: Number of rows of the table (header-row exclusive).
    (2) num_cols: Number of columns of the table.
    (3) header contents (optional): Title of each column.

    NOTE: `grid_size` can be used to assign `num_rows` and `num_cols`
    automatically.

    Example:

    html = HtmlPageVisualizer(num_rows, num_cols)
    html.set_headers([...])
    for i in range(num_rows):
        for j in range(num_cols):
            html.set_cell(i, j, text=..., image=..., highlight=False)
    html.save('visualize.html')
    """

    def __init__(self,
                 num_rows=0,
                 num_cols=0,
                 grid_size=0,
                 is_portrait=True,
                 viz_size=None):
        if grid_size > 0:
            num_rows, num_cols = get_grid_shape(
                grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait)
        assert num_rows > 0 and num_cols > 0

        self.num_rows = num_rows
        self.num_cols = num_cols
        self.viz_size = parse_image_size(viz_size)
        self.headers = ['' for _ in range(self.num_cols)]
        self.cells = [[{
            'text': '',
            'image': '',
            'highlight': False,
        } for _ in range(self.num_cols)] for _ in range(self.num_rows)]

    def set_header(self, col_idx, content):
        """Sets the content of a particular header by column index."""
        self.headers[col_idx] = content

    def set_headers(self, contents):
        """Sets the contents of all headers."""
        if isinstance(contents, str):
            contents = [contents]
        assert isinstance(contents, (list, tuple))
        assert len(contents) == self.num_cols
        for col_idx, content in enumerate(contents):
            self.set_header(col_idx, content)

    def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False):
        """Sets the content of a particular cell.

        Basically, a cell contains some text as well as an image. Both text and
        image can be empty.

        Args:
            row_idx: Row index of the cell to edit.
            col_idx: Column index of the cell to edit.
            text: Text to add into the target cell. (default: None)
            image: Image to show in the target cell. Should be with `RGB`
                channel order. (default: None)
            highlight: Whether to highlight this cell. (default: False)
        """
        self.cells[row_idx][col_idx]['text'] = text
        self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str(
            image, self.viz_size)
        self.cells[row_idx][col_idx]['highlight'] = bool(highlight)

    def save(self, save_path):
        """Saves the html page."""
        html = ''
        for i in range(self.num_rows):
            html += f'<tr>\n'
            for j in range(self.num_cols):
                text = self.cells[i][j]['text']
                image = self.cells[i][j]['image']
                if self.cells[i][j]['highlight']:
                    color = ' bgcolor="#FF8888"'
                else:
                    color = ''
                if text:
                    html += f'    <td{color}>{text}<br><br>{image}</td>\n'
                else:
                    html += f'    <td{color}>{image}</td>\n'
            html += f'</tr>\n'

        header = get_sortable_html_header(self.headers)
        footer = get_sortable_html_footer()

        with open(save_path, 'w') as f:
            f.write(header + html + footer)