from collections import namedtuple

import numpy as np
import pytest
import torch
import treetensor.torch as ttorch

from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper, \
    list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict, RunningMeanStd, \
    one_time_warning, split_data_generator, get_shape0


@pytest.mark.unittest
class TestDefaultHelper():

    def test_get_shape0(self):
        a = {
            'a': {
                'b': torch.randn(4, 3)
            },
            'c': {
                'd': torch.randn(4)
            },
        }
        b = [a, a]
        c = (a, a)
        d = {
            'a': {
                'b': ["a", "b", "c", "d"]
            },
            'c': {
                'd': torch.randn(4)
            },
        }
        a = ttorch.as_tensor(a)
        assert get_shape0(a) == 4
        assert get_shape0(b) == 4
        assert get_shape0(c) == 4
        with pytest.raises(Exception) as e_info:
            assert get_shape0(d) == 4

    def test_lists_to_dicts(self):
        set_pkg_seed(12)
        with pytest.raises(ValueError):
            lists_to_dicts([])
        with pytest.raises(TypeError):
            lists_to_dicts([1])
        assert lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) == {1: [1, 2], 10: [3, 4]}
        T = namedtuple('T', ['location', 'race'])
        data = [T({'x': 1, 'y': 2}, 'zerg') for _ in range(3)]
        output = lists_to_dicts(data)
        assert isinstance(output, T) and output.__class__ == T
        assert len(output.location) == 3
        data = [{'value': torch.randn(1), 'obs': {'scalar': torch.randn(4)}} for _ in range(3)]
        output = lists_to_dicts(data, recursive=True)
        assert isinstance(output, dict)
        assert len(output['value']) == 3
        assert len(output['obs']['scalar']) == 3

    def test_dicts_to_lists(self):
        assert dicts_to_lists({1: [1, 2], 10: [3, 4]}) == [{1: 1, 10: 3}, {1: 2, 10: 4}]

    def test_squeeze(self):
        assert squeeze((4, )) == 4
        assert squeeze({'a': 4}) == 4
        assert squeeze([1, 3]) == (1, 3)
        data = np.random.randn(3)
        output = squeeze(data)
        assert (output == data).all()

    def test_default_get(self):
        assert default_get({}, 'a', default_value=1, judge_fn=lambda x: x < 2) == 1
        assert default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 2) == 1
        with pytest.raises(AssertionError):
            default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 0)
        assert default_get({'val': 1}, 'val', default_value=2) == 1

    def test_override(self):

        class foo(object):

            def fun(self):
                raise NotImplementedError

        class foo1(foo):

            @override(foo)
            def fun(self):
                return "a"

        with pytest.raises(NameError):

            class foo2(foo):

                @override(foo)
                def func(self):
                    pass

        with pytest.raises(NotImplementedError):
            foo().fun()
        foo1().fun()

    def test_error_wrapper(self):

        def good_ret(a, b=1):
            return a + b

        wrap_good_ret = error_wrapper(good_ret, 0)
        assert good_ret(1) == wrap_good_ret(1)

        def bad_ret(a, b=0):
            return a / b

        wrap_bad_ret = error_wrapper(bad_ret, 0)
        assert wrap_bad_ret(1) == 0
        wrap_bad_ret_with_customized_log = error_wrapper(bad_ret, 0, 'customized_information')

    def test_list_split(self):
        data = [i for i in range(10)]
        output, residual = list_split(data, step=4)
        assert len(output) == 2
        assert output[1] == [4, 5, 6, 7]
        assert residual == [8, 9]
        output, residual = list_split(data, step=5)
        assert len(output) == 2
        assert output[1] == [5, 6, 7, 8, 9]
        assert residual is None


@pytest.mark.unittest
class TestLimitedSpaceContainer():

    def test_container(self):
        container = LimitedSpaceContainer(0, 5)
        first = container.acquire_space()
        assert first
        assert container.cur == 1
        left = container.get_residual_space()
        assert left == 4
        assert container.cur == container.max_val == 5
        no_space = container.acquire_space()
        assert not no_space
        container.increase_space()
        six = container.acquire_space()
        assert six
        for i in range(6):
            container.release_space()
            assert container.cur == 5 - i
        container.decrease_space()
        assert container.max_val == 5


@pytest.mark.unittest
class TestDict:

    def test_deep_merge_dicts(self):
        dict1 = {
            'a': 3,
            'b': {
                'c': 3,
                'd': {
                    'e': 6,
                    'f': 5,
                }
            }
        }
        dict2 = {
            'b': {
                'c': 5,
                'd': 6,
                'g': 4,
            }
        }
        new_dict = deep_merge_dicts(dict1, dict2)
        assert new_dict['a'] == 3
        assert isinstance(new_dict['b'], dict)
        assert new_dict['b']['c'] == 5
        assert new_dict['b']['c'] == 5
        assert new_dict['b']['g'] == 4

    def test_deep_update(self):
        dict1 = {
            'a': 3,
            'b': {
                'c': 3,
                'd': {
                    'e': 6,
                    'f': 5,
                },
                'z': 4,
            }
        }
        dict2 = {
            'b': {
                'c': 5,
                'd': 6,
                'g': 4,
            }
        }
        with pytest.raises(RuntimeError):
            new1 = deep_update(dict1, dict2, new_keys_allowed=False)
        new2 = deep_update(dict1, dict2, new_keys_allowed=False, whitelist=['b'])
        assert new2['a'] == 3
        assert new2['b']['c'] == 5
        assert new2['b']['d'] == 6
        assert new2['b']['g'] == 4
        assert new2['b']['z'] == 4

        dict1 = {
            'a': 3,
            'b': {
                'type': 'old',
                'z': 4,
            }
        }
        dict2 = {
            'b': {
                'type': 'new',
                'c': 5,
            }
        }
        new3 = deep_update(dict1, dict2, new_keys_allowed=True, whitelist=[], override_all_if_type_changes=['b'])
        assert new3['a'] == 3
        assert new3['b']['type'] == 'new'
        assert new3['b']['c'] == 5
        assert 'z' not in new3['b']

    def test_flatten_dict(self):
        dict = {
            'a': 3,
            'b': {
                'c': 3,
                'd': {
                    'e': 6,
                    'f': 5,
                },
                'z': 4,
            }
        }
        flat = flatten_dict(dict)
        assert flat['a'] == 3
        assert flat['b/c'] == 3
        assert flat['b/d/e'] == 6
        assert flat['b/d/f'] == 5
        assert flat['b/z'] == 4

    def test_one_time_warning(self):
        one_time_warning('test_one_time_warning')

    def test_running_mean_std(self):
        running = RunningMeanStd()
        running.reset()
        running.update(np.arange(1, 10))
        assert running.mean == pytest.approx(5, abs=1e-4)
        assert running.std == pytest.approx(2.582030, abs=1e-6)
        running.update(np.arange(2, 11))
        assert running.mean == pytest.approx(5.5, abs=1e-4)
        assert running.std == pytest.approx(2.629981, abs=1e-6)
        running.reset()
        running.update(np.arange(1, 10))
        assert pytest.approx(running.mean, abs=1e-4) == 5
        assert running.mean == pytest.approx(5, abs=1e-4)
        assert running.std == pytest.approx(2.582030, abs=1e-6)
        new_shape = running.new_shape((2, 4), (3, ), (1, ))
        assert isinstance(new_shape, tuple) and len(new_shape) == 3

        running = RunningMeanStd(shape=(4, ))
        running.reset()
        running.update(np.random.random((10, 4)))
        assert isinstance(running.mean, torch.Tensor) and running.mean.shape == (4, )
        assert isinstance(running.std, torch.Tensor) and running.std.shape == (4, )

    def test_split_data_generator(self):

        def get_data():
            return {
                'obs': torch.randn(5),
                'action': torch.randint(0, 10, size=(1, )),
                'prev_state': [None, None],
                'info': {
                    'other_obs': torch.randn(5)
                },
            }

        data = [get_data() for _ in range(4)]
        data = lists_to_dicts(data)
        data['obs'] = torch.stack(data['obs'])
        data['action'] = torch.stack(data['action'])
        data['info'] = {'other_obs': torch.stack([t['other_obs'] for t in data['info']])}
        assert len(data['obs']) == 4
        data['NoneKey'] = None
        generator = split_data_generator(data, 3)
        generator_result = list(generator)
        assert len(generator_result) == 2
        assert generator_result[0]['NoneKey'] is None
        assert len(generator_result[0]['obs']) == 3
        assert generator_result[0]['info']['other_obs'].shape == (3, 5)
        assert generator_result[1]['NoneKey'] is None
        assert len(generator_result[1]['obs']) == 3
        assert generator_result[1]['info']['other_obs'].shape == (3, 5)

        generator = split_data_generator(data, 3, shuffle=False)