import os
import random
import logging
import gradio as gr
from PIL import Image
from zipfile import ZipFile
from typing import Any, Dict,List
from transformers import pipeline

class Image_classification:
  def __init__(self):
    pass
    

  def unzip_image_data(self) -> str:
    """
    Unzips an image dataset into a specified directory.

    Returns:
        str: The path to the directory containing the extracted image files.
    """
    try:
      with ZipFile("image_dataset.zip","r") as extract:
        
        directory_path=str("dataset")
        os.mkdir(directory_path)
        extract.extractall(f"{directory_path}")
        return f"{directory_path}"

    except Exception as e:
            logging.error(f"An error occurred during extraction: {e}")
            return ""

  def example_images(self) -> List[str]:
    """
    Unzips the image dataset and generates a list of paths to the individual image files and use image for showing example

    Returns:
        List[str]: A list of file paths to each image in the dataset.
    """
    try:
      image_dataset_folder = self.unzip_image_data()
      image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
      image_count = len([name for name in os.listdir(image_dataset_folder) if os.path.isfile(os.path.join(image_dataset_folder, name)) and os.path.splitext(name)[1].lower() in image_extensions])
      example=[]
      for i in range(image_count):
        for name in os.listdir(image_dataset_folder):
            path=(os.path.join(os.path.dirname(image_dataset_folder),os.path.join(image_dataset_folder,name)))
            example.append(path)
      return example

    except Exception as e:
            logging.error(f"An error occurred in  example images: {e}")
            return ""

  def classify(self, image: Image.Image, model: Any) -> Dict[str, float]:
    """
    Classifies an image using a specified model.

    Args:
        image (Image.Image): The image to classify.
        model (Any): The model used for classification.

    Returns:
        Dict[str, float]: A dictionary of classification labels and their corresponding scores.
    """
    try:
      
      classifier = pipeline("image-classification", model=model)
      result= classifier(image)
      return result
    except Exception as e:
            logging.error(f"An error occurred during image classification: {e}")
            raise

  def format_the_result(self, image: Image.Image, model: Any) -> Dict[str, float]:
    """
    Formats the classification result by retaining the highest score for each label.

    Args:
        image (Image.Image): The image to classify.
        model (Any): The model used for classification.

    Returns:
        Dict[str, float]: A dictionary with unique labels and the highest score for each label.
    """
    try:
      data=self.classify(image,model)
      new_dict = {}
      for item in data:
          label = item['label']
          score = item['score']

          if label in new_dict:
              if new_dict[label] < score:
                  new_dict[label] = score
          else:
              new_dict[label] = score
      return new_dict
    except Exception as e:
      logging.error(f"An error occurred while formatting the results: {e}")
      raise

  def interface(self):

    with gr.Blocks(css="""
    
    .gradio-container {background: #314755;
       background: -webkit-linear-gradient(to right, #26a0da, #314755);
       background: linear-gradient(to right, #26a0da, #314755);}
       .block svelte-90oupt padded{background:314755;
       margin:0;
    padding:0;}""") as demo:

      gr.HTML("""
            <center><h1 style="color:#fff">Image Classification</h1></center>""")

      exam_img=self.example_images()
      with gr.Row():
        model = gr.Dropdown(["facebook/regnet-x-040","google/vit-large-patch16-384","microsoft/resnet-50",""],label="Choose a model")
      with gr.Row():
        image = gr.Image(type="filepath",sources="upload")
        with gr.Column():
          output=gr.Label()
      with gr.Row():
        button=gr.Button()
      button.click(self.format_the_result,[image,model],output)
      gr.Examples(
        examples=exam_img,
        inputs=[image],
        outputs=output,
        fn=self.format_the_result,
        cache_examples=False,
    )
    demo.launch(debug=True)

if __name__=="__main__":
  
  image_classification=Image_classification()
  result=image_classification.interface()