import os
import sys
import random
import torch
from pathlib import Path
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
import logging
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes, SaveImage  # <-- Node SaveImage
from comfy import model_management
import folder_paths

# 1. Configurar logging para debug
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 2. Configuração de Caminhos e Imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)

# 3. Configuração de Diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
models_dir = os.path.join(BASE_DIR, "models")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)

# 4. Configurar caminhos dos modelos e verificar estrutura
MODEL_FOLDERS = ["style_models", "text_encoders", "vae", "unet", "clip_vision"]
for model_folder in MODEL_FOLDERS:
    folder_path = os.path.join(models_dir, model_folder)
    os.makedirs(folder_path, exist_ok=True)
    folder_paths.add_model_folder_path(model_folder, folder_path)
    logger.info(f"Pasta de modelo configurada: {model_folder}")

# 5. Diagnóstico CUDA
logger.info(f"Python version: {sys.version}")
logger.info(f"Torch version: {torch.__version__}")
logger.info(f"CUDA disponível: {torch.cuda.is_available()}")
logger.info(f"Quantidade de GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    logger.info(f"GPU atual: {torch.cuda.get_device_name(0)}")

# 6. Inicialização do ComfyUI
logger.info("Inicializando ComfyUI...")
try:
    init_extra_nodes()
except Exception as e:
    logger.warning(f"Aviso na inicialização de nós extras: {str(e)}")
    logger.info("Continuando mesmo com avisos nos nós extras...")

# 7. Helper Functions
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
    try:
        return obj[index]
    except KeyError:
        return obj["result"][index]

def verify_file_exists(folder: str, filename: str) -> bool:
    file_path = os.path.join(models_dir, folder, filename)
    exists = os.path.exists(file_path)
    if not exists:
        logger.error(f"Arquivo não encontrado: {file_path}")
    return exists

# 8. Download de Modelos
logger.info("Baixando modelos necessários...")
try:
    hf_hub_download(
        repo_id="black-forest-labs/FLUX.1-Redux-dev",
        filename="flux1-redux-dev.safetensors",
        local_dir=os.path.join(models_dir, "style_models")
    )
    hf_hub_download(
        repo_id="comfyanonymous/flux_text_encoders",
        filename="t5xxl_fp16.safetensors",
        local_dir=os.path.join(models_dir, "text_encoders")
    )
    hf_hub_download(
        repo_id="zer0int/CLIP-GmP-ViT-L-14",
        filename="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
        local_dir=os.path.join(models_dir, "text_encoders")
    )
    hf_hub_download(
        repo_id="black-forest-labs/FLUX.1-dev",
        filename="ae.safetensors",
        local_dir=os.path.join(models_dir, "vae")
    )
    hf_hub_download(
        repo_id="black-forest-labs/FLUX.1-dev",
        filename="flux1-dev.safetensors",
        local_dir=os.path.join(models_dir, "unet")
    )
    hf_hub_download(
        repo_id="Comfy-Org/sigclip_vision_384",
        filename="sigclip_vision_patch14_384.safetensors",
        local_dir=os.path.join(models_dir, "clip_vision")
    )
except Exception as e:
    logger.error(f"Erro ao baixar modelos: {str(e)}")
    raise

# 9. Inicialização dos Modelos
logger.info("Inicializando modelos...")
try:
    with torch.no_grad():
        # CLIP
        logger.info("Carregando CLIP...")
        dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
        CLIP_MODEL = dualcliploader.load_clip(
            clip_name1="t5xxl_fp16.safetensors",
            clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
            type="flux"
        )
        if CLIP_MODEL is None:
            raise ValueError("Falha ao carregar CLIP model")

        # CLIP Vision
        logger.info("Carregando CLIP Vision...")
        clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
        CLIP_VISION = clipvisionloader.load_clip(
            clip_name="sigclip_vision_patch14_384.safetensors"
        )
        if CLIP_VISION is None:
            raise ValueError("Falha ao carregar CLIP Vision model")

        # Style Model
        logger.info("Carregando Style Model...")
        stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
        STYLE_MODEL = stylemodelloader.load_style_model(
            style_model_name="flux1-redux-dev.safetensors"
        )
        if STYLE_MODEL is None:
            raise ValueError("Falha ao carregar Style Model")

        # VAE
        logger.info("Carregando VAE...")
        vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
        VAE_MODEL = vaeloader.load_vae(
            vae_name="ae.safetensors"
        )
        if VAE_MODEL is None:
            raise ValueError("Falha ao carregar VAE model")

        # UNET
        logger.info("Carregando UNET...")
        unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
        UNET_MODEL = unetloader.load_unet(
            unet_name="flux1-dev.safetensors",
            weight_dtype="fp8_e4m3fn"  # ajuste se preciso
        )
        if UNET_MODEL is None:
            raise ValueError("Falha ao carregar UNET model")

        logger.info("Carregando modelos na GPU...")
        model_loaders = [CLIP_MODEL, VAE_MODEL, CLIP_VISION, UNET_MODEL]
        model_management.load_models_gpu([
            loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0]
            for loader in model_loaders
        ])
        logger.info("Modelos carregados com sucesso")
except Exception as e:
    logger.error(f"Erro ao inicializar modelos: {str(e)}")
    raise

# 10. Função de Geração
@spaces.GPU
def generate_image(
    prompt, input_image, lora_weight, guidance, downsampling_factor,
    weight, seed, width, height, batch_size, steps,
    progress=gr.Progress(track_tqdm=True)
):
    try:
        with torch.no_grad():
            logger.info(f"Iniciando geração com prompt: {prompt}")

            # Codificar texto
            cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
            encoded_text = cliptextencode.encode(
                text=prompt,
                clip=CLIP_MODEL[0]
            )

            # Carregar e processar imagem
            loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
            loaded_image = loadimage.load_image(image=input_image)
            if loaded_image is None:
                raise ValueError("Erro ao carregar a imagem de entrada")
            logger.info("Imagem carregada com sucesso")

            # Flux Guidance
            fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
            flux_guidance = fluxguidance.append(
                guidance=guidance,
                conditioning=encoded_text[0]
            )

            # Redux Advanced
            reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
            redux_result = reduxadvanced.apply_stylemodel(
                downsampling_factor=downsampling_factor,
                downsampling_function="area",
                mode="keep aspect ratio",
                weight=weight,
                conditioning=flux_guidance[0],
                style_model=STYLE_MODEL[0],
                clip_vision=CLIP_VISION[0],
                image=loaded_image[0]
            )

            # Empty Latent
            emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
            empty_latent = emptylatentimage.generate(
                width=width,
                height=height,
                batch_size=batch_size
            )

            # KSampler
            logger.info("Iniciando sampling...")
            ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
            sampled = ksampler.sample(
                seed=seed,
                steps=steps,
                cfg=1,
                sampler_name="euler",
                scheduler="simple",
                denoise=1,
                model=UNET_MODEL[0],
                positive=redux_result[0],
                negative=flux_guidance[0],
                latent_image=empty_latent[0]
            )

            # VAE Decode
            logger.info("Decodificando imagem...")
            vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
            decoded = vaedecode.decode(
                samples=sampled[0],
                vae=VAE_MODEL[0]
            )

            # Salvar Imagem
            logger.info("Salvando imagem via node SaveImage...")
            decoded_tensor = decoded[0] 
            saveimage_node = NODE_CLASS_MAPPINGS["SaveImage"]()
            result_dict = saveimage_node.save_images(
                filename_prefix="Flux_", 
                images=decoded_tensor
            )
            saved_path = os.path.join(output_dir, result_dict["ui"]["images"][0]["filename"])
            logger.info(f"Imagem salva em: {saved_path}")
            return saved_path

    except Exception as e:
        logger.error(f"Erro ao gerar imagem: {str(e)}")
        return None

# 10. Interface Gradio
with gr.Blocks() as app:
    gr.Markdown("# FLUX Redux Image Generator")
   
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Enter your prompt here...",
                lines=5
            )
            input_image = gr.Image(
                label="Input Image",
                type="filepath"
            )
           
            with gr.Row():
                with gr.Column():
                    lora_weight = gr.Slider(
                        minimum=0,
                        maximum=2,
                        step=0.1,
                        value=0.6,
                        label="LoRA Weight"
                    )
                    guidance = gr.Slider(
                        minimum=0,
                        maximum=20,
                        step=0.1,
                        value=3.5,
                        label="Guidance"
                    )
                    downsampling_factor = gr.Slider(
                        minimum=1,
                        maximum=8,
                        step=1,
                        value=3,
                        label="Downsampling Factor"
                    )
                    weight = gr.Slider(
                        minimum=0,
                        maximum=2,
                        step=0.1,
                        value=1.0,
                        label="Model Weight"
                    )
                with gr.Column():
                    seed = gr.Number(
                        value=random.randint(1, 2**64),
                        label="Seed",
                        precision=0
                    )
                    width = gr.Number(
                        value=1024,
                        label="Width",
                        precision=0
                    )
                    height = gr.Number(
                        value=1024,
                        label="Height",
                        precision=0
                    )
                    batch_size = gr.Number(
                        value=1,
                        label="Batch Size",
                        precision=0
                    )
                    steps = gr.Number(
                        value=20,
                        label="Steps",
                        precision=0
                    )
           
            generate_btn = gr.Button("Generate Image")
       
        with gr.Column():
            output_image = gr.Image(label="Generated Image", type="filepath")
   
    generate_btn.click(
        fn=generate_image,
        inputs=[
            prompt_input,
            input_image,
            lora_weight,
            guidance,
            downsampling_factor,
            weight,
            seed,
            width,
            height,
            batch_size,
            steps
        ],
        outputs=[output_image]
    )

if __name__ == "__main__":
    app.launch()