File size: 3,909 Bytes
62dbcfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from os import path as osp

import msgspec
from traiNNer.data.paired_image_dataset import PairedImageDataset
from traiNNer.utils.redux_options import DatasetOptions


def test_pairedimagedataset() -> None:
    """Test dataset: PairedImageDataset"""

    opt_str = r"""

name: Test

type: PairedImageDataset

dataroot_gt: [

datasets/val/dataset1/hr,

datasets/val/dataset1/hr2,

]

dataroot_lq: [

datasets/val/dataset1/lr,

datasets/val/dataset1/lr2

]

filename_tmpl: '{}'

io_backend:

    type: disk



scale: 4

mean: [0.5, 0.5, 0.5]

std: [0.5, 0.5, 0.5]

gt_size: 128

use_hflip: true

use_rot: true



phase: train

"""
    image_names = (7, 8, 9)
    opt = msgspec.yaml.decode(opt_str, type=DatasetOptions, strict=True)

    dataset = PairedImageDataset(opt)
    assert dataset.io_backend_opt["type"] == "disk"  # io backend
    assert len(dataset) == 3  # whether to read correct meta info
    assert dataset.mean == [0.5, 0.5, 0.5]

    # ------------------ test scan folder mode -------------------- #
    opt.io_backend = {"type": "disk"}
    dataset = PairedImageDataset(opt)
    assert dataset.io_backend_opt["type"] == "disk"  # io backend
    assert len(dataset) == 3  # whether to correctly scan folders

    # test __getitem__
    result = dataset.__getitem__(0)
    # check returned keys
    expected_keys = ["lq", "gt", "lq_path", "gt_path"]
    assert set(expected_keys).issubset(set(result.keys()))
    # check shape and contents
    assert (
        "gt" in result
        and "lq" in result
        and "lq_path" in result
        and "gt_path" in result
    )
    assert result["gt"].shape == (3, 128, 128)
    assert result["lq"].shape == (3, 32, 32)
    assert osp.normpath(result["lq_path"]) in {
        osp.normpath(f"datasets/val/dataset1/lr/{x:04d}.png") for x in image_names
    }
    assert osp.normpath(result["gt_path"]) in {
        osp.normpath(f"datasets/val/dataset1/hr/{x:04d}.png") for x in image_names
    }

    # ------------------ test lmdb backend and with y channel-------------------- #
    # TODO
    # opt["dataroot_gt"] = "tests/data/gt.lmdb"
    # opt["dataroot_lq"] = "tests/data/lq.lmdb"
    # opt["io_backend"] = {"type": "lmdb"}
    # opt["color"] = "y"
    # opt["mean"] = [0.5]
    # opt["std"] = [0.5]

    # dataset = PairedImageDataset(opt)
    # assert dataset.io_backend_opt["type"] == "lmdb"  # io backend
    # assert len(dataset) == 2  # whether to read correct meta info
    # assert dataset.std == [0.5]

    # # test __getitem__
    # result = dataset.__getitem__(1)
    # # check returned keys
    # expected_keys = ["lq", "gt", "lq_path", "gt_path"]
    # assert set(expected_keys).issubset(set(result.keys()))
    # # check shape and contents
    # assert (
    #     "gt" in result
    #     and "lq" in result
    #     and "lq_path" in result
    #     and "gt_path" in result
    # )
    # assert result["gt"].shape == (1, 128, 128)
    # assert result["lq"].shape == (1, 32, 32)
    # assert result["lq_path"] == "comic"
    # assert result["gt_path"] == "comic"

    # ------------------ test case: val/test mode -------------------- #
    # TODO
    # opt["phase"] = "test"
    # opt["io_backend"] = {"type": "lmdb"}
    # dataset = PairedImageDataset(opt)

    # # test __getitem__
    # result = dataset.__getitem__(0)
    # # check returned keys
    # expected_keys = ["lq", "gt", "lq_path", "gt_path"]
    # assert set(expected_keys).issubset(set(result.keys()))
    # # check shape and contents
    # assert (
    #     "gt" in result
    #     and "lq" in result
    #     and "lq_path" in result
    #     and "gt_path" in result
    # )
    # assert result["gt"].shape == (1, 480, 492)
    # assert result["lq"].shape == (1, 120, 123)
    # assert result["lq_path"] == "baboon"
    # assert result["gt_path"] == "baboon"