from os import path as osp | |
import msgspec | |
from traiNNer.data.paired_image_dataset import PairedImageDataset | |
from traiNNer.utils.redux_options import DatasetOptions | |
def test_pairedimagedataset() -> None: | |
"""Test dataset: PairedImageDataset""" | |
opt_str = r""" | |
name: Test | |
type: PairedImageDataset | |
dataroot_gt: [ | |
datasets/val/dataset1/hr, | |
datasets/val/dataset1/hr2, | |
] | |
dataroot_lq: [ | |
datasets/val/dataset1/lr, | |
datasets/val/dataset1/lr2 | |
] | |
filename_tmpl: '{}' | |
io_backend: | |
type: disk | |
scale: 4 | |
mean: [0.5, 0.5, 0.5] | |
std: [0.5, 0.5, 0.5] | |
gt_size: 128 | |
use_hflip: true | |
use_rot: true | |
phase: train | |
""" | |
image_names = (7, 8, 9) | |
opt = msgspec.yaml.decode(opt_str, type=DatasetOptions, strict=True) | |
dataset = PairedImageDataset(opt) | |
assert dataset.io_backend_opt["type"] == "disk" # io backend | |
assert len(dataset) == 3 # whether to read correct meta info | |
assert dataset.mean == [0.5, 0.5, 0.5] | |
# ------------------ test scan folder mode -------------------- # | |
opt.io_backend = {"type": "disk"} | |
dataset = PairedImageDataset(opt) | |
assert dataset.io_backend_opt["type"] == "disk" # io backend | |
assert len(dataset) == 3 # whether to correctly scan folders | |
# test __getitem__ | |
result = dataset.__getitem__(0) | |
# check returned keys | |
expected_keys = ["lq", "gt", "lq_path", "gt_path"] | |
assert set(expected_keys).issubset(set(result.keys())) | |
# check shape and contents | |
assert ( | |
"gt" in result | |
and "lq" in result | |
and "lq_path" in result | |
and "gt_path" in result | |
) | |
assert result["gt"].shape == (3, 128, 128) | |
assert result["lq"].shape == (3, 32, 32) | |
assert osp.normpath(result["lq_path"]) in { | |
osp.normpath(f"datasets/val/dataset1/lr/{x:04d}.png") for x in image_names | |
} | |
assert osp.normpath(result["gt_path"]) in { | |
osp.normpath(f"datasets/val/dataset1/hr/{x:04d}.png") for x in image_names | |
} | |
# ------------------ test lmdb backend and with y channel-------------------- # | |
# TODO | |
# opt["dataroot_gt"] = "tests/data/gt.lmdb" | |
# opt["dataroot_lq"] = "tests/data/lq.lmdb" | |
# opt["io_backend"] = {"type": "lmdb"} | |
# opt["color"] = "y" | |
# opt["mean"] = [0.5] | |
# opt["std"] = [0.5] | |
# dataset = PairedImageDataset(opt) | |
# assert dataset.io_backend_opt["type"] == "lmdb" # io backend | |
# assert len(dataset) == 2 # whether to read correct meta info | |
# assert dataset.std == [0.5] | |
# # test __getitem__ | |
# result = dataset.__getitem__(1) | |
# # check returned keys | |
# expected_keys = ["lq", "gt", "lq_path", "gt_path"] | |
# assert set(expected_keys).issubset(set(result.keys())) | |
# # check shape and contents | |
# assert ( | |
# "gt" in result | |
# and "lq" in result | |
# and "lq_path" in result | |
# and "gt_path" in result | |
# ) | |
# assert result["gt"].shape == (1, 128, 128) | |
# assert result["lq"].shape == (1, 32, 32) | |
# assert result["lq_path"] == "comic" | |
# assert result["gt_path"] == "comic" | |
# ------------------ test case: val/test mode -------------------- # | |
# TODO | |
# opt["phase"] = "test" | |
# opt["io_backend"] = {"type": "lmdb"} | |
# dataset = PairedImageDataset(opt) | |
# # test __getitem__ | |
# result = dataset.__getitem__(0) | |
# # check returned keys | |
# expected_keys = ["lq", "gt", "lq_path", "gt_path"] | |
# assert set(expected_keys).issubset(set(result.keys())) | |
# # check shape and contents | |
# assert ( | |
# "gt" in result | |
# and "lq" in result | |
# and "lq_path" in result | |
# and "gt_path" in result | |
# ) | |
# assert result["gt"].shape == (1, 480, 492) | |
# assert result["lq"].shape == (1, 120, 123) | |
# assert result["lq_path"] == "baboon" | |
# assert result["gt_path"] == "baboon" | |