import abc
from typing import List, Union

from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer

from .type_aliases import ENCODER_DEVICE_TYPE


class Encoder(abc.ABC):
    @abc.abstractmethod
    def encode(
        self,
        prediction: List[str],
        *,
        device: ENCODER_DEVICE_TYPE = "cpu",
        batch_size: int = 32,
        verbose: bool = False,
    ) -> NDArray:
        """
        Abstract method to encode a list of sentences into sentence embeddings.

        Args:
            prediction (List[str]): List of sentences to encode.
            device (Union[str, int, List[Union[str, int]]]): Device specification for encoding.
            batch_size (int): Batch size for encoding.
            verbose (bool): Whether to print verbose information during encoding.

        Returns:
            NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).

        Raises:
            NotImplementedError: If the method is not implemented in the subclass.
        """
        raise NotImplementedError("Method 'encode' must be implemented in subclass.")


class SBertEncoder(Encoder):
    def __init__(self, model_name: str):
        """
        Initialize SBertEncoder instance.

        Args:
            model_name (str): Name or path of the Sentence Transformer model.
        """
        self.model = SentenceTransformer(model_name, trust_remote_code=True)

    def encode(
        self,
        prediction: List[str],
        *,
        device: ENCODER_DEVICE_TYPE = "cpu",
        batch_size: int = 32,
        verbose: bool = False,
    ) -> NDArray:
        """
        Encode a list of sentences into sentence embeddings.

        Args:
            prediction (List[str]): List of sentences to encode.
            device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
            batch_size (int): Batch size for encoding.
            verbose (bool): Whether to print verbose information during encoding.

        Returns:
            NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
        """

        # SBert output is always Batch x Dim
        if isinstance(device, list):
            # Use multiprocess encoding for list of devices
            pool = self.model.start_multi_process_pool(target_devices=device)
            embeddings = self.model.encode_multi_process(
                prediction, pool=pool, batch_size=batch_size
            )
            self.model.stop_multi_process_pool(pool)
        else:
            # Single device encoding
            embeddings = self.model.encode(
                prediction,
                device=device,
                batch_size=batch_size,
                show_progress_bar=verbose,
            )
        return embeddings


def get_encoder(model_name: str) -> Encoder:
    """
    Get the encoder instance based on the specified model name.

    Args:
        model_name (str): Name of the model to instantiate
            Options:
                paraphrase-distilroberta-base-v1,
                stsb-roberta-large,
                sentence-transformers/use-cmlm-multilingual
            Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
            SentenceTransformer.

    Returns:
        Encoder: Instance of the selected encoder based on the model_name.

    Raises:
        EnvironmentError/RuntimeError: If an unsupported model_name is provided.
    """

    try:
        encoder = SBertEncoder(model_name)  # , device, batch_size, verbose)
    except EnvironmentError as err:
        raise EnvironmentError(str(err)) from None
    except Exception as err:
        raise RuntimeError(str(err)) from None

    return encoder