# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from argparse import ArgumentParser
from functools import partial

import cv2
import numpy as np
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.parallel import collate
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from torch import nn

from mmocr.apis import init_detector
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
from mmocr.datasets.pipelines.crop import crop_img  # noqa: F401
from mmocr.utils import is_2dlist


def _convert_batchnorm(module):
    module_output = module
    if isinstance(module, torch.nn.SyncBatchNorm):
        module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
                                             module.momentum, module.affine,
                                             module.track_running_stats)
        if module.affine:
            module_output.weight.data = module.weight.data.clone().detach()
            module_output.bias.data = module.bias.data.clone().detach()
            # keep requires_grad unchanged
            module_output.weight.requires_grad = module.weight.requires_grad
            module_output.bias.requires_grad = module.bias.requires_grad
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
    for name, child in module.named_children():
        module_output.add_module(name, _convert_batchnorm(child))
    del module
    return module_output


def _prepare_data(cfg, imgs):
    """Inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
            Either image files or loaded images.
    Returns:
        result (dict): Predicted results.
    """
    if isinstance(imgs, (list, tuple)):
        if not isinstance(imgs[0], (np.ndarray, str)):
            raise AssertionError('imgs must be strings or numpy arrays')

    elif isinstance(imgs, (np.ndarray, str)):
        imgs = [imgs]
    else:
        raise AssertionError('imgs must be strings or numpy arrays')

    is_ndarray = isinstance(imgs[0], np.ndarray)

    if is_ndarray:
        cfg = cfg.copy()
        # set loading pipeline type
        cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'

    cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
    test_pipeline = Compose(cfg.data.test.pipeline)

    data = []
    for img in imgs:
        # prepare data
        if is_ndarray:
            # directly add img
            datum = dict(img=img)
        else:
            # add information into dict
            datum = dict(img_info=dict(filename=img), img_prefix=None)

        # build the data pipeline
        datum = test_pipeline(datum)
        # get tensor from list to stack for batch mode (text detection)
        data.append(datum)

    if isinstance(data[0]['img'], list) and len(data) > 1:
        raise Exception('aug test does not support '
                        f'inference with batch size '
                        f'{len(data)}')

    data = collate(data, samples_per_gpu=len(imgs))

    # process img_metas
    if isinstance(data['img_metas'], list):
        data['img_metas'] = [
            img_metas.data[0] for img_metas in data['img_metas']
        ]
    else:
        data['img_metas'] = data['img_metas'].data

    if isinstance(data['img'], list):
        data['img'] = [img.data for img in data['img']]
        if isinstance(data['img'][0], list):
            data['img'] = [img[0] for img in data['img']]
    else:
        data['img'] = data['img'].data
    return data


def pytorch2onnx(model: nn.Module,
                 model_type: str,
                 img_path: str,
                 verbose: bool = False,
                 show: bool = False,
                 opset_version: int = 11,
                 output_file: str = 'tmp.onnx',
                 verify: bool = False,
                 dynamic_export: bool = False,
                 device_id: int = 0):
    """Export PyTorch model to ONNX model and verify the outputs are same
    between PyTorch and ONNX.

    Args:
        model (nn.Module): PyTorch model we want to export.
        model_type (str): Model type, detection or recognition model.
        img_path (str): We need to use this input to execute the model.
        opset_version (int): The onnx op version. Default: 11.
        verbose (bool): Whether print the computation graph. Default: False.
        show (bool): Whether visialize final results. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between PyTorch and ONNX.
            Default: False.
        dynamic_export (bool): Whether apply dynamic export.
            Default: False.
        device_id (id): Device id to place model and data.
            Default: 0
    """
    device = torch.device(type='cuda', index=device_id)
    model.to(device).eval()
    _convert_batchnorm(model)

    # prepare inputs
    mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path)
    imgs = mm_inputs.pop('img')
    img_metas = mm_inputs.pop('img_metas')

    if isinstance(imgs, list):
        imgs = imgs[0]

    img_list = [img[None, :].to(device) for img in imgs]

    origin_forward = model.forward
    if (model_type == 'det'):
        model.forward = partial(
            model.simple_test, img_metas=img_metas, rescale=True)
    else:
        model.forward = partial(
            model.forward,
            img_metas=img_metas,
            return_loss=False,
            rescale=True)

    # pytorch has some bug in pytorch1.3, we have to fix it
    # by replacing these existing op
    register_extra_symbolics(opset_version)
    dynamic_axes = None
    if dynamic_export and model_type == 'det':
        dynamic_axes = {
            'input': {
                0: 'batch',
                2: 'height',
                3: 'width'
            },
            'output': {
                0: 'batch',
                2: 'height',
                3: 'width'
            }
        }
    elif dynamic_export and model_type == 'recog':
        dynamic_axes = {
            'input': {
                0: 'batch',
                3: 'width'
            },
            'output': {
                0: 'batch',
                1: 'seq_len',
                2: 'num_classes'
            }
        }
    with torch.no_grad():
        torch.onnx.export(
            model, (img_list[0], ),
            output_file,
            input_names=['input'],
            output_names=['output'],
            export_params=True,
            keep_initializers_as_inputs=False,
            verbose=verbose,
            opset_version=opset_version,
            dynamic_axes=dynamic_axes)
    print(f'Successfully exported ONNX model: {output_file}')
    if verify:
        # check by onnx
        import onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5)
        if dynamic_export:
            # scale image for dynamic shape test
            img_list = [
                nn.functional.interpolate(_, scale_factor=scale_factor)
                for _ in img_list
            ]
            if model_type == 'det':
                img_metas[0][0][
                    'scale_factor'] = img_metas[0][0]['scale_factor'] * (
                        scale_factor * 2)

        # check the numerical value
        # get pytorch output
        with torch.no_grad():
            model.forward = origin_forward
            pytorch_out = model.simple_test(
                img_list[0], img_metas[0], rescale=True)

        # get onnx output
        if model_type == 'det':
            onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id)
        else:
            onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg,
                                               device_id)
        onnx_out = onnx_model.simple_test(
            img_list[0], img_metas[0], rescale=True)

        # compare results
        same_diff = 'same'
        if model_type == 'recog':
            for onnx_result, pytorch_result in zip(onnx_out, pytorch_out):
                if onnx_result['text'] != pytorch_result[
                        'text'] or not np.allclose(
                            np.array(onnx_result['score']),
                            np.array(pytorch_result['score']),
                            rtol=1e-4,
                            atol=1e-4):
                    same_diff = 'different'
                    break
        else:
            for onnx_result, pytorch_result in zip(
                    onnx_out[0]['boundary_result'],
                    pytorch_out[0]['boundary_result']):
                if not np.allclose(
                        np.array(onnx_result),
                        np.array(pytorch_result),
                        rtol=1e-4,
                        atol=1e-4):
                    same_diff = 'different'
                    break
        print('The outputs are {} between PyTorch and ONNX'.format(same_diff))

        if show:
            onnx_img = onnx_model.show_result(
                img_path, onnx_out[0], out_file='onnx.jpg', show=False)
            pytorch_img = model.show_result(
                img_path, pytorch_out[0], out_file='pytorch.jpg', show=False)
            if onnx_img is None:
                onnx_img = cv2.imread(img_path)
            if pytorch_img is None:
                pytorch_img = cv2.imread(img_path)

            cv2.imshow('PyTorch', pytorch_img)
            cv2.imshow('ONNXRuntime', onnx_img)
            cv2.waitKey()
    return


def main():
    parser = ArgumentParser(
        description='Convert MMOCR models from pytorch to ONNX')
    parser.add_argument('model_config', type=str, help='Config file.')
    parser.add_argument(
        'model_ckpt', type=str, help='Checkpint file (local or url).')
    parser.add_argument(
        'model_type',
        type=str,
        help='Detection or recognition model to deploy.',
        choices=['recog', 'det'])
    parser.add_argument('image_path', type=str, help='Input Image file.')
    parser.add_argument(
        '--output-file',
        type=str,
        help='Output file name of the onnx model.',
        default='tmp.onnx')
    parser.add_argument(
        '--device-id', default=0, help='Device used for inference.')
    parser.add_argument(
        '--opset-version',
        type=int,
        help='ONNX opset version, default to 11.',
        default=11)
    parser.add_argument(
        '--verify',
        action='store_true',
        help='Whether verify the outputs of onnx and pytorch are same.',
        default=False)
    parser.add_argument(
        '--verbose',
        action='store_true',
        help='Whether print the computation graph.',
        default=False)
    parser.add_argument(
        '--show',
        action='store_true',
        help='Whether visualize final output.',
        default=False)
    parser.add_argument(
        '--dynamic-export',
        action='store_true',
        help='Whether dynamically export onnx model.',
        default=False)
    args = parser.parse_args()

    # Following strings of text style are from colorama package
    bright_style, reset_style = '\x1b[1m', '\x1b[0m'
    red_text, blue_text = '\x1b[31m', '\x1b[34m'
    white_background = '\x1b[107m'

    msg = white_background + bright_style + red_text
    msg += 'DeprecationWarning: This tool will be deprecated in future. '
    msg += blue_text + 'Welcome to use the unified model deployment toolbox '
    msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
    msg += reset_style
    warnings.warn(msg)

    device = torch.device(type='cuda', index=args.device_id)

    # build model
    model = init_detector(args.model_config, args.model_ckpt, device=device)
    if hasattr(model, 'module'):
        model = model.module
    if model.cfg.data.test.get('pipeline', None) is None:
        if is_2dlist(model.cfg.data.test.datasets):
            model.cfg.data.test.pipeline = \
                model.cfg.data.test.datasets[0][0].pipeline
        else:
            model.cfg.data.test.pipeline = \
                model.cfg.data.test['datasets'][0].pipeline
    if is_2dlist(model.cfg.data.test.pipeline):
        model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0]

    pytorch2onnx(
        model,
        model_type=args.model_type,
        output_file=args.output_file,
        img_path=args.image_path,
        opset_version=args.opset_version,
        verify=args.verify,
        verbose=args.verbose,
        show=args.show,
        device_id=args.device_id,
        dynamic_export=args.dynamic_export)


if __name__ == '__main__':
    main()