import gradio as gr
import numpy as np
import random
import os
import spaces #[uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderTiny, FluxPipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import sys
sys.path.append('.')
from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV



# Model configurations
SDXL_CONCEPTS = [
    "alien", "ancient ruins", "animal", "bike", "car", "Citadel",
    "coral", "cowboy", "face", "futuristic cities", "monster",
    "mystical creature", "planet", "plant", "robot", "sculpture",
    "spaceship", "statue", "studio", "video game", "wizard"
]

FLUX_CONCEPTS = [
    "alien",
    "ancient ruins",
    "animal",
    "bike",
    "car",
    "Citadel",
    "face",
    "futuristic cities",
    "mystical creature",
    "planet",
    "plant",
    "robot",
    "spaceship",
    "statue",
    "studio",
    "video game",
    "wizard"
]




model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "tianweiy/DMD2"
ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"


device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float32

# Load model.
unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet").to(device, torch_dtype)
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))
pipe = DiffusionPipeline.from_pretrained(model_repo_id, unet=unet, torch_dtype=torch_dtype).to(device)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch_dtype).to(device)

pipe = pipe.to(torch_dtype)
unet = pipe.unet

## Change these parameters based on how you trained your sliderspace sliders
train_method = 'xattn-strict'
rank = 1 
alpha =1 
networks = {}
modules = DEFAULT_TARGET_REPLACE
modules += UNET_TARGET_REPLACE_MODULE_CONV
for i in range(1):
    networks[i] = LoRANetwork(
        unet,
        rank=int(rank),
        multiplier=1.0,
        alpha=int(alpha),
        train_method=train_method,
        fast_init=True,
    ).to(device, dtype=torch_dtype)



MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


# base_model_id = "black-forest-labs/FLUX.1-schnell"
# max_sequence_length = 256
# flux_pipe = FluxPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
# flux_pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch_dtype)
# flux_pipe = flux_pipe.to(device).to(torch_dtype)
# # pipe.enable_sequential_cpu_offload()
# transformer = flux_pipe.transformer

# ## Change these parameters based on how you trained your sliderspace sliders
# train_method = 'flux-attn'
# rank = 1 
# alpha =1 

# flux_networks = {}
# modules = DEFAULT_TARGET_REPLACE
# modules += UNET_TARGET_REPLACE_MODULE_CONV
# for i in range(1):
#     flux_networks[i] = LoRANetwork(
#         transformer,
#         rank=int(rank),
#         multiplier=1.0,
#         alpha=int(alpha),
#         train_method=train_method,
#         fast_init=True,
#     ).to(device, dtype=torch_dtype)


def update_sliderspace_choices(model_choice):
    return gr.Dropdown(
        choices=SDXL_CONCEPTS if model_choice == "SDXL-DMD" else FLUX_CONCEPTS,
        label="SliderSpace Concept",
        value=SDXL_CONCEPTS[0] if model_choice == "SDXL-DMD" else FLUX_CONCEPTS[0]
    )

@spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    prompt,
    negative_prompt,
    seed,
    randomize_seed,
    width,
    height,
    guidance_scale,
    num_inference_steps,
    slider_space,
    discovered_directions,
    slider_scale,
    model_choice,
    progress=gr.Progress(track_tqdm=True),
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    if model_choice == 'SDXL-DMD':
        sliderspace_path = f"sliderspace_weights/{slider_space}/slider_{int(discovered_directions.split(' ')[-1])-1}.pt"
        
        for net in networks:
            networks[net].load_state_dict(torch.load(sliderspace_path))
            networks[net].set_lora_slider(-1*slider_scale)
        with networks[0]:
            pass

        # original image
        generator = torch.Generator().manual_seed(seed)
        image = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                width=width,
                height=height,
                generator=generator,
            ).images[0]
    
        # edited image
        generator = torch.Generator().manual_seed(seed)
        with networks[0]:
            slider_image = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                width=width,
                height=height,
                generator=generator,
            ).images[0]
    # else:
    #     sliderspace_path = f"flux_sliderspace_weights/{slider_space}/slider_{int(discovered_directions.split(' ')[-1])-1}.pt"
    #     for net in flux_networks:
    #         flux_networks[net].load_state_dict(torch.load(sliderspace_path))
    #         flux_networks[net].set_lora_slider(-1*slider_scale)
    #     with flux_networks[0]:
    #         pass

    #     # original image
    #     generator = torch.Generator().manual_seed(seed)
    #     image = flux_pipe(
    #             prompt=prompt,
    #             guidance_scale=guidance_scale,
    #             num_inference_steps=num_inference_steps,
    #             width=width,
    #             height=height,
    #             generator=generator,
    #             max_sequence_length = 256,
    #         ).images[0]
    
    #     # edited image
    #     generator = torch.Generator().manual_seed(seed)
    #     with flux_networks[0]:
    #         slider_image = flux_pipe(
    #             prompt=prompt,
    #             guidance_scale=guidance_scale,
    #             num_inference_steps=num_inference_steps,
    #             width=width,
    #             height=height,
    #             generator=generator,
    #             max_sequence_length = 256,
    #         ).images[0]
    
    return image, slider_image, seed


examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 800px;

body, html {
    height: 100%;
    overflow-y: auto;
}

.gradio-container {
    overflow-y: auto;
    max-height: 100vh;
}
}
"""

ORIGINAL_SPACE_ID = 'baulab/SliderSpace'
SPACE_ID = os.getenv('SPACE_ID')

SHARED_UI_WARNING = f'''## You can duplicate and use it with a gpu with at least 24GB, or clone this repository to run on your own machine.
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
'''
# Simple instructions (original)
simple_instructions = """ User Guide: We released sliderspace directions for a few concepts. To test them: <br> 
- Choose the concept from the `SliderSpace trained on` menu. (eg. "Wizard") 
- Select one of the many directions we discovered for that concept from the `Discovered Directions` menu. (eg. "Direction 3") 
- Choose a slider weight between -1 to 1 (recommended). But feel free to explore the scale (eg. +1.5) 
- Finally, write a prompt for the concept you chose. (eg. "picture of a wizard fighting a magical war") 
- Click Generate and discover what the direction is !! """

# Define a function to toggle instructions that correctly handles both states
def toggle_instructions(is_detailed):
    # Toggle the state
    is_detailed = not is_detailed
    
    if is_detailed:
        return detailed_instructions, "Show Simple Instructions", is_detailed
    else:
        return simple_instructions, "Show Detailed Instructions", is_detailed

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(" # SliderSpace: Decomposing Visual Capabilities of Diffusion Models")
        # Adding links under the title
        gr.Markdown("""
        🔗 [Project Page](https://sliderspace.baulab.info) | 
        💻 [GitHub Code](https://github.com/rohitgandikota/sliderspace)
        """)

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )

            run_button = gr.Button("Run", scale=0, variant="primary")

        # Add model selection dropdown
        model_choice = gr.Dropdown(
            choices=["SDXL-DMD", "SDXL-DMD"],
            label="Model",
            value="SDXL-DMD"
        )
        # New dropdowns side by side
        with gr.Row():
            slider_space = gr.Dropdown(
                choices=SDXL_CONCEPTS,
                label="SliderSpace trained on",
                value=SDXL_CONCEPTS[0]
            )
            discovered_directions = gr.Dropdown(
                choices=[f"Slider {i}" for i in range(1, 11)],
                label="Discovered Directions",
                value="Slider 1"
            )

            slider_scale =  gr.Slider(
                    label="Slider Scale",
                    minimum=-4,
                    maximum=4,
                    step=0.1,
                    value=1,  
                )
        
        with gr.Row():
            result = gr.Image(label="Original Image", show_label=True)
            slider_result = gr.Image(label="Discovered Edit Direction", show_label=True)
        

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                visible=False,
            )

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,  # Replace with defaults that work for your model
                )

                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,  # Replace with defaults that work for your model
                )

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=2.0,
                    step=0.1,
                    value=0.0,  # Replace with defaults that work for your model
                )

                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=4,  # Replace with defaults that work for your model
                )
        with gr.Row():
            gr.Markdown(simple_instructions)
    # Add event handler for model selection
    model_choice.change(
        fn=update_sliderspace_choices,
        inputs=[model_choice],
        outputs=[slider_space]
    )
        # gr.Examples(examples=examples, inputs=[prompt])
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
            slider_space,
            discovered_directions,
            slider_scale,
            model_choice
        ],
        outputs=[result, slider_result, seed],
    )

if __name__ == "__main__":
    demo.launch(share=True)

















# import gradio as gr
# import numpy as np
# import random
# import os
# import spaces #[uncomment to use ZeroGPU]
# from diffusers import DiffusionPipeline
# import torch
# from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
# from huggingface_hub import hf_hub_download
# from safetensors.torch import load_file
# import sys
# sys.path.append('.')
# from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV

# model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
# repo_name = "tianweiy/DMD2"
# ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"


# device = "cuda" if torch.cuda.is_available() else "cpu"
# if torch.cuda.is_available():
#     torch_dtype = torch.bfloat16
# else:
#     torch_dtype = torch.float32

# # Load model.
# unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet").to(device, torch_dtype)
# unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))
# pipe = DiffusionPipeline.from_pretrained(model_repo_id, unet=unet, torch_dtype=torch_dtype).to(device)
# pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)


# unet = pipe.unet

# ## Change these parameters based on how you trained your sliderspace sliders
# train_method = 'xattn-strict'
# rank = 1 
# alpha =1 
# networks = {}
# modules = DEFAULT_TARGET_REPLACE
# modules += UNET_TARGET_REPLACE_MODULE_CONV
# for i in range(1):
#     networks[i] = LoRANetwork(
#         unet,
#         rank=int(rank),
#         multiplier=1.0,
#         alpha=int(alpha),
#         train_method=train_method,
#         fast_init=True,
#     ).to(device, dtype=torch_dtype)



# MAX_SEED = np.iinfo(np.int32).max
# MAX_IMAGE_SIZE = 1024


# @spaces.GPU #[uncomment to use ZeroGPU]
# def infer(
#     prompt,
#     negative_prompt,
#     seed,
#     randomize_seed,
#     width,
#     height,
#     guidance_scale,
#     num_inference_steps,
#     slider_space,
#     discovered_directions,
#     slider_scale,
#     progress=gr.Progress(track_tqdm=True),
# ):
#     if randomize_seed:
#         seed = random.randint(0, MAX_SEED)

#     sliderspace_path = f"sliderspace_weights/{slider_space}/slider_{int(discovered_directions.split(' ')[-1])-1}.pt"
    
#     for net in networks:
#         networks[net].load_state_dict(torch.load(sliderspace_path))

#     for net in networks:
#         networks[net].set_lora_slider(slider_scale)

#     with networks[0]:
#         pass
    
#     # original image
#     generator = torch.Generator().manual_seed(seed)
#     image = pipe(
#             prompt=prompt,
#             negative_prompt=negative_prompt,
#             guidance_scale=guidance_scale,
#             num_inference_steps=num_inference_steps,
#             width=width,
#             height=height,
#             generator=generator,
#         ).images[0]

#     # edited image
#     generator = torch.Generator().manual_seed(seed)
#     with  networks[0]:
#         slider_image = pipe(
#             prompt=prompt,
#             negative_prompt=negative_prompt,
#             guidance_scale=guidance_scale,
#             num_inference_steps=num_inference_steps,
#             width=width,
#             height=height,
#             generator=generator,
#         ).images[0]

    
#     return image, slider_image, seed


# examples = [
#     "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
#     "An astronaut riding a green horse",
#     "A delicious ceviche cheesecake slice",
# ]

# css = """
# #col-container {
#     margin: 0 auto;
#     max-width: 640px;
# }
# """

# ORIGINAL_SPACE_ID = 'baulab/SliderSpace'
# SPACE_ID = os.getenv('SPACE_ID')

# SHARED_UI_WARNING = f'''## You can duplicate and use it with a gpu with at least 24GB, or clone this repository to run on your own machine.
# <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
# '''

# with gr.Blocks(css=css) as demo:
#     with gr.Column(elem_id="col-container"):
#         gr.Markdown(" # SliderSpace: Decomposing Visual Capabilities of Diffusion Models")
#         # Adding links under the title
#         gr.Markdown("""
#         🔗 [Project Page](https://sliderspace.baulab.info) | 
#         💻 [GitHub Code](https://github.com/rohitgandikota/sliderspace)
#         """)

#         with gr.Row():
#             prompt = gr.Text(
#                 label="Prompt",
#                 show_label=False,
#                 max_lines=1,
#                 placeholder="Enter your prompt",
#                 container=False,
#             )

#             run_button = gr.Button("Run", scale=0, variant="primary")


#         # New dropdowns side by side
#         with gr.Row():
#             slider_space = gr.Dropdown(
#                 choices= [
#                             "alien",
#                             "ancient ruins",
#                             "animal",
#                             "bike",
#                             "car",
#                             "Citadel",
#                             "coral",
#                             "cowboy",
#                             "face",
#                             "futuristic cities",
#                             "monster",
#                             "mystical creature",
#                             "planet",
#                             "plant",
#                             "robot",
#                             "sculpture",
#                             "spaceship",
#                             "statue",
#                             "studio",
#                             "video game",
#                             "wizard"
#                         ],
#                 label="SliderSpace",
#                 value="spaceship"
#             )
#             discovered_directions = gr.Dropdown(
#                 choices=[f"Slider {i}" for i in range(1, 11)],
#                 label="Discovered Directions",
#                 value="Slider 1"
#             )

#             slider_scale =  gr.Slider(
#                     label="Slider Scale",
#                     minimum=-4,
#                     maximum=4,
#                     step=0.1,
#                     value=1,  
#                 )
        
#         with gr.Row():
#             result = gr.Image(label="Original Image", show_label=True)
#             slider_result = gr.Image(label="Discovered Edit Direction", show_label=True)
        

#         with gr.Accordion("Advanced Settings", open=False):
#             negative_prompt = gr.Text(
#                 label="Negative prompt",
#                 max_lines=1,
#                 placeholder="Enter a negative prompt",
#                 visible=False,
#             )

#             seed = gr.Slider(
#                 label="Seed",
#                 minimum=0,
#                 maximum=MAX_SEED,
#                 step=1,
#                 value=0,
#             )

#             randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

#             with gr.Row():
#                 width = gr.Slider(
#                     label="Width",
#                     minimum=256,
#                     maximum=MAX_IMAGE_SIZE,
#                     step=32,
#                     value=1024,  # Replace with defaults that work for your model
#                 )

#                 height = gr.Slider(
#                     label="Height",
#                     minimum=256,
#                     maximum=MAX_IMAGE_SIZE,
#                     step=32,
#                     value=1024,  # Replace with defaults that work for your model
#                 )

#             with gr.Row():
#                 guidance_scale = gr.Slider(
#                     label="Guidance scale",
#                     minimum=0.0,
#                     maximum=2.0,
#                     step=0.1,
#                     value=0.0,  # Replace with defaults that work for your model
#                 )

#                 num_inference_steps = gr.Slider(
#                     label="Number of inference steps",
#                     minimum=1,
#                     maximum=50,
#                     step=1,
#                     value=4,  # Replace with defaults that work for your model
#                 )

#         # gr.Examples(examples=examples, inputs=[prompt])
#     gr.on(
#         triggers=[run_button.click, prompt.submit],
#         fn=infer,
#         inputs=[
#             prompt,
#             negative_prompt,
#             seed,
#             randomize_seed,
#             width,
#             height,
#             guidance_scale,
#             num_inference_steps,
#             slider_space,
#             discovered_directions,
#             slider_scale
#         ],
#         outputs=[result, slider_result, seed],
#     )

# if __name__ == "__main__":
#     demo.launch(share=True)