import logging
import os
import sys

import torch
from torch import nn
import torch.distributed as dist
import torch.nn.functional as F

from .norm import SimpleRMSNorm as SimpleRMSNormTorch
from .srmsnorm_triton import SimpleRMSNorm as SimpleRMSNormTriton

use_triton = eval(os.environ.get("use_triton", default="True"))
debug = eval(os.environ.get("debug", default="False"))

if use_triton:
    SimpleRMSNorm = SimpleRMSNormTriton
else:
    SimpleRMSNorm = SimpleRMSNormTorch

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("print_config")

BASE_DIM = 256


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def logging_info(string):
    if is_main_process():
        logger.info(string)


def print_params(**kwargs):
    if is_main_process():
        logger.info(f"start print config of {kwargs['__class__']}")
        for key in kwargs:
            if key in ["__class__", "self"]:
                continue
            logger.info(f"{key}: {kwargs[key]}")
        logger.info(f"end print config of {kwargs['__class__']}")


def print_config(config):
    if is_main_process():
        logger.info(f"start print config of {config['__class__']}")
        for key in config:
            if key in ["__class__", "self"]:
                continue
            logger.info(f"{key}: {config[key]}")
        logger.info(f"end print config of {config['__class__']}")


def print_module(module):
    named_modules = set()
    for p in module.named_modules():
        named_modules.update([p[0]])
    named_modules = list(named_modules)

    string_repr = ""
    for p in module.named_parameters():
        name = p[0].split(".")[0]
        if name not in named_modules:
            string_repr = (string_repr + "(" + name + "): " + "Tensor(" +
                           str(tuple(p[1].shape)) + ", requires_grad=" +
                           str(p[1].requires_grad) + ")\n")

    return string_repr.rstrip("\n")


def get_activation_fn(activation):
    if debug:
        logger.info(f"activation: {activation}")
    if activation == "gelu":
        return F.gelu
    elif activation == "relu":
        return F.relu
    elif activation == "elu":
        return F.elu
    elif activation == "sigmoid":
        return F.sigmoid
    elif activation == "exp":

        def f(x):
            with torch.no_grad():
                x_max = torch.max(x, dim=-1, keepdims=True).values
            y = torch.exp(x - x_max)

            return y

        return f
    elif activation == "leak":
        return F.leaky_relu
    elif activation == "1+elu":

        def f(x):
            return 1 + F.elu(x)

        return f
    elif activation == "2+elu":

        def f(x):
            return 2 + F.elu(x)

        return f
    elif activation == "silu" or activation == "swish":
        return F.silu
    elif activation == "sine":
        return torch.sin
    else:
        logger.info(
            f"activation: does not support {activation}, use Identity!!!")
        return lambda x: x


def get_norm_fn(norm_type):
    if norm_type == "simplermsnorm":
        return SimpleRMSNorm
    else:
        return nn.LayerNorm


def convert_to_multiple_of_base(x):
    return BASE_DIM * ((x + BASE_DIM - 1) // BASE_DIM)