# from PIL import Image
# import blobfile as bf
# from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoConfig,
    AutoTokenizer,
    default_data_collator,
    PreTrainedTokenizerFast,
    PreTrainedTokenizer,
)

# from datasets import load_dataset
import sys, os
import torch

# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
from collections import Counter, defaultdict
from functools import partial
from itertools import chain


def load_data_text(
    *,
    data_dir,
    batch_size,
    image_size,
    class_cond=False,
    deterministic=False,
    data_args=None,
    task_mode="roc",
    model=None,
    padding_mode="block",
    split="train",
    load_vocab=None,
):
    """
    For a dataset, create a generator over (images, kwargs) pairs.

    Each images is an NCHW float tensor, and the kwargs dict contains zero or
    more keys, each of which map to a batched Tensor of their own.
    The kwargs dict can be used for class labels, in which case the key is "y"
    and the values are integer tensors of class labels.

    :param data_dir: a dataset directory.
    :param batch_size: the batch size of each returned pair.
    :param image_size: the size to which images are resized.
    :param class_cond: if True, include a "y" key in returned dicts for class
                       label. If classes are not available and this is true, an
                       exception will be raised.
    :param deterministic: if True, yield results in a deterministic order.
    """
    print("hello loading text data. ")

    if data_args.experiment.startswith("random") and model is None:
        model = None
    # elif data_args.experiment.startswith('random') and model is not None:
    #     print('loading initialized random embeddings. ')

    if task_mode == "roc" or task_mode == "roc-aug":
        pass
        # training_data, model = get_corpus_rocstory(data_args, model, image_size,
        #                                     padding_mode=padding_mode, split=split,
        # load_vocab=load_vocab)
    elif task_mode == "simple-wiki":
        pass
        # training_data, model = get_corpus_rocstory(data_args, model, image_size,
        # padding_mode=padding_mode, split=split,
        # load_vocab=load_vocab)

    elif task_mode == "e2e-tgt":
        print("hello loading e2e-tgt. ")
        training_data, model = get_corpus_rocstory(
            data_args,
            model,
            image_size,
            padding_mode=padding_mode,
            split=split,
            load_vocab=load_vocab,
        )
    # elif task_mode == 'yelp':
    #     print('hello loading yelp ')
    #     training_data, model = get_corpus_rocstory(data_args, model, image_size,
    #                                         padding_mode=padding_mode, split=split,
    #                                         load_vocab=load_vocab)

    # elif task_mode == 'commonGen' or task_mode == 'commonGen-aug':
    #     print('hello loading common-gen ')
    #     training_data, model = get_corpus_rocstory(data_args, model, image_size,
    #                                         padding_mode=padding_mode, split=split,
    #                                         load_vocab=load_vocab)

    # elif task_mode == 'e2e':
    #     training_data, model = get_corpus_rocstory(data_args, model, image_size,
    #                                         padding_mode=padding_mode, split=split,
    #                                         load_vocab=load_vocab)

    # elif task_mode == 'book':
    #     tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    #     training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
    #                                           padding_mode=padding_mode, split=split,)

    if (
        data_args.modality
        in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"]
        and data_args.cache_mode == "no"
    ):
        pass  # dataset = TextDataset_NoCache(
        #     training_data,
        #     image_size,
        #     data_args,
        #     model_arch=data_args.model_arch,
        #     model_emb=model
        # )
    else:
        dataset = TextDataset(
            training_data,
            image_size,
            data_args,
            model_arch=data_args.model_arch,
        )

    if deterministic:

        pass  # data_loader = DataLoader(
        #     dataset,
        #     batch_size=batch_size,  # 20,
        #     drop_last=True,
        #     shuffle=False,
        #     num_workers=1,
        # )

    else:
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,  # 20,
            drop_last=True,
            shuffle=True,
            num_workers=1,
        )
    while True:
        yield from data_loader


def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
    result_train_lst = []
    group_lst = defaultdict(list)
    with torch.no_grad():
        for src_ids, input_ids in sentence_lst:
            tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
            tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids]
            input_ids = [0] + tokenized_ + [1]
            group_lst["word_ids"].append(input_ids)
            group_lst["src_ids"].append(tokenized_src)

        print(group_lst["word_ids"][:2])
        print("padding mode is pad")
        max_length = seqlen
        group_lst["word_ids"] = _collate_batch_helper(
            group_lst["word_ids"], vocab_dict["PAD"], max_length
        )
        max_src_length = max([len(xx) for xx in group_lst["src_ids"]])
        print(max_src_length, seqlen)
        max_src_length = min(seqlen, max_src_length)
        group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper(
            group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True
        )

        for input_ids, src_ids, src_mask in zip(
            group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"]
        ):
            if data_args.experiment.startswith("random"):
                hidden_state = model(torch.tensor(input_ids))
            elif data_args.experiment == "gpt2_pre_compress":
                input_ids2 = torch.tensor(input_ids).to(model.device)
                input_embs = model.transformer.wte(input_ids2)  # input_embs
                hidden_state = model.down_proj(input_embs)
                hidden_state = hidden_state * data_args.emb_scale_factor
            result_train_lst.append(
                {
                    "input_ids": input_ids,
                    "hidden_states": hidden_state.cpu().tolist(),
                    "src_ids": src_ids,
                    "src_mask": src_mask,
                }
            )

    return result_train_lst


def helper_tokenize_stream(
    sentence_lst,
    vocab_dict,
    model,
    seqlen,
    data_args,
    padding_mode,
):
    import psutil

    # Process.memory_info is expressed in bytes, so convert to megabytes
    print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
    from datasets import Dataset as Dataset2

    raw_datasets = Dataset2.from_dict({"text": sentence_lst})
    print(raw_datasets)
    print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

    def tokenize_function(examples):
        if isinstance(vocab_dict, dict):
            input_ids = [
                [0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1]
                for seq in examples["text"]
            ]
        elif isinstance(vocab_dict, PreTrainedTokenizerFast):
            examples["text"] = [" ".join(seq) for seq in examples["text"]]
            input_ids = vocab_dict(examples["text"], add_special_tokens=True)[
                "input_ids"
            ]
        result_dict = {"input_ids": input_ids}
        # clm input could be much much longer than block_size
        return result_dict

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=4,
        remove_columns=["text"],
        load_from_cache_file=True,
        desc="Running tokenizer on dataset",
    )
    print(tokenized_datasets)
    print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

    if padding_mode == "block":
        block_size = seqlen

        def group_texts(examples):
            concatenated_examples = {
                k: list(chain(*examples[k])) for k in examples.keys()
            }
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            if total_length >= block_size:
                total_length = (total_length // block_size) * block_size
            result = {
                k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
                for k, t in concatenated_examples.items()
            }
            result["labels"] = result["input_ids"].copy()
            return result

        lm_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=not data_args.overwrite_cache,
            desc=f"Grouping texts in chunks of {block_size}",
        )
    else:

        def pad_function(group_lst):
            max_length = seqlen
            if isinstance(vocab_dict, dict):
                group_lst["input_ids"] = _collate_batch_helper(
                    group_lst["input_ids"], vocab_dict["PAD"], max_length
                )
            else:
                group_lst["input_ids"] = _collate_batch_helper(
                    group_lst["input_ids"], vocab_dict.pad_token_id, max_length
                )
            return group_lst

        # Process.memory_info is expressed in bytes, so convert to megabytes
        print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

        lm_datasets = tokenized_datasets.map(
            pad_function,
            batched=True,
            num_proc=1,
            desc=f"padding",
        )

    print(lm_datasets, "padded dataset")
    print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
    import datasets

    raw_datasets = datasets.DatasetDict()
    raw_datasets["train"] = lm_datasets
    print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
    return raw_datasets


def helper_tokenize_encode(
    sentence_lst,
    vocab_dict,
    model,
    seqlen,
    data_args,
    padding_mode,
):
    result_train_lst = []
    group_lst = defaultdict(list)
    with torch.no_grad():
        for input_ids in sentence_lst:
            tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
            input_ids = [0] + tokenized_ + [1]
            group_lst["word_ids"].append(input_ids)
        print(group_lst["word_ids"][:2])

        if padding_mode == "block":
            print("padding mode is block")
            concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
            total_length = len(concatenated_examples[list(group_lst.keys())[0]])
            block_size = seqlen
            total_length = (total_length // block_size) * block_size
            # Split by chunks of max_len.
            group_lst = {
                k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
                for k, t in concatenated_examples.items()
            }
        elif padding_mode == "pad":
            print("padding mode is pad")
            max_length = seqlen
            group_lst["word_ids"] = _collate_batch_helper(
                group_lst["word_ids"], vocab_dict["PAD"], max_length
            )

        for input_ids in group_lst["word_ids"]:
            if data_args.experiment.startswith("random"):
                hidden_state = model(torch.tensor(input_ids))
            elif data_args.experiment == "gpt2_pre_compress":
                input_ids2 = torch.tensor(input_ids).to(model.device)
                input_embs = model.transformer.wte(input_ids2)  # input_embs
                hidden_state = model.down_proj(input_embs)
                hidden_state = hidden_state * data_args.emb_scale_factor
            elif data_args.experiment == "glove":
                hidden_state = model(torch.tensor(input_ids))
            result_train_lst.append(
                {"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()}
            )

    return result_train_lst


def load_glove_model(File):
    print("Loading Glove Model")
    glove_model = {}
    with open(File, "r") as f:
        for line in f:
            split_line = line.split()
            word = split_line[0]
            embedding = torch.tensor(np.array(split_line[1:], dtype=np.float64))
            # embedding = np.array(split_line[1:], dtype=np.float64)
            glove_model[word] = embedding
    print(f"{len(glove_model)} words loaded!")
    return glove_model


def load_glove(vocab):
    model = torch.nn.Embedding(len(vocab), 50)
    glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt")
    array_lst = []
    count_ = 0
    for word, idx in vocab.items():
        if word in glove_model:
            array_lst.append(glove_model[word])
        else:
            count_ += 1
            array_lst.append(torch.randn(50))
    print(f"{count_} out of {len(vocab)} is initialized. ")
    array_lst = torch.stack(array_lst)
    print(torch.norm(array_lst, dim=-1).mean())
    model.weight.data = array_lst
    return model


def get_corpus_rocstory(
    data_args, model, image_size, padding_mode="block", split="train", load_vocab=None
):
    import csv, torch, json
    from spacy.lang.en import English

    if data_args.experiment_mode == "lm":
        if data_args.modality == "roc":
            pass
            # print('loading dataset from ROCStory')
            # nlp = English()
            # tokenizer = nlp.tokenizer
            # sentence_lst = []
            # print(f'loading from {data_args.roc_train}')
            # if split == 'train':
            #     print('loading form the TRAIN set')
            #     path = f'{data_args.roc_train}/roc_train.json'
            # elif split == 'valid':
            #     print('loading form the VALID set')
            #     path = f'{data_args.roc_train}/roc_valid.json'
            # else:
            #     assert False, "invalid split for ROC dataset"

            # with open(path, 'r') as roc_reader:
            #     for row in roc_reader:
            #         sentences = json.loads(row)[0].strip()
            #         word_lst = [x.text for x in tokenizer(sentences)]
            #         sentence_lst.append(word_lst)

            # # with open(data_args.roc_train, 'r') as csvfile:
            # #     roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|')
            # #     for row in roc_reader:
            # #         # tokenize.
            # #         sentences = " ".join(row[2:])
            # #         word_lst = [x.text for x in tokenizer(sentences)]
            # #         sentence_lst.append(word_lst)
            # # sentence_lst = sentence_lst[1:]
            # print(sentence_lst[:2])
        if data_args.modality == "roc-aug":
            pass
            # print('loading dataset from ROCStory')
            # nlp = English()
            # tokenizer = nlp.tokenizer
            # sentence_lst = []
            # if split == 'train':
            #     print('loading form the TRAIN set')
            #     path_lst = [f'{data_args.roc_train}/roc_train.json']
            #     path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt')
            #     # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc.json')
            #     # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc2.json')

            # elif split == 'valid':
            #     print('loading form the VALID set')
            #     path_lst = [f'{data_args.roc_train}/roc_valid.json']
            # else:
            #     assert False, "invalid split for ROC dataset"

            # print(path_lst)
            # for path in path_lst:
            #     if path.endswith('txt'):
            #         with open(path, 'r') as roc_reader:
            #             for row in roc_reader:
            #                 sentences = row.strip()
            #                 word_lst = [x.text for x in tokenizer(sentences)]
            #                 sentence_lst.append(word_lst)
            #     else:
            #         with open(path, 'r') as roc_reader:
            #             for row in roc_reader:
            #                 sentences = json.loads(row)[0].strip()
            #                 word_lst = [x.text for x in tokenizer(sentences)]
            #                 sentence_lst.append(word_lst)
            # print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
        elif data_args.modality == "simple-wiki":
            pass
            # print('loading dataset from simple wikipedia')
            # sentence_lst = []
            # with open(data_args.wiki_train, 'r') as ff:
            #     for row in ff:
            #         word_lst = row.lower().split()
            #         sentence_lst.append(word_lst)
            # print(sentence_lst[:2])
        elif data_args.modality == "e2e-tgt":
            print("loading dataset from simple e2e dataset")
            sentence_lst = []
            nlp = English()
            tokenizer = nlp.tokenizer
            if split == "train":
                print("loading form the TRAIN set")
                path = (
                    "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt"
                )
                # path = f'../{data_args.e2e_train}/src1_train.txt'
            elif split == "valid":
                print("loading form the VALID set")
                path = f"../{data_args.e2e_train}/src1_valid.txt"
                path = (
                    "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt"
                )
            elif split == "test":
                print("loading form the TEST set")
                path = f"../{data_args.e2e_train}/src1_test.txt"
                path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt"
            elif split == "debug":
                print("loading form the DEBUG set")
                path = data_args.debug_path
                import json

                with open(path, "r") as ff:
                    for line in ff:
                        sentence_lst.append(json.loads(line)[0].split(" "))
                sentence_lst = sentence_lst + sentence_lst
            if split in ["train", "valid", "test"]:
                with open(path, "r") as ff:
                    for row in ff:
                        word_lst = row.split("||")[1]
                        word_lst = [x.text for x in tokenizer(word_lst)]
                        sentence_lst.append(word_lst)
            print(sentence_lst[:2])

        elif data_args.modality == "yelp":
            print("loading dataset from simple YelpNLG dataset")
            sentence_lst = []
            nlp = English()
            tokenizer = nlp.tokenizer
            if split == "train":
                print("loading form the TRAIN set")
                path = f"{data_args.yelp_train}/yelpnlg-train.csv"
            elif split == "valid":
                print("loading form the VALID set")
                path = f"{data_args.yelp_train}/yelpnlg-dev.csv"
            elif split == "test":
                print("loading form the TEST set")
                path = f"{data_args.yelp_train}/yelpnlg-test.csv"
            if split in ["train", "valid", "test"]:

                with open(path, "r") as csvfile:
                    yelp_reader = csv.reader(csvfile)  # delimiter=' ', quotechar='|')
                    for row in yelp_reader:
                        sentences = row[1]
                        word_lst = [x.text for x in tokenizer(sentences)]
                        sentence_lst.append(word_lst)
                sentence_lst = sentence_lst[1:]
            print(sentence_lst[:2])

        elif data_args.modality == "commonGen":
            print("loading dataset from simple YelpNLG dataset")
            sentence_lst = []
            nlp = English()
            tokenizer = nlp.tokenizer
            if split == "train":
                print("loading form the TRAIN set")
                path = f"{data_args.commonGen_train}/commongen.train.jsonl"
            elif split == "valid":
                print("loading form the VALID set")
                path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
            elif split == "test":
                print("loading form the TEST set")
                path = f"{data_args.commonGen_train}/commongen.test.jsonl"
            if split in ["train", "valid", "test"]:
                with open(path, "r") as ff:
                    for line in ff:
                        line = json.loads(line)
                        for sentences in line["scene"]:
                            word_lst = [x.text for x in tokenizer(sentences)]
                            sentence_lst.append(word_lst)
            print(sentence_lst[:2])

        elif data_args.modality == "commonGen-aug":
            print("loading dataset from simple YelpNLG dataset")
            sentence_lst = []
            nlp = English()
            tokenizer = nlp.tokenizer
            if split == "train":
                print("loading form the TRAIN set")
                path = f"{data_args.commonGen_train}/commongen.train.jsonl"
                path_lst = [f"{data_args.roc_train}/roc_train.json"]
                path_lst.append(
                    "diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt"
                )
            elif split == "valid":
                print("loading form the VALID set")
                path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
                path_lst = []
            elif split == "test":
                print("loading form the TEST set")
                path = f"{data_args.commonGen_train}/commongen.test.jsonl"
                path_lst = []

            if split in ["train", "valid", "test"]:
                with open(path, "r") as ff:
                    for line in ff:
                        line = json.loads(line)
                        for sentences in line["scene"]:
                            word_lst = [x.text for x in tokenizer(sentences)]
                            sentence_lst.append(word_lst)
            print(sentence_lst[:2])
            import itertools

            for path in path_lst:
                if path.endswith("txt"):
                    with open(path, "r") as roc_reader:
                        for row in roc_reader:
                            sentences = row.strip()
                            word_lst = [x.text for x in tokenizer(sentences)]
                            spl = [[]]
                            for x, y in itertools.groupby(word_lst, lambda z: z == "."):
                                spl[-1].extend(y)
                                if x:
                                    spl.append([])
                            sentence_lst.extend(spl[:-1])
                else:
                    with open(path, "r") as roc_reader:
                        for row in roc_reader:
                            sentences = json.loads(row)[0].strip()
                            word_lst = [x.text for x in tokenizer(sentences)]
                            spl = [[]]
                            for x, y in itertools.groupby(word_lst, lambda z: z == "."):
                                spl[-1].extend(y)
                                if x:
                                    spl.append([])
                            sentence_lst.extend(spl[:-1])

            print(sentence_lst[-2:])

        # get tokenizer.
        if load_vocab is None:
            counter = Counter()
            for input_ids in sentence_lst:
                counter.update(input_ids)

    if data_args.experiment_mode == "conditional_gen":
        if data_args.modality == "e2e":
            print("loading dataset from simple e2e dataset")
            sentence_lst = []
            nlp = English()
            tokenizer = nlp.tokenizer
            if split == "train":
                path = f"{data_args.e2e_train}/src1_train.txt"
                with open(path, "r") as ff:
                    for row in ff:
                        src_lst, word_lst = row.split("||")
                        word_lst = [x.text for x in tokenizer(word_lst)]
                        src_lst = [x.text for x in tokenizer(src_lst)]
                        sentence_lst.append((src_lst, word_lst))
            elif split == "valid":
                path = f"{data_args.e2e_train}/src1_valid.txt"
                sentence_lst = read_e2e_files(path, data_args, tokenizer)
            print(sentence_lst[:2])
        # get tokenizer.
        if load_vocab is None:
            counter = Counter()
            for src_ids, input_ids in sentence_lst:
                counter.update(input_ids)
                counter.update(src_ids)

    if load_vocab is None:
        vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3}
        for k, v in counter.items():
            if v > 10:
                vocab_dict[k] = len(vocab_dict)
        print(len(counter), len(vocab_dict))

        path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
        print(f"save the vocab to {path_save_vocab}")
        with open(path_save_vocab, "w") as f:
            json.dump(vocab_dict, f)
    else:
        vocab_dict = load_vocab
        path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
        if not os.path.exists(path_save_vocab):
            print(f"save the vocab to {path_save_vocab}")
            if isinstance(vocab_dict, dict):
                with open(path_save_vocab, "w") as f:
                    json.dump(vocab_dict, f)
                assert vocab_dict["START"] == 0
            elif isinstance(vocab_dict, PreTrainedTokenizerFast):
                vocab_dict.save_pretrained(data_args.checkpoint_path)
            else:
                assert False, "invalid type of vocab_dict"

    if model is None and data_args.experiment == "random":
        model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
        print("initializing the random embeddings", model)
        torch.nn.init.normal_(model.weight)
        path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch"
        print(
            f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
        )
        torch.save(model.state_dict(), path_save)

    # path_save = f'{data_args.checkpoint_path}/random_emb.torch'
    # if not os.path.exists(path_save) and data_args.experiment == 'random':
    #     torch.save(model.state_dict(), path_save)

    if (
        data_args.experiment_mode == "lm"
        and data_args.modality
        in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"]
        and data_args.cache_mode == "no"
    ):
        train_dataset = helper_tokenize_stream(
            sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
        )
        return train_dataset, model
    elif data_args.experiment_mode == "lm":
        result_train_lst = helper_tokenize_encode(
            sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
        )
    elif data_args.experiment_mode == "conditional_gen":
        result_train_lst = helper_tokenize_encode_cond(
            sentence_lst, vocab_dict, model, image_size**2, data_args
        )
    return {"train": result_train_lst}, model


def write_e2e_corr(prompt_lst, file_dict, corr_path):
    print(len(prompt_lst))
    with open(corr_path, "w") as f:
        for x in prompt_lst:
            for line in file_dict[x]:
                print(" ".join(line), file=f)
            print("", file=f)


def write_e2e_src(prompt_lst, corr_path):
    with open(corr_path, "w") as f:
        for x in prompt_lst:
            print(" ".join(x), file=f)
    return


def read_e2e_files(path, args, tokenizer):
    file_dict = {}
    with open(path, "r") as f:
        for line in f:
            src_lst, word_lst = line.strip().split("||")
            tgt = tuple([x.text for x in tokenizer(word_lst)])
            src = tuple([x.text for x in tokenizer(src_lst)])
            if src not in file_dict:
                file_dict[src] = []
            file_dict[src].append(tgt)
    temp = "1"
    prompt_text_dict = file_dict
    prompt_text_lst = list(prompt_text_dict.keys())
    gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold"))
    print("gold dir", gold_dir)
    write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
    src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src"))
    write_e2e_src(prompt_text_lst, src_dir)
    final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
    return final_lst


def get_corpus_book(
    data_args,
    tokenizer,
    model,
    image_size,
    padding_mode="block",
    split="train",
):
    max_length = image_size**2
    import os

    assert padding_mode == "block"
    raw_datasets = load_dataset("bookcorpus")
    if "validation" not in raw_datasets.keys():
        raw_datasets["validation"] = load_dataset(
            "bookcorpus",
            split=f"train[:1%]",
        )
        raw_datasets["train"] = load_dataset(
            "bookcorpus",
            split=f"train[1%:]",
        )
    print(raw_datasets)
    column_names = raw_datasets["train"].column_names

    def tokenize_function(examples):
        output = tokenizer(examples["text"], add_special_tokens=False)
        return output

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=True,
    )

    print(tokenized_datasets)

    block_size = max_length

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        return result

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=4,
        load_from_cache_file=True,
        desc=f"Grouping texts in chunks of {block_size}",
    )

    print(lm_datasets)

    if model is None:
        if data_args.training_mode.startswith("e2e"):
            print("since its e2e, initialize a dummy embedding")
            model = torch.nn.Embedding(len(tokenizer), 1)
        else:
            model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
        print("initializing the random embeddings", model)
        torch.nn.init.normal_(model.weight)
        path_save = f"{data_args.checkpoint_path}/random_emb.torch"
        print(
            f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
        )
        torch.save(model.state_dict(), path_save)

    if split == "train":
        return lm_datasets, model
    else:
        lm_datasets["train"] = lm_datasets["validation"]
        return lm_datasets, model


class TextDataset(Dataset):
    def __init__(
        self,
        text_datasets,
        resolution,
        data_args,
        model_arch="conv-unet",
        classes=None,
        shard=0,
        num_shards=1,
        eigen_transform=None,
        mapping_func=None,
        model_emb=None,
    ):
        super().__init__()
        self.resolution = resolution
        self.text_datasets = text_datasets
        self.length = len(self.text_datasets["train"])
        self.model_arch = model_arch
        self.data_args = data_args
        print(self.resolution)
        self.eigen_transform = eigen_transform
        self.mapping_func = mapping_func
        self.model_emb = model_emb
        # self.local_images = image_paths[shard:][::num_shards]
        # self.local_classes = None if classes is None else classes[shard:][::num_shards]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):

        # We are not on a new enough PIL to support the `reducing_gap`
        # argument, which uses BOX downsampling at powers of two first.
        # Thus, we do it by hand to improve downsample quality.
        if self.model_arch == "conv-unet":
            pass  # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
            #                dtype=np.float32).reshape(self.resolution, self.resolution, -1)
            # # print(self.eigen_transform.shape)
            # if self.eigen_transform  is not None:
            #     old_shape = arr.shape
            #     arr = arr.reshape(1, -1) - self.eigen_transform['mean']
            #     arr = arr @ self.eigen_transform['map']
            #     arr = arr.reshape(old_shape)
            # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
            #     arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)

            # out_dict = {}
            # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
            # # if self.local_classes is not None:
            # #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
            # # print(out_dict.keys())
            # return np.transpose(arr, [2, 0, 1]), out_dict
        elif self.model_arch == "1d-unet":
            pass  # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
            #                dtype=np.float32) # seqlen, dim
            # if self.eigen_transform  is not None:
            #     old_shape = arr.shape
            #     arr = arr.reshape(1, -1) - self.eigen_transform['mean']
            #     arr = arr @ self.eigen_transform['map']
            #     arr = arr.reshape(old_shape)
            # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
            #     arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
            # arr = np.transpose(arr, [1, 0])
            # out_dict = {}
            # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
            # # out_dict['mapping_func'] = self.mapping_func
            # # if self.local_classes is not None:
            # #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
            # # print(arr.shape)
            # return arr, out_dict
        else:
            arr = np.array(
                self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32
            )
            if self.eigen_transform is not None:
                old_shape = arr.shape
                # arr = arr.reshape(1, -1) @ self.eigen_transform
                arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
                arr = arr @ self.eigen_transform["map"]
                arr = arr.reshape(old_shape)

            if (
                hasattr(self.data_args, "noise_level")
                and self.data_args.noise_level > 0
            ):
                # print(arr.dtype)
                # print(self.data_args.noise_level, 'using the noise level.')
                arr = arr + self.data_args.noise_level * np.random.randn(
                    *arr.shape
                ).astype(arr.dtype)
                # print(arr.dtype)

            out_dict = {}
            out_dict["input_ids"] = np.array(
                self.text_datasets["train"][idx]["input_ids"]
            )
            # out_dict['mapping_func'] = self.mapping_func
            if self.data_args.experiment_mode == "conditional_gen":
                out_dict["src_ids"] = np.array(
                    self.text_datasets["train"][idx]["src_ids"]
                )
                out_dict["src_mask"] = np.array(
                    self.text_datasets["train"][idx]["src_mask"]
                )
            # if self.local_classes is not None:
            #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
            return arr, out_dict
        # print(arr.dtype)
        # arr = arr.float()
        # print(arr.shape)


class TextDataset_NoCache(Dataset):
    def __init__(
        self,
        text_datasets,
        resolution,
        data_args,
        model_arch="conv-unet",
        classes=None,
        shard=0,
        num_shards=1,
        eigen_transform=None,
        mapping_func=None,
        model_emb=None,
    ):
        super().__init__()
        self.resolution = resolution
        self.text_datasets = text_datasets
        self.length = len(self.text_datasets["train"])
        self.model_arch = model_arch
        self.data_args = data_args
        print(self.resolution)
        self.eigen_transform = eigen_transform
        self.mapping_func = mapping_func
        self.model_emb = model_emb
        # self.local_images = image_paths[shard:][::num_shards]
        # self.local_classes = None if classes is None else classes[shard:][::num_shards]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):

        # We are not on a new enough PIL to support the `reducing_gap`
        # argument, which uses BOX downsampling at powers of two first.
        # Thus, we do it by hand to improve downsample quality.
        with torch.no_grad():
            input_ids = self.text_datasets["train"][idx]["input_ids"]
            model = self.model_emb
            if self.data_args.experiment.startswith("random"):
                hidden_state = model(torch.tensor(input_ids))
            elif self.data_args.experiment == "gpt2_pre_compress":
                input_ids2 = torch.tensor(input_ids).to(model.device)
                input_embs = model.transformer.wte(input_ids2)  # input_embs
                hidden_state = model.down_proj(input_embs)
                hidden_state = hidden_state * data_args.emb_scale_factor

            if self.model_arch == "conv-unet":
                arr = np.array(hidden_state, dtype=np.float32).reshape(
                    self.resolution, self.resolution, -1
                )
                # print(self.eigen_transform.shape)
                if self.eigen_transform is not None:
                    old_shape = arr.shape
                    arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
                    arr = arr @ self.eigen_transform["map"]
                    arr = arr.reshape(old_shape)
                if (
                    hasattr(self.data_args, "noise_level")
                    and self.data_args.noise_level > 0
                ):
                    arr = arr + self.data_args.noise_level * np.random.randn(
                        *arr.shape
                    ).astype(arr.dtype)

                out_dict = {}
                out_dict["input_ids"] = np.array(
                    self.text_datasets["train"][idx]["input_ids"]
                )
                # if self.local_classes is not None:
                #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
                # print(out_dict.keys())
                return np.transpose(arr, [2, 0, 1]), out_dict
            elif self.model_arch == "1d-unet":
                arr = np.array(hidden_state, dtype=np.float32)  # seqlen, dim
                if self.eigen_transform is not None:
                    old_shape = arr.shape
                    arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
                    arr = arr @ self.eigen_transform["map"]
                    arr = arr.reshape(old_shape)
                if (
                    hasattr(self.data_args, "noise_level")
                    and self.data_args.noise_level > 0
                ):
                    arr = arr + self.data_args.noise_level * np.random.randn(
                        *arr.shape
                    ).astype(arr.dtype)
                arr = np.transpose(arr, [1, 0])
                out_dict = {}
                out_dict["input_ids"] = np.array(
                    self.text_datasets["train"][idx]["input_ids"]
                )
                # out_dict['mapping_func'] = self.mapping_func
                # if self.local_classes is not None:
                #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
                # print(arr.shape)
                return arr, out_dict
            else:
                arr = np.array(hidden_state, dtype=np.float32)
                if self.eigen_transform is not None:
                    old_shape = arr.shape
                    # arr = arr.reshape(1, -1) @ self.eigen_transform
                    arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
                    arr = arr @ self.eigen_transform["map"]
                    arr = arr.reshape(old_shape)

                if (
                    hasattr(self.data_args, "noise_level")
                    and self.data_args.noise_level > 0
                ):
                    # print(arr.dtype)
                    # print(self.data_args.noise_level, 'using the noise level.')
                    arr = arr + self.data_args.noise_level * np.random.randn(
                        *arr.shape
                    ).astype(arr.dtype)
                    # print(arr.dtype)

                out_dict = {}
                out_dict["input_ids"] = np.array(
                    self.text_datasets["train"][idx]["input_ids"]
                )
                # out_dict['mapping_func'] = self.mapping_func
                if self.data_args.experiment_mode == "conditional_gen":
                    out_dict["src_ids"] = np.array(
                        self.text_datasets["train"][idx]["src_ids"]
                    )
                    out_dict["src_mask"] = np.array(
                        self.text_datasets["train"][idx]["src_mask"]
                    )
                # if self.local_classes is not None:
                #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
                return arr, out_dict


def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
    result = torch.full(
        [len(examples), max_length], pad_token_id, dtype=torch.int64
    ).tolist()
    mask_ = torch.full(
        [len(examples), max_length], pad_token_id, dtype=torch.int64
    ).tolist()
    for i, example in enumerate(examples):
        curr_len = min(len(example), max_length)
        result[i][:curr_len] = example[:curr_len]
        mask_[i][:curr_len] = [1] * curr_len
    if return_mask:
        return result, mask_
    return result


def _torch_collate_batch(examples, pad_token_id, max_length):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    import numpy as np
    import torch

    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple, np.ndarray)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    # length_of_first = examples[0].size(0)
    # Check if padding is necessary.
    # are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    # if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
    #     return torch.stack(examples, dim=0)
    # Creating the full tensor and filling it with our data.
    # max_length = max(x.size(0) for x in examples)
    # if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
    #     max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], pad_token_id)
    for i, example in enumerate(examples):
        if True:
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0] :] = example
    return result