import os
import json
import csv

import gradio as gr
import random
import time

from collections import Counter
from numpy.random import choice
from datasets import load_dataset, Dataset

from PIL import PngImagePlugin, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
PngImagePlugin.MAX_TEXT_CHUNK = 1048576 * 10 # this is 10x the amount.
"""
This code is designed to read in the ImageNet 1K ILSVRC dataset from the Hugging Face Hub, 
then create a new version of this dataset with {percentage} lines with random labels based on the observed frequencies,
then upload this new version of the Hugging Face Hub, in the Data Composition organization:
https://huggingface.co/datasets/datacomp
"""

# The number of examples/instances in this dataset is copied from the model card:
# https://huggingface.co/datasets/ILSVRC/imagenet-1k
NUM_EXAMPLES = 1281167
DEV = False
FRACTIONS = [2, 4, 8, 16, 32, 64]
# Arbitrary small number of dataset examples to look at, only using in devv'ing.
DEV_AMOUNT = 10
if DEV:
    NUM_EXAMPLES = DEV_AMOUNT
    FRACTIONS = [2]
# Whether to read in the distribution over labels from an external text file.
READ_DISTRO = False
GATED_IMAGENET = os.environ.get("GATED_IMAGENET")


def create_subset_dataset(dataset, fraction_size):
    dataset = dataset.shuffle()
    num_samples = int(NUM_EXAMPLES / fraction_size)
    sampled_dataset = dataset.take(num_samples)   
    return sampled_dataset
    
def main(percentage=10):
    global randomize_subset
    # Just for timing how long this takes.
    start = time.time()

    percentage = float(percentage)
    
    if DEV:
        dataset = load_dataset("datacomp/imagenet-1k-random" + str(percentage), split="train", streaming=True,
                           trust_remote_code=True, token=GATED_IMAGENET).take(DEV_AMOUNT)
    else:
        dataset = load_dataset("datacomp/imagenet-1k-random" + str(percentage) + "-take2", split="train", streaming=True,
                           trust_remote_code=True, token=GATED_IMAGENET)

    for frac in FRACTIONS:
        sampled_dataset = create_subset_dataset(dataset, frac)

        # Upload the new version of the dataset (this will take awhile)
        if DEV:
            Dataset.from_generator(sampled_dataset.__iter__).push_to_hub(
                "datacomp/debug-imagenet-1k-random-" + str(percentage) + '-frac-1over' + str(frac), token=GATED_IMAGENET)
        else:
            Dataset.from_generator(sampled_dataset.__iter__).push_to_hub(
                "datacomp/imagenet-1k-random-" + str(percentage) + '-frac-1over' + str(frac) + "-take2", token=GATED_IMAGENET)


    end = time.time()
    print("That took %d seconds" % (end - start))


demo = gr.Interface(fn=main, inputs="text", outputs="text")
demo.launch()