from collections import OrderedDict, deque
from typing import Any, NamedTuple
import os

import dm_env
import numpy as np
from dm_env import StepType, specs

import gym
import torch

class ExtendedTimeStep(NamedTuple):
    step_type: Any
    reward: Any
    discount: Any
    observation: Any
    action: Any

    def first(self):
        return self.step_type == StepType.FIRST

    def mid(self):
        return self.step_type == StepType.MID

    def last(self):
        return self.step_type == StepType.LAST

    def __getitem__(self, attr):
        return getattr(self, attr)


class FlattenJacoObservationWrapper(dm_env.Environment):
    def __init__(self, env):
        self._env = env
        self._obs_spec = OrderedDict()
        wrapped_obs_spec = env.observation_spec().copy()
        if 'front_close' in wrapped_obs_spec:
            spec = wrapped_obs_spec['front_close']
            # drop batch dim
            self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:],
                                                          dtype=spec.dtype,
                                                          minimum=spec.minimum,
                                                          maximum=spec.maximum,
                                                          name='pixels')
            wrapped_obs_spec.pop('front_close')

        for key, spec in wrapped_obs_spec.items():
            assert spec.dtype == np.float64
            assert type(spec) == specs.Array
        dim = np.sum(
            np.fromiter((int(np.prod(spec.shape))
                         for spec in wrapped_obs_spec.values()), np.int32))

        self._obs_spec['observations'] = specs.Array(shape=(dim,),
                                                     dtype=np.float32,
                                                     name='observations')

    def _transform_observation(self, time_step):
        obs = OrderedDict()

        if 'front_close' in time_step.observation:
            pixels = time_step.observation['front_close']
            time_step.observation.pop('front_close')
            pixels = np.squeeze(pixels)
            obs['pixels'] = pixels

        features = []
        for feature in time_step.observation.values():
            features.append(feature.ravel())
        obs['observations'] = np.concatenate(features, axis=0)
        return time_step._replace(observation=obs)

    def reset(self):
        time_step = self._env.reset()
        return self._transform_observation(time_step)

    def step(self, action):
        time_step = self._env.step(action)
        return self._transform_observation(time_step)

    def observation_spec(self):
        return self._obs_spec

    def action_spec(self):
        return self._env.action_spec()

    def __getattr__(self, name):
        return getattr(self._env, name)


class ActionRepeatWrapper(dm_env.Environment):
    def __init__(self, env, num_repeats):
        self._env = env
        self._num_repeats = num_repeats

    def step(self, action):
        reward = 0.0
        discount = 1.0
        for i in range(self._num_repeats):
            time_step = self._env.step(action)
            reward += (time_step.reward or 0.0) * discount
            discount *= time_step.discount
            if time_step.last():
                break

        return time_step._replace(reward=reward, discount=discount)

    def observation_spec(self):
        return self._env.observation_spec()

    def action_spec(self):
        return self._env.action_spec()

    def reset(self):
        return self._env.reset()

    def __getattr__(self, name):
        return getattr(self._env, name)


class FramesWrapper(dm_env.Environment):
    def __init__(self, env, num_frames=1, pixels_key='pixels'):
        self._env = env
        self._num_frames = num_frames
        self._frames = deque([], maxlen=num_frames)
        self._pixels_key = pixels_key

        wrapped_obs_spec = env.observation_spec()
        assert pixels_key in wrapped_obs_spec

        pixels_shape = wrapped_obs_spec[pixels_key].shape
        # remove batch dim
        if len(pixels_shape) == 4:
            pixels_shape = pixels_shape[1:]
        self._obs_spec = specs.BoundedArray(shape=np.concatenate(
            [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
                                            dtype=np.uint8,
                                            minimum=0,
                                            maximum=255,
                                            name='observation')

    def _transform_observation(self, time_step):
        assert len(self._frames) == self._num_frames
        obs = np.concatenate(list(self._frames), axis=0)
        return time_step._replace(observation=obs)

    def _extract_pixels(self, time_step):
        pixels = time_step.observation[self._pixels_key]
        # remove batch dim
        if len(pixels.shape) == 4:
            pixels = pixels[0]
        return pixels.transpose(2, 0, 1).copy()

    def reset(self):
        time_step = self._env.reset()
        pixels = self._extract_pixels(time_step)
        for _ in range(self._num_frames):
            self._frames.append(pixels)
        return self._transform_observation(time_step)

    def step(self, action):
        time_step = self._env.step(action)
        pixels = self._extract_pixels(time_step)
        self._frames.append(pixels)
        return self._transform_observation(time_step)

    def observation_spec(self):
        return self._obs_spec

    def action_spec(self):
        return self._env.action_spec()

    def __getattr__(self, name):
        return getattr(self._env, name)

class OneHotAction(gym.Wrapper):
    def __init__(self, env):
        assert isinstance(env.action_space, gym.spaces.Discrete)
        super().__init__(env)
        self._random = np.random.RandomState()
        shape = (self.env.action_space.n,)
        space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
        space.discrete = True
        self.action_space = space

    def step(self, action):
        index = np.argmax(action).astype(int)
        reference = np.zeros_like(action)
        reference[index] = 1
        if not np.allclose(reference, action):
            raise ValueError(f"Invalid one-hot action:\n{action}")
        return self.env.step(index)

    def reset(self):
        return self.env.reset()

    def _sample_action(self):
        actions = self.env.action_space.n
        index = self._random.randint(0, actions)
        reference = np.zeros(actions, dtype=np.float32)
        reference[index] = 1.0
        return reference

class ActionDTypeWrapper(dm_env.Environment):
    def __init__(self, env, dtype):
        self._env = env
        wrapped_action_spec = env.action_spec()
        self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
                                               dtype,
                                               wrapped_action_spec.minimum,
                                               wrapped_action_spec.maximum,
                                               'action')

    def step(self, action):
        action = action.astype(self._env.action_spec().dtype)
        return self._env.step(action)

    def observation_spec(self):
        return self._env.observation_spec()

    def action_spec(self):
        return self._action_spec

    def reset(self):
        return self._env.reset()

    def __getattr__(self, name):
        return getattr(self._env, name)


class ObservationDTypeWrapper(dm_env.Environment):
    def __init__(self, env, dtype):
        self._env = env
        self._dtype = dtype
        wrapped_obs_spec = env.observation_spec()['observations']
        self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype,
                                     'observation')

    def _transform_observation(self, time_step):
        obs = time_step.observation['observations'].astype(self._dtype)
        return time_step._replace(observation=obs)

    def reset(self):
        time_step = self._env.reset()
        return self._transform_observation(time_step)

    def step(self, action):
        time_step = self._env.step(action)
        return self._transform_observation(time_step)

    def observation_spec(self):
        return self._obs_spec

    def action_spec(self):
        return self._env.action_spec()

    def __getattr__(self, name):
        return getattr(self._env, name)


class ExtendedTimeStepWrapper(dm_env.Environment):
    def __init__(self, env):
        self._env = env

    def reset(self):
        time_step = self._env.reset()
        return self._augment_time_step(time_step)

    def step(self, action):
        time_step = self._env.step(action)
        return self._augment_time_step(time_step, action)

    def _augment_time_step(self, time_step, action=None):
        if action is None:
            action_spec = self.action_spec()
            action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
        return ExtendedTimeStep(observation=time_step.observation,
                                step_type=time_step.step_type,
                                action=action,
                                reward=time_step.reward or 0.0,
                                discount=time_step.discount or 1.0)

    def observation_spec(self):
        return self._env.observation_spec()

    def action_spec(self):
        return self._env.action_spec()

    def __getattr__(self, name):
        return getattr(self._env, name)

class DMC:
  def __init__(self, env):
    self._env = env 
    self._ignored_keys = []

  def step(self, action):
    time_step = self._env.step(action)
    assert time_step.discount in (0, 1)
    obs = {
        'reward': time_step.reward,
        'is_first': False,
        'is_last': time_step.last(),
        'is_terminal': time_step.discount == 0,
        'observation': time_step.observation,
        'action' : action,
        'discount': time_step.discount
    }
    return time_step, obs 

  def reset(self):
    time_step = self._env.reset()
    obs = {
        'reward': 0.0,
        'is_first': True,
        'is_last': False,
        'is_terminal': False,
        'observation': time_step.observation,
        'action' : np.zeros_like(self.act_space['action'].sample()),
        'discount': time_step.discount
    }
    return time_step, obs

  def __getattr__(self, name):
    if name == 'obs_space':
        obs_spaces = {
            'observation': self._env.observation_spec(), 
            'is_first': gym.spaces.Box(0, 1, (), dtype=bool),
            'is_last': gym.spaces.Box(0, 1, (), dtype=bool),
            'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool),
        }
        return obs_spaces
    if name == 'act_space':
        spec = self._env.action_spec()
        action = gym.spaces.Box((spec.minimum)*spec.shape[0], (spec.maximum)*spec.shape[0], shape=spec.shape, dtype=np.float32)
        act_space = {'action': action}
        return act_space
    return getattr(self._env, name)


class OneHotAction(gym.Wrapper):
    def __init__(self, env):
        assert isinstance(env.action_space, gym.spaces.Discrete)
        super().__init__(env)
        self._random = np.random.RandomState()
        shape = (self.env.action_space.n,)
        space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
        space.discrete = True
        self.action_space = space

    def step(self, action):
        index = np.argmax(action).astype(int)
        reference = np.zeros_like(action)
        reference[index] = 1
        if not np.allclose(reference, action):
            raise ValueError(f"Invalid one-hot action:\n{action}")
        return self.env.step(index)

    def reset(self):
        return self.env.reset()

    def _sample_action(self):
        actions = self.env.action_space.n
        index = self._random.randint(0, actions)
        reference = np.zeros(actions, dtype=np.float32)
        reference[index] = 1.0
        return reference

class KitchenWrapper:
    def __init__(
        self,
        name,
        seed=0,
        action_repeat=1,
        size=(64, 64),
    ):
        import envs.kitchen_extra as kitchen_extra
        self._env  = {
            'microwave' : kitchen_extra.KitchenMicrowaveV0,
            'kettle' : kitchen_extra.KitchenKettleV0,
            'burner' : kitchen_extra.KitchenBurnerV0,
            'light'  : kitchen_extra.KitchenLightV0,
            'hinge'  : kitchen_extra.KitchenHingeV0,
            'slide'  : kitchen_extra.KitchenSlideV0,
            'top_burner' : kitchen_extra.KitchenTopBurnerV0,
        }[name]()
            
        self._size = size
        self._action_repeat = action_repeat
        self._seed = seed
        self._eval = False

    def eval_mode(self,):
        self._env.dense = False
        self._eval = True

    @property
    def obs_space(self):
        spaces = {
            "observation": gym.spaces.Box(0, 255, (3,) + self._size, dtype=np.uint8),
            "is_first": gym.spaces.Box(0, 1, (), dtype=bool),
            "is_last": gym.spaces.Box(0, 1, (), dtype=bool),
            "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
            "state": self._env.observation_space,
        }
        return spaces

    @property
    def act_space(self):
        action = self._env.action_space
        return {"action": action}

    def step(self, action):
        # assert np.isfinite(action["action"]).all(), action["action"]
        reward = 0.0
        for _ in range(self._action_repeat):
            state, rew, done, info = self._env.step(action.copy())
            reward += rew 
        obs = {
            "reward": reward,
            "is_first": False,
            "is_last": False,  # will be handled by timelimit wrapper
            "is_terminal": False,  # will be handled by per_episode function
            "observation": info['images'].transpose(2, 0, 1).copy(),
            "state": state.astype(np.float32),
            'action' : action,
            'discount' : 1
        }
        if self._eval:
            obs['reward'] = min(obs['reward'], 1)
            if obs['reward'] > 0:
                obs['is_last'] = True
        return dm_env.TimeStep(
                step_type=dm_env.StepType.MID if not obs['is_last'] else dm_env.StepType.LAST, 
                reward=obs['reward'],
                discount=1,
                observation=obs['observation']), obs

    def reset(self,):
        state = self._env.reset()
        obs = {
            "reward": 0.0,
            "is_first": True,
            "is_last": False,
            "is_terminal": False,
            "observation": self.get_visual_obs(self._size),
            "state": state.astype(np.float32),
            'action' : np.zeros_like(self.act_space['action'].sample()),
            'discount' : 1
        }
        return dm_env.TimeStep(
                step_type=dm_env.StepType.FIRST,
                reward=None,
                discount=None,
                observation=obs['observation']), obs

    def __getattr__(self, name):
        if name == 'obs_space':
            return self.obs_space
        if name == 'act_space':
            return self.act_space
        return getattr(self._env, name)
    
    def get_visual_obs(self, resolution):
        img = self._env.render(resolution=resolution,).transpose(2, 0, 1).copy()
        return img

class ViClipWrapper:
    def __init__(self, env, hd_rendering=False, device='cuda'):
        self._env = env
        try:
            from tools.genrl_utils import viclip_global_instance
        except:
            from tools.genrl_utils import ViCLIPGlobalInstance
            viclip_global_instance = ViCLIPGlobalInstance()

        if not viclip_global_instance._instantiated:
            viclip_global_instance.instantiate(device)
        self.viclip_model = viclip_global_instance.viclip
        self.n_frames = self.viclip_model.n_frames
        self.viclip_emb_dim = viclip_global_instance.viclip_emb_dim
        self.n_frames = self.viclip_model.n_frames
        self.buffer = deque(maxlen=self.n_frames)
        # NOTE: these are hardcoded for now, as they are the best settings
        self.accumulate = True
        self.accumulate_buffer = []
        self.anticipate_conv1 = False
        self.hd_rendering = hd_rendering

    def hd_render(self, obs):
        if not self.hd_rendering:
            return obs['observation']
        if self._env._domain_name in ['mw', 'kitchen', 'mujoco']:
            return self.get_visual_obs((224,224,))
        else:
            render_kwargs = {**getattr(self, '_render_kwargs', {})}
            render_kwargs.update({'width' : 224, 'height' : 224})
            return self._env.physics.render(**render_kwargs).transpose(2,0,1)

    def preprocess(self, x):
        return x

    def process_accumulate(self, process_at_once=4): # NOTE: this could be varied for increasing FPS, depending on the size of the GPU
        self.accumulate = False
        x = np.stack(self.accumulate_buffer, axis=0)
        # Splitting in chunks
        chunks = []
        chunk_idxs = list(range(0, x.shape[0] + 1, process_at_once))
        if chunk_idxs[-1] != x.shape[0]:
            chunk_idxs.append(x.shape[0])
        start = 0
        for end in chunk_idxs[1:]:
            embeds = self.clip_process(x[start:end], bypass=True)
            chunks.append(embeds.cpu())
            start = end
        embeds = torch.cat(chunks, dim=0)
        assert embeds.shape[0] == len(self.accumulate_buffer)
        self.accumulate = True
        self.accumulate_buffer = []
        return [*embeds.cpu().numpy()], 'clip_video'
    
    def process_episode(self, obs, process_at_once=8):
        self.accumulate = False
        sequences = []
        for j in range(obs.shape[0] - self.n_frames + 1):
            sequences.append(obs[j:j+self.n_frames].copy())
        sequences = np.stack(sequences, axis=0)

        idx_start = 0
        clip_vid = []
        for idx_end in range(process_at_once, sequences.shape[0] + process_at_once, process_at_once):
            x = sequences[idx_start:idx_end]
            with torch.no_grad(): # , torch.cuda.amp.autocast():
                x = self.clip_process(x, bypass=True) 
            clip_vid.append(x)
            idx_start = idx_end
        if len(clip_vid) == 1: # process all at once
            embeds = clip_vid[0]
        else:
            embeds = torch.cat(clip_vid, dim=0)
        pad = torch.zeros( (self.n_frames - 1, *embeds.shape[1:]), device=embeds.device, dtype=embeds.dtype)
        embeds = torch.cat([pad, embeds], dim=0)
        assert embeds.shape[0] == obs.shape[0], f"Shapes are different {embeds.shape[0]} {obs.shape[0]}"
        return embeds.cpu().numpy()

    def get_sequence(self,):
        return np.expand_dims(np.stack(self.buffer, axis=0), axis=0)
    
    def clip_process(self, x, bypass=False):
        if len(self.buffer) == self.n_frames or bypass:
            if self.accumulate:
                self.accumulate_buffer.append(self.preprocess(x)[0])
                return torch.zeros(self.viclip_emb_dim)
            with torch.no_grad():
                B, n_frames, C, H, W = x.shape
                obs = torch.from_numpy(x.copy().reshape(B * n_frames, C, H, W)).to(self.viclip_model.device)
                processed_obs = self.viclip_model.preprocess_transf(obs / 255)
                reshaped_obs = processed_obs.reshape(B, n_frames, 3,processed_obs.shape[-2],processed_obs.shape[-1])
                video_embed = self.viclip_model.get_vid_features(reshaped_obs)
            return video_embed.detach()
        else:
            return torch.zeros(self.viclip_emb_dim)

    def step(self, action):
        ts, obs = self._env.step(action)
        self.buffer.append(self.hd_render(obs))
        obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy()
        return ts, obs

    def reset(self,):
        # Important to reset the buffer        
        self.buffer = deque(maxlen=self.n_frames)

        ts, obs = self._env.reset()
        self.buffer.append(self.hd_render(obs))
        obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy()
        return ts, obs

    def __getattr__(self, name):
        if name == 'obs_space':
            space = self._env.obs_space
            space['clip_video'] = gym.spaces.Box(-np.inf, np.inf, (self.viclip_emb_dim,), dtype=np.float32)  
            return space
        return getattr(self._env, name)

class TimeLimit:

  def __init__(self, env, duration):
    self._env = env
    self._duration = duration
    self._step = None

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    return getattr(self._env, name)

  def step(self, action):
    assert self._step is not None, 'Must reset environment.'
    ts, obs = self._env.step(action)
    self._step += 1
    if self._duration and self._step >= self._duration:
      ts = dm_env.TimeStep(dm_env.StepType.LAST, ts.reward, ts.discount, ts.observation)
      obs['is_last'] = True
      self._step = None
    return ts, obs

  def reset(self):
    self._step = 0
    return self._env.reset()

  def reset_with_task_id(self, task_id):
    self._step = 0
    return self._env.reset_with_task_id(task_id)
  
class ClipActionWrapper:

  def __init__(self, env, low=-1.0, high=1.0):
    self._env = env
    self._low = low
    self._high = high

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    return getattr(self._env, name)

  def step(self, action):
    clipped_action = np.clip(action, self._low, self._high)
    return self._env.step(clipped_action)

  def reset(self):
    self._step = 0
    return self._env.reset()

  def reset_with_task_id(self, task_id):
    self._step = 0
    return self._env.reset_with_task_id(task_id)

class NormalizeAction:

  def __init__(self, env, key='action'):
    self._env = env
    self._key = key
    space = env.act_space[key]
    self._mask = np.isfinite(space.low) & np.isfinite(space.high)
    self._low = np.where(self._mask, space.low, -1)
    self._high = np.where(self._mask, space.high, 1)

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  @property
  def act_space(self):
    low = np.where(self._mask, -np.ones_like(self._low), self._low)
    high = np.where(self._mask, np.ones_like(self._low), self._high)
    space = gym.spaces.Box(low, high, dtype=np.float32)
    return {**self._env.act_space, self._key: space}

  def step(self, action):
    orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low
    orig = np.where(self._mask, orig, action[self._key])
    return self._env.step({**action, self._key: orig})

def _make_jaco(obs_type, domain, task, action_repeat, seed, img_size,):
    import envs.custom_dmc_tasks as cdmc
    env = cdmc.make_jaco(task, obs_type, seed, img_size,)
    env = ActionDTypeWrapper(env, np.float32)
    env = ActionRepeatWrapper(env, action_repeat)
    env = FlattenJacoObservationWrapper(env)
    env._size = (img_size, img_size)
    return env


def _make_dmc(obs_type, domain, task, action_repeat, seed, img_size,):
    visualize_reward = False
    from dm_control import manipulation, suite
    import envs.custom_dmc_tasks as cdmc

    if (domain, task) in suite.ALL_TASKS:
        env = suite.load(domain,
                         task,
                         task_kwargs=dict(random=seed),
                         environment_kwargs=dict(flat_observation=True),
                         visualize_reward=visualize_reward)
    else:
        env = cdmc.make(domain,
                        task,
                        task_kwargs=dict(random=seed),
                        environment_kwargs=dict(flat_observation=True),
                        visualize_reward=visualize_reward)
    env = ActionDTypeWrapper(env, np.float32)
    env = ActionRepeatWrapper(env, action_repeat)
    if obs_type == 'pixels':
        from dm_control.suite.wrappers import pixels
        # zoom in camera for quadruped
        camera_id = dict(locom_rodent=1,quadruped=2).get(domain, 0)
        render_kwargs = dict(height=img_size, width=img_size, camera_id=camera_id)
        env = pixels.Wrapper(env,
                             pixels_only=True,
                             render_kwargs=render_kwargs)
        env._size = (img_size, img_size)
        env._camera = camera_id
    return env


def make(name, obs_type, action_repeat, seed, img_size=64, viclip_encode=False, clip_hd_rendering=False, device='cuda'):
    assert obs_type in ['states', 'pixels']
    domain, task = name.split('_', 1)
    if domain == 'kitchen':
        env = TimeLimit(KitchenWrapper(task, seed=seed, action_repeat=action_repeat, size=(img_size,img_size)), 280 // action_repeat)
    else:
        os.environ['PYOPENGL_PLATFORM'] = 'egl' 
        os.environ['MUJOCO_GL'] = 'egl'

        domain = dict(cup='ball_in_cup', point='point_mass').get(domain, domain)

        make_fn = _make_jaco if domain == 'jaco' else _make_dmc
        env = make_fn(obs_type, domain, task, action_repeat, seed, img_size,)

        if obs_type == 'pixels':
            env = FramesWrapper(env,)
        else:
            env = ObservationDTypeWrapper(env, np.float32)

        from dm_control.suite.wrappers import action_scale
        env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
        env = ExtendedTimeStepWrapper(env)

        env =  DMC(env)
    env._domain_name = domain
    
    if isinstance(env.act_space['action'], gym.spaces.Box):
        env = ClipActionWrapper(env,)

    if viclip_encode:
        env = ViClipWrapper(env, hd_rendering=clip_hd_rendering, device=device)
    return env