import numpy as np
from scipy.optimize import minimize, LinearConstraint, NonlinearConstraint
from collections import OrderedDict
import pandas as pd
from decimal import Decimal


def round_num(n, decimal=2):
    n = Decimal(n)
    return n.to_integral() if n == n.to_integral() else round(n.normalize(), decimal)


def numerize(n, decimal=2):
    # 60 sufixes
    sufixes = [
        "",
        "K",
        "M",
        "B",
        "T",
        "Qa",
        "Qu",
        "S",
        "Oc",
        "No",
        "D",
        "Ud",
        "Dd",
        "Td",
        "Qt",
        "Qi",
        "Se",
        "Od",
        "Nd",
        "V",
        "Uv",
        "Dv",
        "Tv",
        "Qv",
        "Qx",
        "Sx",
        "Ox",
        "Nx",
        "Tn",
        "Qa",
        "Qu",
        "S",
        "Oc",
        "No",
        "D",
        "Ud",
        "Dd",
        "Td",
        "Qt",
        "Qi",
        "Se",
        "Od",
        "Nd",
        "V",
        "Uv",
        "Dv",
        "Tv",
        "Qv",
        "Qx",
        "Sx",
        "Ox",
        "Nx",
        "Tn",
        "x",
        "xx",
        "xxx",
        "X",
        "XX",
        "XXX",
        "END",
    ]

    sci_expr = [
        1e0,
        1e3,
        1e6,
        1e9,
        1e12,
        1e15,
        1e18,
        1e21,
        1e24,
        1e27,
        1e30,
        1e33,
        1e36,
        1e39,
        1e42,
        1e45,
        1e48,
        1e51,
        1e54,
        1e57,
        1e60,
        1e63,
        1e66,
        1e69,
        1e72,
        1e75,
        1e78,
        1e81,
        1e84,
        1e87,
        1e90,
        1e93,
        1e96,
        1e99,
        1e102,
        1e105,
        1e108,
        1e111,
        1e114,
        1e117,
        1e120,
        1e123,
        1e126,
        1e129,
        1e132,
        1e135,
        1e138,
        1e141,
        1e144,
        1e147,
        1e150,
        1e153,
        1e156,
        1e159,
        1e162,
        1e165,
        1e168,
        1e171,
        1e174,
        1e177,
    ]
    minus_buff = n
    n = abs(n)

    if n < 1:
        return f"{round(n/1000, decimal)}K"

    for x in range(len(sci_expr)):
        try:
            if n >= sci_expr[x] and n < sci_expr[x + 1]:
                sufix = sufixes[x]
                if n >= 1e3:
                    num = str(round_num(n / sci_expr[x], decimal))
                else:
                    num = str(round_num(n, decimal))
                return num + sufix if minus_buff > 0 else "-" + num + sufix
        except IndexError:
            pass


def class_to_dict(class_instance):
    attr_dict = {}
    if isinstance(class_instance, Channel):
        attr_dict["type"] = "Channel"
        attr_dict["name"] = class_instance.name
        attr_dict["dates"] = class_instance.dates
        attr_dict["spends"] = class_instance.actual_spends
        attr_dict["conversion_rate"] = class_instance.conversion_rate
        attr_dict["modified_spends"] = class_instance.modified_spends
        attr_dict["modified_sales"] = class_instance.modified_sales
        attr_dict["response_curve_type"] = class_instance.response_curve_type
        attr_dict["response_curve_params"] = class_instance.response_curve_params
        attr_dict["penalty"] = class_instance.penalty
        attr_dict["bounds"] = class_instance.bounds
        attr_dict["actual_total_spends"] = class_instance.actual_total_spends
        attr_dict["actual_total_sales"] = class_instance.actual_total_sales
        attr_dict["modified_total_spends"] = class_instance.modified_total_spends
        attr_dict["modified_total_sales"] = class_instance.modified_total_sales
        attr_dict["actual_mroi"] = class_instance.get_marginal_roi("actual")
        attr_dict["modified_mroi"] = class_instance.get_marginal_roi("modified")
        attr_dict["freeze"] = class_instance.freeze
        attr_dict["correction"] = class_instance.correction

    elif isinstance(class_instance, Scenario):
        attr_dict["type"] = "Scenario"
        attr_dict["name"] = class_instance.name
        attr_dict["bounds"] = class_instance.bounds
        channels = []
        for channel in class_instance.channels.values():
            channels.append(class_to_dict(channel))
        attr_dict["channels"] = channels
        attr_dict["constant"] = class_instance.constant
        attr_dict["correction"] = class_instance.correction
        attr_dict["actual_total_spends"] = class_instance.actual_total_spends
        attr_dict["actual_total_sales"] = class_instance.actual_total_sales
        attr_dict["modified_total_spends"] = class_instance.modified_total_spends
        attr_dict["modified_total_sales"] = class_instance.modified_total_sales

    return attr_dict


# def class_convert_to_dict(class_instance):
#     attr_dict = {}
#     if isinstance(class_instance, Channel):
#         attr_dict["type"] = "Channel"
#         attr_dict["name"] = class_instance.name
#         attr_dict["dates"] = class_instance.dates
#         attr_dict["spends"] = class_instance.actual_spends
#         attr_dict["conversion_rate"] = class_instance.conversion_rate
#         attr_dict["modified_spends"] = class_instance.modified_spends
#         attr_dict["modified_sales"] = class_instance.modified_sales
#         attr_dict["response_curve_type"] = class_instance.response_curve_type
#         attr_dict["response_curve_params"] = class_instance.response_curve_params
#         # attr_dict["penalty"] = class_instance.penalty
#         attr_dict["bounds"] = class_instance.bounds
#         attr_dict["actual_total_spends"] = class_instance.actual_total_spends
#         attr_dict["actual_total_sales"] = class_instance.actual_total_sales
#         attr_dict["modified_total_spends"] = class_instance.modified_total_spends
#         attr_dict["modified_total_sales"] = class_instance.modified_total_sales
#         # attr_dict["actual_mroi"] = class_instance.get_marginal_roi("actual")
#         # attr_dict["modified_mroi"] = class_instance.get_marginal_roi("modified")

#         attr_dict["freeze"] = class_instance.freeze
#         attr_dict["correction"] = class_instance.correction

#     elif isinstance(class_instance, Scenario):
#         attr_dict["type"] = "Scenario"
#         attr_dict["name"] = class_instance.name
#         channels = {}
#         for channel in class_instance.channels.values():
#             channels[channel.name] = class_to_dict(channel)
#         attr_dict["channels"] = channels
#         attr_dict["constant"] = class_instance.constant
#         attr_dict["correction"] = class_instance.correction
#         attr_dict["actual_total_spends"] = class_instance.actual_total_spends
#         attr_dict["actual_total_sales"] = class_instance.actual_total_sales
#         attr_dict["modified_total_spends"] = class_instance.modified_total_spends
#         attr_dict["modified_total_sales"] = class_instance.modified_total_sales

#         attr_dict["bound_type"] = class_instance.bound_type

#         attr_dict["bounds"] = class_instance.bounds

#     return attr_dict.copy()

# Function to convert class instance to dictionary
def class_convert_to_dict(class_instance):
    attr_dict = {}

    if isinstance(class_instance, Channel):
        # Convert Channel instance to dictionary
        attr_dict["type"] = "Channel"
        attr_dict["name"] = class_instance.name
        attr_dict["dates"] = class_instance.dates
        attr_dict["spends"] = list(class_instance.actual_spends)
        attr_dict["conversion_rate"] = class_instance.conversion_rate
        attr_dict["modified_spends"] = list(class_instance.modified_spends)
        attr_dict["modified_sales"] = list(class_instance.modified_sales)
        attr_dict["response_curve_type"] = class_instance.response_curve_type
        attr_dict["response_curve_params"] = class_instance.response_curve_params.copy()
        attr_dict["bounds"] = class_instance.bounds.copy()
        attr_dict["actual_total_spends"] = class_instance.actual_total_spends
        attr_dict["actual_total_sales"] = class_instance.actual_total_sales
        attr_dict["modified_total_spends"] = class_instance.modified_total_spends
        attr_dict["modified_total_sales"] = class_instance.modified_total_sales
        attr_dict["freeze"] = class_instance.freeze
        attr_dict["correction"] = class_instance.correction.copy()

    elif isinstance(class_instance, Scenario):
        # Convert Scenario instance to dictionary
        attr_dict["type"] = "Scenario"
        attr_dict["name"] = class_instance.name

        channels = {}
        for channel in class_instance.channels.values():
            channels[channel.name] = class_convert_to_dict(channel)
        attr_dict["channels"] = channels

        attr_dict["constant"] = list(class_instance.constant)
        attr_dict["correction"] = list(class_instance.correction)
        attr_dict["actual_total_spends"] = class_instance.actual_total_spends
        attr_dict["actual_total_sales"] = class_instance.actual_total_sales
        attr_dict["modified_total_spends"] = class_instance.modified_total_spends
        attr_dict["modified_total_sales"] = class_instance.modified_total_sales
        attr_dict["bound_type"] = class_instance.bound_type
        attr_dict["bounds"] = class_instance.bounds.copy()

    return attr_dict


def class_from_dict(attr_dict):
    if attr_dict["type"] == "Channel":
        return Channel.from_dict(attr_dict)
    elif attr_dict["type"] == "Scenario":
        return Scenario.from_dict(attr_dict)


class Channel:
    def __init__(
        self,
        name,
        dates,
        spends,
        response_curve_type,
        response_curve_params,
        bounds,
        correction,
        conversion_rate=1,
        modified_spends=None,
        penalty=True,
        freeze=False,
    ):
        self.name = name
        self.dates = dates
        self.conversion_rate = conversion_rate
        self.actual_spends = spends.copy()
        self.correction = correction

        if modified_spends is None:
            self.modified_spends = self.actual_spends.copy()
        else:
            self.modified_spends = modified_spends

        self.response_curve_type = response_curve_type
        self.response_curve_params = response_curve_params
        self.bounds = bounds
        self.penalty = penalty
        self.freeze = freeze

        self.upper_limit = self.actual_spends.max() + self.actual_spends.std()
        self.power = np.ceil(np.log(self.actual_spends.max()) / np.log(10)) - 3
        # self.actual_sales = None
        # self.actual_sales = self.response_curve(self.actual_spends)
        self.actual_total_spends = self.actual_spends.sum()
        self.actual_total_sales = self.actual_sales.sum()
        self.modified_sales = self.calculate_sales()
        self.modified_total_spends = self.modified_spends.sum()
        self.modified_total_sales = self.modified_sales.sum()
        self.delta_spends = self.modified_total_spends - self.actual_total_spends
        self.delta_sales = self.modified_total_sales - self.actual_total_sales

    @property
    def actual_sales(self):
        return self.response_curve(self.actual_spends) + self.correction

    def update_penalty(self, penalty):
        self.penalty = penalty

    def _modify_spends(self, spends_array, total_spends):
        return spends_array * total_spends / spends_array.sum()

    def modify_spends(self, total_spends):
        self.modified_spends = (
            self.modified_spends * total_spends / self.modified_spends.sum()
        )

    def calculate_sales(self):
        return self.response_curve(self.modified_spends) + self.correction

    def response_curve(self, x):
        if self.penalty:
            x = np.where(
                x < self.upper_limit,
                x,
                self.upper_limit + (x - self.upper_limit) * self.upper_limit / x,
            )
        if self.response_curve_type == "s-curve":
            if self.power >= 0:
                x = x / 10**self.power
            x = x.astype("float64")
            K = self.response_curve_params["K"]
            b = self.response_curve_params["b"]
            a = self.response_curve_params["a"]
            x0 = self.response_curve_params["x0"]
            sales = K / (1 + b * np.exp(-a * (x - x0)))
        if self.response_curve_type == "linear":
            beta = self.response_curve_params["beta"]
            sales = beta * x

        return sales

    def get_marginal_roi(self, flag):
        K = self.response_curve_params["K"]
        a = self.response_curve_params["a"]
        # x = self.modified_total_spends
        # if self.power >= 0 :
        #     x = x / 10**self.power
        # x = x.astype('float64')
        # return K*b*a*np.exp(-a*(x-x0)) / (1 + b * np.exp(-a*(x - x0)))**2
        if flag == "actual":
            y = self.response_curve(self.actual_spends)
            # spends_array = self.actual_spends
            # total_spends = self.actual_total_spends
            # total_sales = self.actual_total_sales

        else:
            y = self.response_curve(self.modified_spends)
            # spends_array = self.modified_spends
            # total_spends = self.modified_total_spends
            # total_sales = self.modified_total_sales

        # spends_inc_1 = self._modify_spends(spends_array, total_spends+1)
        mroi = a * (y) * (1 - y / K)
        return mroi.sum() / len(self.modified_spends)
        # spends_inc_1 = self.spends_array + 1
        # new_total_sales = self.response_curve(spends_inc_1).sum()
        # return (new_total_sales - total_sales) / len(self.modified_spends)

    def update(self, total_spends):
        self.modify_spends(total_spends)
        self.modified_sales = self.calculate_sales()
        self.modified_total_spends = self.modified_spends.sum()
        self.modified_total_sales = self.modified_sales.sum()
        self.delta_spends = self.modified_total_spends - self.actual_total_spends
        self.delta_sales = self.modified_total_sales - self.actual_total_sales

    def intialize(self):
        self.new_spends = self.old_spends

    def __str__(self):
        return f"{self.name},{self.actual_total_sales}, {self.modified_total_spends}"

    @classmethod
    def from_dict(cls, attr_dict):
        return Channel(
            name=attr_dict["name"],
            dates=attr_dict["dates"],
            spends=attr_dict["spends"],
            bounds=attr_dict["bounds"],
            modified_spends=attr_dict["modified_spends"],
            response_curve_type=attr_dict["response_curve_type"],
            response_curve_params=attr_dict["response_curve_params"],
            penalty=attr_dict["penalty"],
            correction=attr_dict["correction"],
        )

    def update_response_curves(self, response_curve_params):
        self.response_curve_params = response_curve_params


class Scenario:
    def __init__(
        self, name, channels, constant, correction, bound_type=False, bounds=[-10, 10]
    ):
        self.name = name
        self.channels = channels
        self.constant = constant
        self.correction = correction
        self.bound_type = bound_type
        self.bounds = bounds

        self.actual_total_spends = self.calculate_modified_total_spends()
        self.actual_total_sales = self.calculate_actual_total_sales()
        self.modified_total_sales = self.calculate_modified_total_sales()
        self.modified_total_spends = self.calculate_modified_total_spends()
        self.delta_spends = self.modified_total_spends - self.actual_total_spends
        self.delta_sales = self.modified_total_sales - self.actual_total_sales

    def update_penalty(self, value):
        for channel in self.channels.values():
            channel.update_penalty(value)

    def calculate_modified_total_spends(self):
        total_actual_spends = 0.0
        for channel in self.channels.values():
            total_actual_spends += channel.actual_total_spends * channel.conversion_rate
        return total_actual_spends

    def calculate_modified_total_spends(self):
        total_modified_spends = 0.0
        for channel in self.channels.values():
            # import streamlit as st
            # st.write(channel.modified_total_spends )
            total_modified_spends += (
                channel.modified_total_spends * channel.conversion_rate
            )
        return total_modified_spends

    def calculate_actual_total_sales(self):
        total_actual_sales = self.constant.sum() + self.correction.sum()
        for channel in self.channels.values():
            total_actual_sales += channel.actual_total_sales
        return total_actual_sales

    def calculate_modified_total_sales(self):
        total_modified_sales = self.constant.sum() + self.correction.sum()
        for channel in self.channels.values():
            total_modified_sales += channel.modified_total_sales
        return total_modified_sales

    def update(self, channel_name, modified_spends):
        self.channels[channel_name].update(modified_spends)
        self.modified_total_sales = self.calculate_modified_total_sales()
        self.modified_total_spends = self.calculate_modified_total_spends()
        self.delta_spends = self.modified_total_spends - self.actual_total_spends
        self.delta_sales = self.modified_total_sales - self.actual_total_sales

    # def optimize_spends(self, sales_percent, channels_list, algo="COBYLA"):
    #     desired_sales = self.actual_total_sales * (1 + sales_percent / 100.0)

    #     def constraint(x):
    #         for ch, spends in zip(channels_list, x):
    #             self.update(ch, spends)
    #         return self.modified_total_sales - desired_sales

    #     bounds = []
    #     for ch in channels_list:
    #         bounds.append(
    #             (1 + np.array([-50.0, 100.0]) / 100.0)
    #             * self.channels[ch].actual_total_spends
    #         )

    #     initial_point = []
    #     for bound in bounds:
    #         initial_point.append(bound[0])

    #     power = np.ceil(np.log(sum(initial_point)) / np.log(10))

    #     constraints = [NonlinearConstraint(constraint, -1.0, 1.0)]

    #     res = minimize(
    #         lambda x: sum(x) / 10 ** (power),
    #         bounds=bounds,
    #         x0=initial_point,
    #         constraints=constraints,
    #         method=algo,
    #         options={"maxiter": int(2e7), "catol": 1},
    #     )

    #     for channel_name, modified_spends in zip(channels_list, res.x):
    #         self.update(channel_name, modified_spends)

    #     return zip(channels_list, res.x)

    def optimize_spends(self, sales_percent, channels_list, algo="trust-constr"):
        desired_sales = self.actual_total_sales * (1 + sales_percent / 100.0)

        def constraint(x):
            for ch, spends in zip(channels_list, x):
                self.update(ch, spends)
            return self.modified_total_sales - desired_sales

        bounds = []
        for ch in channels_list:
            bounds.append(
                (1 + np.array([-50.0, 100.0]) / 100.0)
                * self.channels[ch].actual_total_spends
            )

        initial_point = []
        for bound in bounds:
            initial_point.append(bound[0])

        power = np.ceil(np.log(sum(initial_point)) / np.log(10))

        constraints = [NonlinearConstraint(constraint, -1.0, 1.0)]

        res = minimize(
            lambda x: sum(x) / 10 ** (power),
            bounds=bounds,
            x0=initial_point,
            constraints=constraints,
            method=algo,
            options={"maxiter": int(2e7), "xtol": 100},
        )

        for channel_name, modified_spends in zip(channels_list, res.x):
            self.update(channel_name, modified_spends)

        return zip(channels_list, res.x)

    def optimize(self, spends_percent, channels_list):
        # channels_list = self.channels.keys()
        num_channels = len(channels_list)
        spends_constant = []
        spends_constraint = 0.0
        for channel_name in channels_list:
            # spends_constraint += self.channels[channel_name].modified_total_spends
            spends_constant.append(self.channels[channel_name].conversion_rate)
            spends_constraint += (
                self.channels[channel_name].actual_total_spends
                * self.channels[channel_name].conversion_rate
            )
        spends_constraint = spends_constraint * (1 + spends_percent / 100)
        # constraint= LinearConstraint(np.ones((num_channels,)), lb = spends_constraint, ub = spends_constraint)
        constraint = LinearConstraint(
            np.array(spends_constant),
            lb=spends_constraint,
            ub=spends_constraint,
        )
        bounds = []
        old_spends = []
        for channel_name in channels_list:
            _channel_class = self.channels[channel_name]
            channel_bounds = _channel_class.bounds
            channel_actual_total_spends = _channel_class.actual_total_spends * (
                (1 + spends_percent / 100)
            )
            old_spends.append(channel_actual_total_spends)
            bounds.append((1 + channel_bounds / 100) * channel_actual_total_spends)

        def objective_function(x):
            for channel_name, modified_spends in zip(channels_list, x):
                self.update(channel_name, modified_spends)
            return -1 * self.modified_total_sales

        power = np.ceil(np.log(self.modified_total_sales) / np.log(10))
  
        res = minimize(
            lambda x: objective_function(x) / 10 ** (power - 1),
            method="trust-constr",
            x0=old_spends,
            constraints=constraint,
            bounds=bounds,
            options={"maxiter": int(1e7), "xtol": 0.1},
        )
        # res = dual_annealing(
        # objective_function,
        # x0=old_spends,
        # mi
        # constraints=constraint,
        # bounds=bounds,
        # tol=1e-16
   
        for channel_name, modified_spends in zip(channels_list, res.x):
            self.update(channel_name, modified_spends)

        return zip(channels_list, res.x)

    def save(self):
        details = {}
        actual_list = []
        modified_list = []
        data = {}
        channel_data = []

        summary_rows = []
        actual_list.append(
            {
                "name": "Total",
                "Spends": self.actual_total_spends,
                "Sales": self.actual_total_sales,
            }
        )
        modified_list.append(
            {
                "name": "Total",
                "Spends": self.modified_total_spends,
                "Sales": self.modified_total_sales,
            }
        )
        for channel in self.channels.values():
            name_mod = channel.name.replace("_", " ")
            if name_mod.lower().endswith(" imp"):
                name_mod = name_mod.replace("Imp", " Impressions")
            summary_rows.append(
                [
                    name_mod,
                    channel.actual_total_spends,
                    channel.modified_total_spends,
                    channel.actual_total_sales,
                    channel.modified_total_sales,
                    round(
                        channel.actual_total_sales / channel.actual_total_spends,
                        2,
                    ),
                    round(
                        channel.modified_total_sales / channel.modified_total_spends,
                        2,
                    ),
                    channel.get_marginal_roi("actual"),
                    channel.get_marginal_roi("modified"),
                ]
            )
            data[channel.name] = channel.modified_spends
            data["Date"] = channel.dates
            data["Sales"] = (
                data.get("Sales", np.zeros((len(channel.dates),)))
                + channel.modified_sales
            )
            actual_list.append(
                {
                    "name": channel.name,
                    "Spends": channel.actual_total_spends,
                    "Sales": channel.actual_total_sales,
                    "ROI": round(
                        channel.actual_total_sales / channel.actual_total_spends,
                        2,
                    ),
                }
            )
            modified_list.append(
                {
                    "name": channel.name,
                    "Spends": channel.modified_total_spends,
                    "Sales": channel.modified_total_sales,
                    "ROI": round(
                        channel.modified_total_sales / channel.modified_total_spends,
                        2,
                    ),
                    "Marginal ROI": channel.get_marginal_roi("modified"),
                }
            )

            channel_data.append(
                {
                    "channel": channel.name,
                    "spends_act": channel.actual_total_spends,
                    "spends_mod": channel.modified_total_spends,
                    "sales_act": channel.actual_total_sales,
                    "sales_mod": channel.modified_total_sales,
                }
            )
        summary_rows.append(
            [
                "Total",
                self.actual_total_spends,
                self.modified_total_spends,
                self.actual_total_sales,
                self.modified_total_sales,
                round(self.actual_total_sales / self.actual_total_spends, 2),
                round(self.modified_total_sales / self.modified_total_spends, 2),
                0.0,
                0.0,
            ]
        )
        details["Actual"] = actual_list
        details["Modified"] = modified_list
        columns_index = pd.MultiIndex.from_product(
            [[""], ["Channel"]], names=["first", "second"]
        )
        columns_index = columns_index.append(
            pd.MultiIndex.from_product(
                [["Spends", "NRPU", "ROI", "MROI"], ["Actual", "Simulated"]],
                names=["first", "second"],
            )
        )
        details["Summary"] = pd.DataFrame(summary_rows, columns=columns_index)
        data_df = pd.DataFrame(data)
        channel_list = list(self.channels.keys())
        data_df = data_df[["Date", *channel_list, "Sales"]]

        details["download"] = {
            "data_df": data_df,
            "channels_df": pd.DataFrame(channel_data),
            "total_spends_act": self.actual_total_spends,
            "total_sales_act": self.actual_total_sales,
            "total_spends_mod": self.modified_total_spends,
            "total_sales_mod": self.modified_total_sales,
        }

        return details

    @classmethod
    def from_dict(cls, attr_dict):
        channels_list = attr_dict["channels"]
        channels = {
            channel["name"]: class_from_dict(channel) for channel in channels_list
        }
        return Scenario(
            name=attr_dict["name"],
            channels=channels,
            constant=attr_dict["constant"],
            correction=attr_dict["correction"],
        )