File size: 2,294 Bytes
62dbcfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import math
import os

import torchvision.utils
from traiNNer.data import build_dataloader, build_dataset
from traiNNer.utils.redux_options import DatasetOptions


def main(mode: str = "folder") -> None:
    """Test paired image dataset.



    Args:

        mode: There are three modes: 'lmdb', 'folder', 'meta_info'.

    """
    opt = DatasetOptions(
        phase="train",
        name="DIV2K",
        type="PairedImageDataset",
        gt_size=128,
        use_hflip=True,
        use_rot=True,
        num_worker_per_gpu=2,
        batch_size_per_gpu=16,
        scale=4,
        dataset_enlarge_ratio=1,
        dataroot_gt=["datasets/DIV2K/DIV2K_train_HR_sub"],
        dataroot_lq=["datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub"],
        filename_tmpl="{}",
        io_backend={"type": "disk"},
    )

    if mode == "meta_info":
        opt.dataroot_gt = ["datasets/DIV2K/DIV2K_train_HR_sub"]
        opt.dataroot_lq = ["datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub"]
        opt.meta_info = "traiNNer/data/meta_info/meta_info_DIV2K800sub_GT.txt"
        opt.filename_tmpl = "{}"
        opt.io_backend = {"type": "disk"}
    elif mode == "lmdb":
        opt.dataroot_gt = ["datasets/DIV2K/DIV2K_train_HR_sub.lmdb"]
        opt.dataroot_lq = ["datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb"]
        opt.io_backend = {"type": "lmdb"}

    os.makedirs("tmp", exist_ok=True)

    dataset = build_dataset(opt)
    data_loader = build_dataloader(dataset, opt, num_gpu=0, dist=False, sampler=None)

    assert opt.batch_size_per_gpu is not None

    nrow = int(math.sqrt(opt.batch_size_per_gpu))
    padding = 2 if opt.phase == "train" else 0

    print("start...")
    for i, data in enumerate(data_loader):
        if i > 5:
            break
        print(i)

        lq = data["lq"]
        gt = data["gt"]
        lq_path = data["lq_path"]
        gt_path = data["gt_path"]
        print(lq_path, gt_path)
        torchvision.utils.save_image(
            lq, f"tmp/lq_{i:03d}.png", nrow=nrow, padding=padding, normalize=False
        )
        torchvision.utils.save_image(
            gt, f"tmp/gt_{i:03d}.png", nrow=nrow, padding=padding, normalize=False
        )


if __name__ == "__main__":
    main()