import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import random, string


class JpegTest(nn.Module):
	def __init__(self, Q=50, subsample=0, path="temp/"):
		super(JpegTest, self).__init__()
		self.Q = Q
		self.subsample = subsample
		self.path = path
		if not os.path.exists(path): os.mkdir(path)
		self.transform = transforms.Compose([
			transforms.ToTensor(),
			# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
		])

	def get_path(self):
		return self.path + ''.join(random.sample(string.ascii_letters + string.digits, 16)) + ".jpg"

	def forward(self, image_cover_mask):
		image = image_cover_mask

		noised_image = torch.zeros_like(image)

		for i in range(image.shape[0]):
			single_image = ((image[i].clamp(0, 1).permute(1, 2, 0)) * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
			im = Image.fromarray(single_image)

			file = self.get_path()
			while os.path.exists(file):
				file = self.get_path()
			im.save(file, format="JPEG", quality=self.Q, subsampling=self.subsample)
			jpeg = np.array(Image.open(file), dtype=np.uint8)
			os.remove(file)

			noised_image[i] = self.transform(jpeg).unsqueeze(0).to(image.device)

		return noised_image