from dataclasses import dataclass, field
import os
import json
import logging
from argparse import Namespace
from typing import List, Literal, Optional, Union
from pydantic import AnyHttpUrl, BaseSettings, HttpUrl, validator, BaseModel


CURRENT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))

# request body
# 使用pydantic对请求中的body数据进行验证
class RequestDataStructure(BaseModel):
    input_text: List[str] = [""]
    uuid: Optional[int]
    
    # parameters for text2image model
    input_image: Optional[str]
    skip_steps: Optional[int]
    clip_guidance_scale: Optional[int]
    init_scale: Optional[int]

# API config
@dataclass
class APIConfig:

    # server config
    SERVER_HOST: AnyHttpUrl = "127.0.0.1"
    SERVER_PORT: int = 8990
    SERVER_NAME: str = ""
    PROJECT_NAME: str = ""
    API_PREFIX_STR: str = "/api"

    # api config
    API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = "POST"
    API_path: str = "/TextClassification"
    API_tags: List[str] = field(default_factory = lambda: [""])
    
    # CORS config
    BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = field(default_factory = lambda: ["*"])
    allow_credentials: bool = True
    allow_methods: List[str] = field(default_factory = lambda: ["*"])
    allow_headers: List[str] = field(default_factory = lambda: ["*"])
    
    # log config
    log_file_path: str = ""
    log_level: str = "INFO"
        
    # pipeline config
    pipeline_type: str = ""
    model_name: str = ""

    # model config
    # device: int = -1
    # texta_name: Optional[str] = "sentence"
    # textb_name: Optional[str] = "sentence2"
    # label_name: Optional[str] = "label"
    # max_length: int = 512
    # return_tensors: str = "pt"
    # padding: str = "longest"
    # truncation: bool = True
    # skip_special_tokens: bool = True
    # clean_up_tkenization_spaces: bool = True

    # # parameters for text2image model
    # skip_steps: Optional[int] = 0
    # clip_guidance_scale: Optional[int] = 0
    # init_scale: Optional[int] = 0

    def setup_config(self, args:Namespace) -> None:
        
        # load config file
        with open(CURRENT_DIR_PATH + "/" + args.config_path, "r") as jsonfile:
            config = json.load(jsonfile)

        server_config = config["SERVER"]
        logging_config = config["LOGGING"]
        pipeline_config = config["PIPELINE"]
        
        # server config
        self.SERVER_HOST: AnyHttpUrl = server_config["SERVER_HOST"]
        self.SERVER_PORT: int = server_config["SERVER_PORT"]
        self.SERVER_NAME: str = server_config["SERVER_NAME"]
        self.PROJECT_NAME: str = server_config["PROJECT_NAME"]
        self.API_PREFIX_STR: str = server_config["API_PREFIX_STR"]

        # api config
        self.API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = server_config["API_method"]
        self.API_path: str = server_config["API_path"]
        self.API_tags: List[str] = server_config["API_tags"]

        # CORS config
        self.BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = server_config["BACKEND_CORS_ORIGINS"]
        self.allow_credentials: bool = server_config["allow_credentials"]
        self.allow_methods: List[str] = server_config["allow_methods"]
        self.allow_headers: List[str] = server_config["allow_headers"]
        
        # log config
        self.log_file_path: str = logging_config["log_file_path"]
        self.log_level: str = logging_config["log_level"]
        
        # pipeline config
        self.pipeline_type: str = pipeline_config["pipeline_type"]
        self.model_name: str = pipeline_config["model_name"]

        # general model config
        self.model_settings: dict = pipeline_config["model_settings"]

        # 由于pipeline本身会解析参数,后续参数可以不要
        # 直接将model_settings字典转为Namespace后作为pipeline的args参数即可

        # self.device: int = self.model_settings["device"]
        # self.texta_name: Optional[str] = self.model_settings["texta_name"]
        # self.textb_name: Optional[str] = self.model_settings["textb_name"]
        # self.label_name: Optional[str] = self.model_settings["label_name"]
        # self.max_length: int = self.model_settings["max_length"]
        # self.return_tensors: str = self.model_settings["return_tensors"]
        # self.padding: str = self.model_settings["padding"]
        # self.truncation: bool = self.model_settings["truncation"]
        # self.skip_special_tokens: bool = self.model_settings["skip_special_tokens"]
        # self.clean_up_tkenization_spaces: bool = self.model_settings["clean_up_tkenization_spaces"]

        # # specific parameters for text2image model
        # self.skip_steps: Optional[int] = self.model_settings["skip_steps"]
        # self.clip_guidance_scale: Optional[int] = self.model_settings["clip_guidance_scale"]
        # self.init_scale: Optional[int] = self.model_settings["init_scale"]
        


def setup_logger(logger, user_config: APIConfig):
        
        # default level: INFO 

        logger.setLevel(getattr(logging, user_config.log_level, "INFO"))
        ch = logging.StreamHandler()
        
        if(user_config.log_file_path == ""):
            fh = logging.FileHandler(filename = CURRENT_DIR_PATH + "/"  + user_config.SERVER_NAME  + ".log")
        elif(".log" not in user_config.log_file_path[-5:-1]):
            fh = logging.FileHandler(filename = user_config.log_file_path + "/" + user_config.SERVER_NAME + ".log")
        else:
            fh = logging.FileHandler(filename = user_config.log_file_path)


        formatter = logging.Formatter(
            "%(asctime)s - %(module)s - %(funcName)s - line:%(lineno)d - %(levelname)s - %(message)s"
        )

        ch.setFormatter(formatter)
        fh.setFormatter(formatter)
        logger.addHandler(ch)  # Exporting logs to the screen
        logger.addHandler(fh)  # Exporting logs to a file

        return logger

user_config = APIConfig()
api_logger = logging.getLogger()