import spaces
import gradio as gr
import subprocess
from PIL import Image,ImageEnhance,ImageFilter
import json
import numpy as np
from skimage.exposure  import match_histograms

import mp_box
'''
Face landmark detection based Face Detection.
https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker
from model card
https://storage.googleapis.com/mediapipe-assets/MediaPipe%20BlazeFace%20Model%20Card%20(Short%20Range).pdf
Licensed Apache License, Version 2.0
Train with google's dataset(more detail see model card)

Not Face Detector based 
https://ai.google.dev/edge/mediapipe/solutions/vision/face_detector

Bacause this is part of getting-landmark program and need control face edge.
So I don't know which one is better.never compare these.
'''

def color_match(base_image,cropped_image):
    reference = np.array(base_image)
    target =np.array(cropped_image)
    matched = match_histograms(target, reference,channel_axis=-1)
    return Image.fromarray(matched)

def select_box(boxes,box_type):
    if box_type == "type-3":
         box = boxes[2]
    elif box_type =="type-2":
         box = boxes[1]
    elif box_type =="type-1":
         box = boxes[0]
    else:#never happen
         box=[0,0,image.size[0],image.size[1]]
    box_width = box[2]
    box_height = box[3]
    box = mp_box.xywh_to_xyxy(box)
    return box,box_width,box_height

def resize_image_in_box(image,box_width,box_height,keep_aspect=True,resampling=None):

    aspect_ratio = image.width / image.height
    if box_width / box_height >= aspect_ratio:
        new_width = int(box_height * aspect_ratio)
        new_height = box_height
    else:
        new_width = box_width
        new_height = int(box_width / aspect_ratio)


    if resampling == None:#automatic
        image_area = image.width * image.height
        if box_width * box_height > image_area:
            resampling = Image.Resampling.BICUBIC
        else:
            resampling = Image.Resampling.LANCZOS

    resized = image.resize((new_width,new_height),resampling)
    offset_x = int((box_width -new_width)/2)
    offset_y = int((box_height -new_height)/2)
    return resized,offset_x,offset_y


def process_images(image,replace_image=None,replace_image_need_crop=False,box_type="type-3",fill_color_mode=False,fill_color="black",custom_color="rgba(255,255,255,1)",image_size=1024,margin_percent=0,filter_value="None",match_color=True,progress=gr.Progress(track_tqdm=True)):
    
    if image == None:
        raise gr.Error("Need Image")

    # choose box
    image_width,image_height = image.size
    boxes,mp_image,face_landmarker_result = mp_box.mediapipe_to_box(image)
    box,box_width,box_height = select_box(boxes,box_type)

    # replace-mode
    if replace_image!=None:
        print("replace mode")

        if replace_image_need_crop:
            replace_boxes,mp_image,face_landmarker_result = mp_box.mediapipe_to_box(replace_image)
            replace_box,replace_box_width,replace_box_height = select_box(replace_boxes,box_type)

        keep_aspect = True
        # this is for fill_color_mode exported image
        if fill_color_mode:
            if replace_image_need_crop:
                cropped = replace_image.crop(replace_box)
                cropped,off_x,off_y = resize_image_in_box(cropped,box_width,box_height,keep_aspect)
            else:
                cropped = replace_image.crop(box)
                off_x = int((box_width -cropped.width)/2)
                off_y = int((box_height -cropped.height)/2)
            
            if match_color:
                cropped = color_match(image.crop(box),cropped)
            #just paste base-face area
            image.paste(cropped,[box[0]+off_x,box[1]+off_y])
            return image
        else:#scale mode
                # box expand by margin
            if margin_percent>0:
                h_margin = int(box_width*margin_percent/100)
                v_margin = int(box_height*margin_percent/100)

                box[0] = max(0,box[0]-h_margin)
                box[1] = max(0,box[1]-v_margin)
                box[2] = min(image_width-1,box[2]+h_margin)
                box[3] = min(image_height-1,box[3]+v_margin)
                box_width = box[2]-box[0]
                box_height = box[3]-box[1]

            if replace_image_need_crop:
                replace_image = replace_image.crop(replace_box)

            
            
            replace_resized,off_x,off_y = resize_image_in_box(replace_image,box_width,box_height,keep_aspect)
            if match_color:
                replace_resized = color_match(image.crop(box),replace_resized)
            
            image.paste(replace_resized,[box[0]+off_x,box[1]+off_y])
            return image


    # box expand by margin
    if margin_percent>0:
        h_margin = int(box_width*margin_percent/100)
        v_margin = int(box_height*margin_percent/100)

        box[0] = max(0,box[0]-h_margin)
        box[1] = max(0,box[1]-v_margin)
        box[2] = min(image_width-1,box[2]+h_margin)
        box[3] = min(image_height-1,box[3]+v_margin)
    
    # crop-mode
    if fill_color_mode:
        # choose color
        color_map={
            "black":[0,0,0,1],
            "white":[255,255,255,1],
            "gray":[127,127,127,1],
            "red":[255,0,0,1],
            "brown":[92,33,31,1],
            "pink":[255,192,203,1],
        }
        if fill_color == "custom":
            color_value = custom_color.strip("rgba()").split(",")
            color_value[0] = int(float(color_value[0]))
            color_value[1] = int(float(color_value[1]))
            color_value[2] = int(float(color_value[2]))
        else:
            color_value = color_map[fill_color]

        cropped = image.crop(box)

        img = Image.new('RGBA', image.size, (color_value[0], color_value[1], color_value[2]))
        img.paste(cropped,[box[0],box[1]])
        return img
    else:
        #scale up mode
        cropped = image.crop(box)
        resized = resize_image_by_max_dimension(cropped,image_size)
        
        filter_map={
           "None":None,
           "Blur":ImageFilter.BLUR,"Smooth More":ImageFilter.SMOOTH_MORE,"Smooth":ImageFilter.SMOOTH,"Sharpen":ImageFilter.SHARPEN,"Edge Enhance":ImageFilter.EDGE_ENHANCE,"Edge Enhance More":ImageFilter.EDGE_ENHANCE_MORE
        }

        if filter_value not in filter_map:
            raise gr.Error(f"filter {filter_value} not found")
        
        if filter_value != "None":
           
            #resized = resized.filter(ImageFilter.SHARPEN)
            #Gimp's weak 0.1-0.2?
            enhancer = ImageEnhance.Sharpness(resized)
            resized = resized.filter(filter_map[filter_value])
            #resized = enhancer.enhance(sharpen_value)


        return resized


def resize_image_by_max_dimension(image, max_size, resampling=Image.Resampling.BICUBIC):
    image_width, image_height = image.size
    
    max_dimension = max(image_width, image_height)
    
    ratio = max_size / max_dimension
    
    new_width = int(image_width * ratio)
    new_height = int(image_height * ratio)
 
    return image.resize((new_width, new_height), resampling)
    

def read_file(file_path: str) -> str:
    """read the text of target file
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()

    return content

css="""
#col-left {
    margin: 0 auto;
    max-width: 640px;
}
#col-right {
    margin: 0 auto;
    max-width: 640px;
}
.grid-container {
  display: flex;
  align-items: center;
  justify-content: center;
  gap:10px
}

.image {
  width: 128px; 
  height: 128px; 
  object-fit: cover; 
}

.text {
  font-size: 16px;
}
"""

#css=css,
def update_button_label(image):
    if image == None:
        print("none replace")
        return  gr.Button(visible=True),gr.Button(visible=False),gr.Row(visible=True),gr.Row(visible=True)
    else:
        return  gr.Button(visible=False),gr.Button(visible=True),gr.Row(visible=False),gr.Row(visible=False)
    
def update_visible(fill_color_mode,image):
     if image != None:
          return  gr.Row(visible=False),gr.Row(visible=False)
     
     if fill_color_mode:
        return  gr.Row(visible=False),gr.Row(visible=True)
     else:
        return  gr.Row(visible=True),gr.Row(visible=False)
     
with gr.Blocks(css=css, elem_id="demo-container") as demo:
    with gr.Column():
        gr.HTML(read_file("demo_header.html"))
        gr.HTML(read_file("demo_tools.html"))
    with gr.Row():
                with gr.Column():
                    image = gr.Image(sources=['upload','clipboard'],image_mode='RGB',elem_id="image_upload", type="pil", label="Upload")
                    box_type = gr.Dropdown(label="box-type",value="type-3",choices=["type-1","type-2","type-3"])
                    with gr.Row(elem_id="prompt-container",  equal_height=False):
                        with gr.Row():
                            btn1 = gr.Button("Face Crop", elem_id="run_button",variant="primary")
                            btn2 = gr.Button("Face Replace", elem_id="run_button2",variant="primary",visible=False)
                            
                    replace_image = gr.Image(sources=['upload','clipboard'],image_mode='RGB',elem_id="replace_upload", type="pil", label="replace image")
                    replace_image_need_crop = gr.Checkbox(label="Replace image need crop",value=False)
                    
                    with gr.Accordion(label="Advanced Settings", open=False):
                        with gr.Row(equal_height=True):
                            fill_color_mode = gr.Checkbox(label="Fill Color Mode/No Resize",value=False)
                            match_color = gr.Checkbox(label="Match Color",value=True,info="skimage match_histograms")
                            margin_percent = gr.Slider(
                                label="Margin percent",info = "add extra space",
                                minimum=0,
                                maximum=200,
                                step=1,
                                value=0,
                                interactive=True)
                        
                        
                        row1 = gr.Row(equal_height=True)
                        row2 = gr.Row(equal_height=True,visible=False)
                        fill_color_mode.change(update_visible,[fill_color_mode,replace_image],[row1,row2])
                        with row1:
                            image_size = gr.Slider(
                            label="Image Size",info = "cropped face size",
                            minimum=8,
                            maximum=2048,
                            step=1,
                            value=1024,
                            interactive=True)

                            

                            #filter_image = gr.Checkbox(label="Filter image")
                            filter_value = gr.Dropdown(label="Filter",value="None",choices=["Blur","Smooth More","Smooth","None","Sharpen","Edge Enhance","Edge Enhance More"])
                        with row2:
                            
                            fill_color = gr.Dropdown(label="fill color",choices=["black","white","gray","red","brown","pink","custom"],value="gray")
                            custom_color = gr.ColorPicker(label="custom color",value="rgba(250, 218, 205, 1)")
                            
                    replace_image.change(update_button_label,replace_image,[btn1,btn2,row1,row2])#margin_percent
                with gr.Column():
                    image_out = gr.Image(label="Output", elem_id="output-img")
                    
                    
                    
            

    
    gr.on(
        [btn1.click,btn2.click],
        fn=process_images, inputs=[image,replace_image,replace_image_need_crop,box_type,fill_color_mode,fill_color,custom_color,image_size,margin_percent,filter_value,match_color], outputs =[image_out], api_name='infer'
    )
    gr.Examples(
                examples =["examples/00004200.jpg","examples/00003245_00.jpg","examples/00005259.jpg","examples/00018022.jpg","examples/img-above.jpg","examples/img-below.jpg","examples/img-side.jpg"],
                inputs=[image]
    )
    gr.HTML(read_file("demo_footer.html"))

    if __name__ == "__main__":
        demo.launch()