import io
from ditk import logging
import os
import pickle
import time
from functools import lru_cache
from typing import Union

import torch

from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc
from .lock_helper import get_file_lock

_memcached = None
_redis_cluster = None

if os.environ.get('DI_STORE', 'off').lower() == 'on':
    print('Enable DI-store')
    from di_store import Client

    di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml')
    di_store_client = Client(di_store_config_path)

    def save_to_di_store(data):
        return di_store_client.put(data)

    def read_from_di_store(object_ref):
        data = di_store_client.get(object_ref)
        di_store_client.delete(object_ref)
        return data
else:
    save_to_di_store = read_from_di_store = None


@lru_cache()
def get_ceph_package():
    return try_import_ceph()


@lru_cache()
def get_redis_package():
    return try_import_redis()


@lru_cache()
def get_rediscluster_package():
    return try_import_rediscluster()


@lru_cache()
def get_mc_package():
    return try_import_mc()


def read_from_ceph(path: str) -> object:
    """
    Overview:
        Read file from ceph
    Arguments:
        - path (:obj:`str`): File path in ceph, start with ``"s3://"``
    Returns:
        - (:obj:`data`): Deserialized data
    """
    value = get_ceph_package().Get(path)
    if not value:
        raise FileNotFoundError("File({}) doesn't exist in ceph".format(path))

    return pickle.loads(value)


@lru_cache()
def _get_redis(host='localhost', port=6379):
    """
    Overview:
        Ensures redis usage
    Arguments:
        - host (:obj:`str`): Host string
        - port (:obj:`int`): Port number
    Returns:
        - (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0``
    """
    return get_redis_package().StrictRedis(host=host, port=port, db=0)


def read_from_redis(path: str) -> object:
    """
    Overview:
        Read file from redis
    Arguments:
        - path (:obj:`str`): Dile path in redis, could be a string key
    Returns:
        - (:obj:`data`): Deserialized data
    """
    return pickle.loads(_get_redis().get(path))


def _ensure_rediscluster(startup_nodes=[{"host": "127.0.0.1", "port": "7000"}]):
    """
    Overview:
        Ensures redis usage
    Arguments:
        - List of startup nodes (:obj:`dict`) of
            - host (:obj:`str`): Host string
            - port (:obj:`int`): Port number
    Returns:
        - (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \
            and ``False`` for ``decode_responses`` in default.
    """
    global _redis_cluster
    if _redis_cluster is None:
        _redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False)
    return


def read_from_rediscluster(path: str) -> object:
    """
    Overview:
        Read file from rediscluster
    Arguments:
        - path (:obj:`str`): Dile path in rediscluster, could be a string key
    Returns:
        - (:obj:`data`): Deserialized data
    """
    _ensure_rediscluster()
    value_bytes = _redis_cluster.get(path)
    value = pickle.loads(value_bytes)
    return value


def read_from_file(path: str) -> object:
    """
    Overview:
        Read file from local file system
    Arguments:
        - path (:obj:`str`): File path in local file system
    Returns:
        - (:obj:`data`): Deserialized data
    """
    with open(path, "rb") as f:
        value = pickle.load(f)

    return value


def _ensure_memcached():
    """
    Overview:
        Ensures memcache usage
    Returns:
        - (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \
            memcached_client's ``server_list.conf`` and ``client.conf`` files
    """
    global _memcached
    if _memcached is None:
        server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf"
        client_config_file = "/mnt/lustre/share/memcached_client/client.conf"
        _memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file)
    return


def read_from_mc(path: str, flush=False) -> object:
    """
    Overview:
        Read file from memcache, file must be saved by `torch.save()`
    Arguments:
        - path (:obj:`str`): File path in local system
    Returns:
        - (:obj:`data`): Deserialized data
    """
    _ensure_memcached()
    while True:
        try:
            value = get_mc_package().pyvector()
            if flush:
                _memcached.Get(path, value, get_mc_package().MC_READ_THROUGH)
                return
            else:
                _memcached.Get(path, value)
            value_buf = get_mc_package().ConvertBuffer(value)
            value_str = io.BytesIO(value_buf)
            value_str = torch.load(value_str, map_location='cpu')
            return value_str
        except Exception:
            print('read mc failed, retry...')
            time.sleep(0.01)


def read_from_path(path: str):
    """
    Overview:
        Read file from ceph
    Arguments:
        - path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system
    Returns:
        - (:obj:`data`): Deserialized data
    """
    if get_ceph_package() is None:
        logging.info(
            "You do not have ceph installed! Loading local file!"
            " If you are not testing locally, something is wrong!"
        )
        return read_from_file(path)
    else:
        return read_from_ceph(path)


def save_file_ceph(path, data):
    """
    Overview:
        Save pickle dumped data file to ceph
    Arguments:
        - path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not
        - data (:obj:`Any`): Could be dict, list or tensor etc.
    """
    data = pickle.dumps(data)
    save_path = os.path.dirname(path)
    file_name = os.path.basename(path)
    ceph = get_ceph_package()
    if ceph is not None:
        if hasattr(ceph, 'save_from_string'):
            ceph.save_from_string(save_path, file_name, data)
        elif hasattr(ceph, 'put'):
            ceph.put(os.path.join(save_path, file_name), data)
        else:
            raise RuntimeError('ceph can not save file, check your ceph installation')
    else:
        size = len(data)
        if save_path == 'do_not_save':
            logging.info(
                "You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) +
                " If you are not testing locally, something is wrong!"
            )
            return
        p = os.path.join(save_path, file_name)
        with open(p, 'wb') as f:
            logging.info(
                "You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) +
                " If you are not testing locally, something is wrong!"
            )
            f.write(data)


def save_file_redis(path, data):
    """
    Overview:
        Save pickle dumped data file to redis
    Arguments:
        - path (:obj:`str`): File path (could be a string key) in redis
        - data (:obj:`Any`): Could be dict, list or tensor etc.
    """
    _get_redis().set(path, pickle.dumps(data))


def save_file_rediscluster(path, data):
    """
    Overview:
        Save pickle dumped data file to rediscluster
    Arguments:
        - path (:obj:`str`): File path (could be a string key) in redis
        - data (:obj:`Any`): Could be dict, list or tensor etc.
    """
    _ensure_rediscluster()
    data = pickle.dumps(data)
    _redis_cluster.set(path, data)
    return


def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object:
    """
    Overview:
        Read file from path
    Arguments:
        - path (:obj:`str`): The path of file to read
        - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
        - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
    """
    if fs_type is None:
        if path.lower().startswith('s3'):
            fs_type = 'ceph'
        elif get_mc_package() is not None:
            fs_type = 'mc'
        else:
            fs_type = 'normal'
    assert fs_type in ['normal', 'ceph', 'mc']
    if fs_type == 'ceph':
        data = read_from_path(path)
    elif fs_type == 'normal':
        if use_lock:
            with get_file_lock(path, 'read'):
                data = torch.load(path, map_location='cpu')
        else:
            data = torch.load(path, map_location='cpu')
    elif fs_type == 'mc':
        data = read_from_mc(path)
    return data


def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None:
    """
    Overview:
        Save data to file of path
    Arguments:
        - path (:obj:`str`): The path of file to save to
        - data (:obj:`object`): The data to save
        - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
        - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
    """
    if fs_type is None:
        if path.lower().startswith('s3'):
            fs_type = 'ceph'
        elif get_mc_package() is not None:
            fs_type = 'mc'
        else:
            fs_type = 'normal'
    assert fs_type in ['normal', 'ceph', 'mc']
    if fs_type == 'ceph':
        save_file_ceph(path, data)
    elif fs_type == 'normal':
        if use_lock:
            with get_file_lock(path, 'write'):
                torch.save(data, path)
        else:
            torch.save(data, path)
    elif fs_type == 'mc':
        torch.save(data, path)
        read_from_mc(path, flush=True)


def remove_file(path: str, fs_type: Union[None, str] = None) -> None:
    """
    Overview:
        Remove file
    Arguments:
        - path (:obj:`str`): The path of file you want to remove
        - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
    """
    if fs_type is None:
        fs_type = 'ceph' if path.lower().startswith('s3') else 'normal'
    assert fs_type in ['normal', 'ceph']
    if fs_type == 'ceph':
        os.popen("aws s3 rm --recursive {}".format(path))
    elif fs_type == 'normal':
        os.popen("rm -rf {}".format(path))