from typing import Any, ByteString, Callable
import pickle
import cloudpickle
import zlib
import numpy as np


class CloudPickleWrapper:
    """
    Overview:
        CloudPickleWrapper can be able to pickle more python object(e.g: an object with lambda expression).
    Interfaces:
        ``__init__``, ``__getstate__``, ``__setstate__``.
    """

    def __init__(self, data: Any) -> None:
        """
        Overview:
            Initialize the CloudPickleWrapper using the given arguments.
        Arguments:
            - data (:obj:`Any`): The object to be dumped.
        """
        self.data = data

    def __getstate__(self) -> bytes:
        """
        Overview:
            Get the state of the CloudPickleWrapper.
        Returns:
            - data (:obj:`bytes`): The dumped byte-like result.
        """

        return cloudpickle.dumps(self.data)

    def __setstate__(self, data: bytes) -> None:
        """
        Overview:
            Set the state of the CloudPickleWrapper.
        Arguments:
            - data (:obj:`bytes`): The dumped byte-like result.
        """

        if isinstance(data, (tuple, list, np.ndarray)):  # pickle is faster
            self.data = pickle.loads(data)
        else:
            self.data = cloudpickle.loads(data)


def dummy_compressor(data: Any) -> Any:
    """
    Overview:
        Return the raw input data.
    Arguments:
        - data (:obj:`Any`): The input data of the compressor.
    Returns:
        - output (:obj:`Any`): This compressor will exactly return the input data.
    """
    return data


def zlib_data_compressor(data: Any) -> bytes:
    """
    Overview:
        Takes the input compressed data and return the compressed original data (zlib compressor) in binary format.
    Arguments:
        - data (:obj:`Any`): The input data of the compressor.
    Returns:
        - output (:obj:`bytes`): The compressed byte-like result.
    Examples:
        >>> zlib_data_compressor("Hello")
        b'x\x9ck`\x99\xca\xc9\x00\x01=\xac\x1e\xa999\xf9S\xf4\x00%L\x04j'
    """
    return zlib.compress(pickle.dumps(data))


def lz4_data_compressor(data: Any) -> bytes:
    """
    Overview:
        Return the compressed original data (lz4 compressor).The compressor outputs in binary format.
    Arguments:
        - data (:obj:`Any`): The input data of the compressor.
    Returns:
        - output (:obj:`bytes`): The compressed byte-like result.
    Examples:
        >>> lz4.block.compress(pickle.dumps("Hello"))
        b'\x14\x00\x00\x00R\x80\x04\x95\t\x00\x01\x00\x90\x8c\x05Hello\x94.'
    """
    try:
        import lz4.block
    except ImportError:
        from ditk import logging
        import sys
        logging.warning("Please install lz4 first, such as `pip3 install lz4`")
        sys.exit(1)
    return lz4.block.compress(pickle.dumps(data))


def jpeg_data_compressor(data: np.ndarray) -> bytes:
    """
    Overview:
        To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in \
        the buffer. This function encodes the observation numpy arr to the jpeg strings.
    Arguments:
        - data (:obj:`np.array`): the observation numpy arr.
    Returns:
        - img_str (:obj:`bytes`): The compressed byte-like result.
    """
    try:
        import cv2
    except ImportError:
        from ditk import logging
        import sys
        logging.warning("Please install opencv-python first.")
        sys.exit(1)
    img_str = cv2.imencode('.jpg', data)[1].tobytes()

    return img_str


_COMPRESSORS_MAP = {
    'lz4': lz4_data_compressor,
    'zlib': zlib_data_compressor,
    'jpeg': jpeg_data_compressor,
    'none': dummy_compressor,
}


def get_data_compressor(name: str):
    """
    Overview:
        Get the data compressor according to the input name.
    Arguments:
        - name(:obj:`str`): Name of the compressor, support ``['lz4', 'zlib', 'jpeg', 'none']``
    Return:
        - compressor (:obj:`Callable`): Corresponding data_compressor, taking input data returning compressed data.
    Example:
        >>> compress_fn = get_data_compressor('lz4')
        >>> compressed_data = compressed(input_data)
    """
    return _COMPRESSORS_MAP[name]


def dummy_decompressor(data: Any) -> Any:
    """
    Overview:
        Return the input data.
    Arguments:
        - data (:obj:`Any`): The input data of the decompressor.
    Returns:
        - output (:obj:`bytes`): The decompressed result, which is exactly the input.
    """
    return data


def lz4_data_decompressor(compressed_data: bytes) -> Any:
    """
    Overview:
        Return the decompressed original data (lz4 compressor).
    Arguments:
        - data (:obj:`bytes`): The input data of the decompressor.
    Returns:
        - output (:obj:`Any`): The decompressed object.
    """
    try:
        import lz4.block
    except ImportError:
        from ditk import logging
        import sys
        logging.warning("Please install lz4 first, such as `pip3 install lz4`")
        sys.exit(1)
    return pickle.loads(lz4.block.decompress(compressed_data))


def zlib_data_decompressor(compressed_data: bytes) -> Any:
    """
    Overview:
        Return the decompressed original data (zlib compressor).
    Arguments:
        - data (:obj:`bytes`): The input data of the decompressor.
    Returns:
        - output (:obj:`Any`): The decompressed object.
    """
    return pickle.loads(zlib.decompress(compressed_data))


def jpeg_data_decompressor(compressed_data: bytes, gray_scale=False) -> np.ndarray:
    """
    Overview:
        To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in the \
        buffer. This function decodes the observation numpy arr from the jpeg strings.
    Arguments:
        - compressed_data (:obj:`bytes`): The jpeg strings.
        - gray_scale (:obj:`bool`): If the observation is gray, ``gray_scale=True``,
            if the observation is RGB, ``gray_scale=False``.
    Returns:
        - arr (:obj:`np.ndarray`): The decompressed numpy array.
    """
    try:
        import cv2
    except ImportError:
        from ditk import logging
        import sys
        logging.warning("Please install opencv-python first.")
        sys.exit(1)
    nparr = np.frombuffer(compressed_data, np.uint8)
    if gray_scale:
        arr = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
        arr = np.expand_dims(arr, -1)
    else:
        arr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

    return arr


_DECOMPRESSORS_MAP = {
    'lz4': lz4_data_decompressor,
    'zlib': zlib_data_decompressor,
    'jpeg': jpeg_data_decompressor,
    'none': dummy_decompressor,
}


def get_data_decompressor(name: str) -> Callable:
    """
    Overview:
        Get the data decompressor according to the input name.
    Arguments:
        - name(:obj:`str`): Name of the decompressor, support ``['lz4', 'zlib', 'none']``

    .. note::

        For all the decompressors, the input of a bytes-like object is required.

    Returns:
        - decompressor (:obj:`Callable`): Corresponding data decompressor.
    Examples:
        >>> decompress_fn = get_data_decompressor('lz4')
        >>> origin_data = compressed(compressed_data)
    """
    return _DECOMPRESSORS_MAP[name]