import json
from typing import List
import logging
import os
import sys

sys.path.append(os.getcwd())
from ..base import Action, NON_FILE_TYPES

# from cllm.services.tog import TaskSolver, TaskDecomposer, config
# from cllm.services.nlp.llms import ChatOpenAI, MessageMemory
from cllm.services.tog.api import tog, task_decomposer
from collections import OrderedDict
import copy


logger = logging.getLogger(__name__)


class Planner:
    def __init__(
        self, streaming=False, backend="remote", device="cuda:0", **llm_kwargs
    ):
        self.streaming = streaming
        if backend == "local":
            pass
            # self.cfg = config
            # self.device = device
            # self.mem = MessageMemory(**self.cfg.memory)
            # self.llm = ChatOpenAI(temperature=0.2, **llm_kwargs)
            # self.tog = TaskSolver(self.llm, self.cfg.task_solver_config, device).solve
            # self.decomposer = TaskDecomposer(device, self.cfg.task_decomposer_cfg).solve
        elif backend == "remote":
            self.decomposer = task_decomposer
            self.tog = tog
        else:
            raise ValueError("Backend should be chosen from [remote, local]")

    def _find_latest_resource(self, resources, type):
        for key, val in list(resources.items())[::-1]:
            if val == type:
                return key
        return None

    def _check_task_decomposition(
        self, task_decomposition: str | list, available_resources: dict
    ):
        copy_task_decomposition = copy.deepcopy(task_decomposition)
        available_resources = copy.deepcopy(available_resources)
        if isinstance(copy_task_decomposition, str):
            copy_task_decomposition = json.loads(copy_task_decomposition)

        for subtask in copy_task_decomposition:
            for arg in subtask["args"]:
                if arg["type"] in NON_FILE_TYPES:
                    continue

                r_type = available_resources.get(arg["value"], "None").split(".")[-1]
                if arg["value"] not in available_resources or arg["type"] != r_type:
                    new_value = self._find_latest_resource(
                        available_resources, arg["type"]
                    )
                    if new_value is None:
                        logger.error(
                            f"No available resource for {arg['value']} with type {arg['type']}"
                        )
                        return None

                    arg["value"] = new_value

            available_resources[subtask["returns"][0]["value"]] = subtask["returns"][0][
                "type"
            ]
        return json.dumps(copy_task_decomposition, indent=2, ensure_ascii=False)

    def wrap_request(self, request, memory):
        logger.info(memory)
        resource_list = {k: v.split(".")[-1] for k, v in memory.items()}
        request = f"Resource list: {resource_list}\n{request}"
        logger.info(f"Input: {request}")
        return request

    def solve_streaming(self, request: str, memory: dict = OrderedDict()):
        request = self.wrap_request(request, memory)
        sub_tasks = self.decomposer(request, streaming=self.streaming)
        logger.info(f"Task decomposition: \n{sub_tasks}")
        sub_tasks = self._check_task_decomposition(sub_tasks, memory)
        yield sub_tasks
        if sub_tasks in [None, "", []]:
            yield None
        else:
            solutions = self.tog(request, sub_tasks, streaming=self.streaming)
            yield solutions

    def solve(self, request: str, memory: dict = OrderedDict()) -> List:
        self.wrap_request(request, memory)
        sub_tasks = self.decomposer(request)
        solutions = self.tog(request, sub_tasks)
        print(f"solutions: {solutions}")
        return sub_tasks, solutions

    def plan(self, task, memory: dict = OrderedDict()) -> List:
        if self.streaming:
            return self.solve_streaming(task, memory)
        else:
            return self.solve(task, memory)

    def _check_solutions(self, solution: List | str) -> bool:
        if isinstance(solution, str):
            solution = json.loads(solution)
        if len(solution) == 0:
            return False

        valid = True
        for i, stage_candiate in enumerate(solution):
            if len(stage_candiate) == 0:
                logger.error(f"No solution is found in {i}-th subtask.")
                valid = False
            elif (
                "solution" in stage_candiate[0]
                and len(stage_candiate[0]["solution"]) == 0
            ):
                logger.error(f"No solution is found in {i+1}-th subtask.")
                valid = False
            else:
                logger.info(f"Solutions for {i+1}-th subtask:\n{stage_candiate}")
        return valid

    def parse(self, solution: List | str) -> List[List[Action]]:
        if isinstance(solution, str):
            solution = json.loads(solution)

        if not self._check_solutions(solution):
            return None

        if isinstance(solution[0], Action):
            return solution

        stages = []
        for i, stage_candiate in enumerate(solution):
            stage = stage_candiate[0]["solution"]
            actions = []
            for action in stage:
                inputs = {arg["name"]: arg["value"] for arg in action["args"]}
                outputs = [r["value"] for r in action["returns"]]
                actions.append(
                    Action(action["tool_name"], inputs=inputs, outputs=outputs)
                )
            stages.append(actions)
        return stages

    def __call__(
        self, request: str, memory: dict = OrderedDict()
    ) -> List[List[Action]]:
        solution = self.solve(request, memory)
        return self.parse(solution)