"""
aggregate.py - module for 'reducing' multiple 'summary chunks' into one

an overly complicated class for legacy compatibility reasons, for usage of the
2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage
"""

import logging
import pprint as pp
import time

import torch
from transformers import GenerationConfig, pipeline

# Setting up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


class BatchAggregator:
    """
    BatchAggregator is a class for aggregating text from multiple sources.

    Usage:
        from aggregate import BatchAggregator
        aggregator = BatchAggregator()
        agg = aggregator.infer_aggregate(["This is a test", "This is another test"])
        print(agg)
    """

    GENERIC_CONFIG = GenerationConfig(
        max_new_tokens=512,
        num_beams=4,
        early_stopping=True,
        do_sample=False,
        truncation=True,
    )

    def __init__(
        self,
        model_name: str = "pszemraj/bart-large-summary-map-reduce",
        force_cpu: bool = False,
        **kwargs,
    ):
        """
        __init__ initializes the BatchAggregator class.

        :param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce"
        :param bool force_cpu: force the model to run on CPU, default: False
        """
        self.device = None
        self.is_compiled = False
        self.model_name = None
        self.aggregator = None
        self.force_cpu = force_cpu
        self.logger = logging.getLogger(__name__)
        self.init_model(model_name)

    def init_model(self, model_name: str) -> None:
        """
        Initialize the model.

        :param model_name: The name of the model to use.
        """
        # Free up memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        self.logger.info(f"Setting model to {model_name}")
        self.model_name = model_name
        self.aggregator = self._create_pipeline(model_name)
        self._configure_model()

    def _create_pipeline(
        self, model_name: str = "pszemraj/bart-large-summary-map-reduce"
    ) -> pipeline:
        """
        _create_pipeline creates a pipeline for the model.

        :param str model_name: model name to use
        :return pipeline: the pipeline for the model

        :raises Exception: if the pipeline cannot be created
        """
        device_map = (
            "auto" if torch.cuda.is_available() and not self.force_cpu else "cpu"
        )
        try:
            self.logger.info(
                f"Creating pipeline with model {model_name} on device {device_map}"
            )
            return pipeline(
                "text2text-generation",
                model=model_name,
                device_map=device_map,
                torch_dtype=torch.float32,
            )
        except Exception as e:
            self.logger.error(f"Failed to create pipeline: {e}")
            raise

    def _configure_model(self):
        """
        Configure the model for generation.
        """
        try:
            self.aggregator.model = torch.compile(self.aggregator.model)
            self.is_compiled = True
        except Exception as e:
            self.logger.warning(f"Could not compile model with Torch 2.0: {e}")

        self._set_default_generation_config()
        self.logger.info(self.aggregator.model.generation_config.to_json_string())

    def _set_default_generation_config(self):
        """
        Set the default generation configuration for the model.
        """
        self.aggregator.model.generation_config.update(
            **self.GENERIC_CONFIG.to_diff_dict()
        )

    def update_generation_config(self, **kwargs):
        """
        Update the generation configuration with the specified parameters.

        Args:
            **kwargs: The parameters to update in the generation configuration.
        """
        self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
        self.aggregator.model.generation_config.update(**kwargs)

    def get_generation_config(self) -> dict:
        """
        Get the current generation configuration.

        Returns:
            dict: The current generation configuration.
        """
        return self.aggregator.model.generation_config.to_dict()

    def update_loglevel(self, level: str = "INFO"):
        """
        Update the log level.

        Args:
            level (str): The log level to set. Defaults to "INFO".
        """
        self.logger.setLevel(level)

    def infer_aggregate(
        self,
        text_list: list,
        instruction: str = None,  # Kept for backward compatibility but not used
        **kwargs,
    ) -> str:
        """
        infer_aggregate - infers a consolidated summary from a list of texts.

        Args:
            text_list (list): The texts to summarize.
            instruction (str): Not used by this model, kept for compatibility.
            **kwargs: Additional parameters to update in the generation configuration.

        Returns:
            The generated summary.
        """
        joined_text = "\n\n".join(text_list)
        if kwargs:
            self.update_generation_config(**kwargs)
        st = time.perf_counter()
        self.logger.info(f"inference on {len(text_list)} texts ...")
        result = self.aggregator(
            joined_text,
            generation_config=self.aggregator.model.generation_config,
        )[0]["generated_text"]
        self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
        self.logger.info(
            f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}"
        )
        self.logger.debug(f"Generated text:\n{result}")

        return result

    def count_tokens(self, text: str) -> int:
        """count the number of tokens in a text"""
        return (
            len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
            if text
            else 0
        )