from abc import ABC, abstractmethod

import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal


class MammalObjectBroker:
    def __init__(
        self,
        model_path: str,
        name: str | None = None,
        task_list: list[str] | None = None,
        *,
        force_preload=False,
    ) -> None:
        self.model_path = model_path
        if name is None:
            name = model_path
        self.name = name

        self.tasks: list[str] = []
        if task_list is not None:
            self.tasks = task_list
        self._model: Mammal | None = None
        self._tokenizer_op = None
        if force_preload:
            self.force_preload()

    @property
    def model(self) -> Mammal:
        if self._model is None:
            self._model = Mammal.from_pretrained(self.model_path)
            self._model.eval()
        return self._model

    @property
    def tokenizer_op(self):
        if self._tokenizer_op is None:
            self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
        return self._tokenizer_op

    def force_preload(self):
        """pre-load the model and tokenizer (in this order)"""
        _ = self.model
        _ = self.tokenizer_op


class MammalTask(ABC):
    def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
        self.name = name
        self.description = None
        self._demo = None
        self.model_dict = model_dict

    @abstractmethod
    def crate_sample_dict(
        self, sample_inputs: dict, model_holder: MammalObjectBroker
    ) -> dict:
        """Formatting prompt to match pre-training syntax

        Args:
            prompt (str): _description_

        Returns:
            dict: sample_dict for feeding into model
        """
        raise NotImplementedError()

    # @abstractmethod
    def run_model(self, sample_dict, model: Mammal):
        raise NotImplementedError()

    def create_demo(self, model_name_widget: gr.component) -> gr.Group:
        """create an gradio demo group

        Args:
            model_name_widgit (gr.Component): widget holding the model name to use.  This is needed to create
                gradio actions with the current model name as an input


        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError()

    def demo(self, model_name_widgit: gr.component = None):
        if self._demo is None:
            self._demo = self.create_demo(model_name_widget=model_name_widgit)
        return self._demo

    @abstractmethod
    def decode_output(self, batch_dict, model: Mammal) -> list:
        raise NotImplementedError()

    # classification helpers
    @staticmethod
    def positive_token_id(tokenizer_op: ModularTokenizerOp) -> int:
        """token for positive binding

        Args:
            model (MammalTrainedModel): model holding tokenizer

        Returns:
            int: id of positive binding token
        """
        return tokenizer_op.get_token_id("<1>")

    @staticmethod
    def negative_token_id(tokenizer_op: ModularTokenizerOp) -> int:
        """token for negative binding

        Args:
            model (MammalTrainedModel): model holding tokenizer

        Returns:
            int: id of negative binding token
        """
        return tokenizer_op.get_token_id("<0>")

    @staticmethod
    def get_label_from_token(tokenizer_op: ModularTokenizerOp, token_id):

        label_mapping = {
            MammalTask.negative_token_id(tokenizer_op): "negative",
            MammalTask.positive_token_id(tokenizer_op): "positive",
        }
        return label_mapping.get(token_id, token_id)


class TaskRegistry(dict[str, MammalTask]):
    """just a dictionary with a register method"""

    def register_task(self, task: MammalTask):
        self[task.name] = task
        return task.name


class ModelRegistry(dict[str, MammalObjectBroker]):
    """just a dictionary with a register models"""

    def register_model(
        self, model_path, task_list=None, name=None, *, force_preload=False
    ):
        """register a model and return the name of the model
        Args:
            model_path (_type_): _description_
            name (optional str): explicit name for the model

        Returns:
            str: model name
        """
        model_holder = MammalObjectBroker(
            model_path=model_path,
            task_list=task_list,
            name=name,
            force_preload=force_preload,
        )
        self[model_holder.name] = model_holder
        return model_holder.name