import os
import gradio as gr
import numpy as np
import skimage
from skimage import io
import torch
import monai
from monai.transforms import Rotate

# Placeholder for the 3D reconstruction model
class Simple3DReconstructionModel:
    def __init__(self):
        # Load your pre-trained model here
        self.model = None  # replace with actual model loading

    def reconstruct_3d(self, image):
        # Implement the 3D reconstruction logic here
        # This is a placeholder example
        return np.zeros((128, 128, 128))

    def rotate_3d(self, volume, angles):
        # Rotate the 3D volume using MONAI
        rotate = Rotate(angles, mode='bilinear')
        rotated_volume = rotate(volume)
        return rotated_volume

    def project_2d(self, volume):
        # Project the 3D volume back to 2D
        # This is a placeholder example
        projection = np.max(volume, axis=0)
        return projection

# Initialize the model
model = Simple3DReconstructionModel()

# Gradio helper functions

def process_image(img, xt, yt, zt):
    # Reconstruct the 3D volume
    volume = model.reconstruct_3d(img)
    # Rotate the 3D volume
    rotated_volume = model.rotate_3d(volume, (xt, yt, zt))
    # Project the rotated volume back to 2D
    output_img = model.project_2d(rotated_volume)
    return output_img

def rotate_btn_fn(img, xt, yt, zt, add_bone_cmap=False):
    try:
        angles = (xt, yt, zt)
        print(f"Rotating with angles: {angles}")

        if isinstance(img, np.ndarray):
            input_img_path = "uploaded_image.png"
            skimage.io.imsave(input_img_path, img)
        elif isinstance(img, str) and os.path.exists(img):
            input_img_path = img
            img = skimage.io.imread(input_img_path)
        else:
            raise ValueError("Invalid input image")

        # Process the image with the model
        out_img = process_image(img, xt, yt, zt)

        if not add_bone_cmap:
            return out_img

        cmap = plt.get_cmap('bone')
        out_img = cmap(out_img)
        out_img = (out_img[..., :3] * 255).astype(np.uint8)
        return out_img

    except Exception as e:
        print(f"Error in rotate_btn_fn: {e}")
        return None

css_style = "./style.css"
callback = gr.CSVLogger()

with gr.Blocks(css=css_style, title="RadRotator") as app:
    gr.HTML("RadRotator: 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
    gr.HTML("Developed by:<br>Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles<br><a href='https://pouriarouzrokh.github.io/RadRotator'>[Our website]</a>, <a href='https://arxiv.org/abs/2404.13000'>[arXiv Paper]</a>", elem_classes="note")
    gr.HTML("Note: The demo operates on a CPU, and since diffusion models require more computational capacity to function, all predictions are precomputed.", elem_classes="note")
    
    with gr.TabItem("Demo"):
        with gr.Row():
            input_img = gr.Image(type='numpy', label='Input image', interactive=True, elem_classes='imgs')
            output_img = gr.Image(type='numpy', label='Output image', interactive=False, elem_classes='imgs')
        with gr.Row():
            with gr.Column(scale=0.25):
                pass
            with gr.Column(scale=1):
                gr.Examples(
                    examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], 
                    inputs = [input_img],
                    label = "Xray Examples",
                    elem_id='examples',
                )
            with gr.Column(scale=0.25):
                pass
        with gr.Row():
            gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
        with gr.Row():
            with gr.Column(scale=1):
                xt = gr.Slider(label='x axis (medial/lateral rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
            with gr.Column(scale=1):
                yt = gr.Slider(label='y axis (inlet/outlet rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
            with gr.Column(scale=1):
                zt = gr.Slider(label='z axis (plane rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
        with gr.Row():
            rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
        rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
        
try:
    app.close()
    gr.close_all()
except Exception as e:
    print(f"Error closing app: {e}")

demo = app.launch(
    max_threads=4,
    share=True,
    inline=False,
    show_api=False,
    show_error=False,
)