import time
import requests
from io import BytesIO
from os import path
from torch.utils.data import Dataset
from PIL import Image

class TestImageSetOnline(Dataset):
    """ Test Image set with hugging face CLIP preprocess interface

    Args:
        Dataset (torch.utils.data.Dataset): 
    """
    def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2):
        """
        Args:
            processor (CLIP preprocessor): process data to a CLIP digestable format
            image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata
            timeout_base (float, optional): initial timeout parameter. Defaults to 0.5.
            timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2.
        """
        self.image_list = image_list
        self.processor = processor
        self.timeout_base = timeout_base
        self.timeout = self.timeout_base
        self.timeout_mul = timeout_mul
    
    def __getitem__(self, index):
        row = self.image_list[index]
        url = str(row['coco_url'])
        _id = str(row['id'])
        txt, img = None, None
        flag = True
        while flag:
            try:
                # Get images online
                response = requests.get(url)
                img = Image.open(BytesIO(response.content))
                img_s = img.size
                if img.mode in ['L', 'CMYK', 'RGBA']:
                    # L is grayscale, CMYK uses alternative color channels
                    img = img.convert('RGB')
                # Preprocess image
                ret = self.processor(text=txt, images=img, return_tensor='pt')
                img = ret['pixel_values'][0]
                # If success, then there will be no need to run this again
                flag = False
                # Relief the timeout param
                if self.timeout > self.timeout_base:
                    self.timeout /= self.timeout_mul
            except Exception as e:
                print(f"{_id} {url}: {str(e)}")
                if type(e) is KeyboardInterrupt:
                    raise e
                time.sleep(self.timeout)
                # Tension the timeout param and turn into a new request
                self.timeout *= self.timeout_mul
        return _id, url, img, img_s
   
    def get(self, url):
        _id = url
        txt, img = None, None
        flag = True
        while flag:
            try:
                # Get images online
                response = requests.get(url)
                img = Image.open(BytesIO(response.content))
                img_s = img.size
                if img.mode in ['L', 'CMYK', 'RGBA']:
                    # L is grayscale, CMYK uses alternative color channels
                    img = img.convert('RGB')
                # Preprocess image
                ret = self.processor(text=txt, images=img, return_tensor='pt')
                img = ret['pixel_values'][0]
                # If success, then there will be no need to run this again
                flag = False
                # Relief the timeout param
                if self.timeout > self.timeout_base:
                    self.timeout /= self.timeout_mul
            except Exception as e:
                print(f"{_id} {url}: {str(e)}")
                if type(e) is KeyboardInterrupt:
                    raise e
                time.sleep(self.timeout)
                # Tension the timeout param and turn into a new request
                self.timeout *= self.timeout_mul
        return _id, url, img, img_s
    
    
    def __len__(self,):
        return len(self.image_list)
    
    def __add__(self, other):
        self.image_list += other.image_list
        return self
    
class TestImageSet(TestImageSetOnline):
    def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2):
        super().__init__(processor, image_list, timeout_base, timeout_mul)
        self.droot = droot
    
    def __getitem__(self, index):
        row = self.image_list[index]
        url = str(row['coco_url'])
        _id = '_'.join([url.split('/')[-2], str(row['id'])])
        txt, img = None, None
        # Get images online
        img = Image.open(path.join(self.droot,
                                   url.split('http://images.cocodataset.org/')[1]))
        img_s = img.size
        if img.mode in ['L', 'CMYK', 'RGBA']:
            # L is grayscale, CMYK uses alternative color channels
            img = img.convert('RGB')
        # Preprocess image
        ret = self.processor(text=txt, images=img, return_tensor='pt')
        img = ret['pixel_values'][0]
        # If success, then there will be no need to run this again
        return _id, url, img, img_s