################################################################################
# This files contains OSAIL utils to read and write files.
################################################################################

import copy
import monai as mn
import numpy as np
import os
import skimage

################################################################################
# -F: pad_to_square

def pad_to_square(image):
    """A function to pad an image to square shape with zero pixels.

    Args:
        image (np.ndarray): the input image array.

    Returns:
        np.ndarray: the padded image array.
    """
    height, width = image.shape
    if height < width: 
        padded_image = np.zeros((width, width))
        delta = (width - height) // 2
        padded_image[delta:height+delta, :] = image
        image = padded_image
    elif height > width:
        padded_image = np.zeros((height, height))
        delta = (height - width) // 2
        padded_image[:, delta:width+delta] = image
        image = padded_image
    return image

################################################################################
# -F: load_image

def load_image(input_object, pad=False, normalize=True, standardize=False, 
               dtype=np.float32, percentile_clip=None, target_shape=None, 
               transpose=False, ensure_grayscale=True, LoadImage_args=[], LoadImage_kwargs={}):
    """A helper function to load different input types.

    Args:
        input_object (Union[np.ndarray, str]): 
            a 2D NumPy array of X-ray an image, a DICOM file of an X-ray image, 
            or a string path to a .npy, any regular image file format 
            saved on disk that skimage.io can load.
        pad (bool, optional): whether to pad the image to square shape. 
            Defaults to True.
        normalize (bool, optional): whether to normalize the image. 
            Defaults to True.
        standardize (bool, optional): whether to standardize the image.
            Defaults to False.
        dtype (np.dtype, optional): the data type of the output image. 
            Defaults to np.float32.
        percentile_clip (float, optional): the percentile to clip the image. 
            Defaults to 2.5.
        target_shape (tuple, optional): the target shape of the output image. 
            Defaults to None, which means no resizing.
        transpose (bool, optional): whether to transpose the image.
            Defaults to False.
        ensure_grayscale (bool, optional): whether to make the image grayscale.
            Defaults to True.
        LoadImg_args: a list of keyword arguments to pass to  mn.transforms.LoadImage.
        LoadImg_kwargs: a dictionary of keyword arguments to pass to  mn.transforms.LoadImage.
            
    Returns:
        the loaded image array.
    """
    # Load the image.
    if isinstance(input_object, np.ndarray):
        image = input_object
    elif isinstance(input_object, str):
        assert os.path.exists(input_object), f"File not found: {input_object}"
        reader = mn.transforms.LoadImage(image_only=True, *LoadImage_args, **LoadImage_kwargs)
        image = reader(input_object)

    # Make the image 2D.
    if ensure_grayscale:
        if image.shape[-1] == 3:
            image = np.mean(image, axis=-1)  
        elif image.shape[0] == 3:
            image = np.mean(image, axis=0)
        elif image.shape[-1] == 4:
            image = np.mean(image[...,:3], axis=-1)  
        elif image.shape[0] == 4:
            image = np.mean(image[:3,...], axis=0)  
        assert len(image.shape) == 2, f"Image must be 2D: {image.shape}"
    
    # Transpose the image.
    if transpose:
        image = np.transpose(image, axes=(1,0))
    
    # Clip the image.
    if percentile_clip is not None:
        percentile_low = np.percentile(image, percentile_clip)
        percentile_high = np.percentile(image, 100-percentile_clip)
        image = np.clip(image, percentile_low, percentile_high)
        
    # Standardize the image.
    if standardize:
        image = image.astype(np.float32)
        image -= image.mean()
        image /= (image.std() + 1e-8)
        
    # Normalize the image.
    if normalize:
        image = image.astype(np.float32)
        image -= image.min()
        image /= (image.max() + 1e-8)
    
    # Pad the image to square shape.
    if pad:
        image = pad_to_square(image)   
    
    # Resize the image.
    if target_shape is not None:
        image = skimage.transform.resize(image, target_shape, preserve_range=True)
        
    # Cast the image to the target data type.
    if dtype is np.uint8:
        image = (image * 255).astype(np.uint8)
    else:
        image = image.astype(dtype)  
    
    return image

################################################################################
# -C: LoadImageD

class LoadImageD(mn.transforms.Transform):
    """A MONAI transform to load input image using load_image function.
    """
    def __init__(self, keys, *to_pass_keys, **to_pass_kwargs) -> None:
        super().__init__()
        self.keys = keys
        self.to_pass_keys = to_pass_keys
        self.to_pass_kwargs = to_pass_kwargs
        
    def __call__(self, data):
        data_copy = copy.deepcopy(data)
        for key in self.keys:
            data_copy[key] = load_image(data[key], *self.to_pass_keys, **self.to_pass_kwargs)
        return data_copy