sisr2onnx / tests /test_data /test_paired_image_dataset.py
Zarxrax's picture
Upload 823 files
62dbcfb verified
raw
history blame
3.91 kB
from os import path as osp
import msgspec
from traiNNer.data.paired_image_dataset import PairedImageDataset
from traiNNer.utils.redux_options import DatasetOptions
def test_pairedimagedataset() -> None:
"""Test dataset: PairedImageDataset"""
opt_str = r"""
name: Test
type: PairedImageDataset
dataroot_gt: [
datasets/val/dataset1/hr,
datasets/val/dataset1/hr2,
]
dataroot_lq: [
datasets/val/dataset1/lr,
datasets/val/dataset1/lr2
]
filename_tmpl: '{}'
io_backend:
type: disk
scale: 4
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
gt_size: 128
use_hflip: true
use_rot: true
phase: train
"""
image_names = (7, 8, 9)
opt = msgspec.yaml.decode(opt_str, type=DatasetOptions, strict=True)
dataset = PairedImageDataset(opt)
assert dataset.io_backend_opt["type"] == "disk" # io backend
assert len(dataset) == 3 # whether to read correct meta info
assert dataset.mean == [0.5, 0.5, 0.5]
# ------------------ test scan folder mode -------------------- #
opt.io_backend = {"type": "disk"}
dataset = PairedImageDataset(opt)
assert dataset.io_backend_opt["type"] == "disk" # io backend
assert len(dataset) == 3 # whether to correctly scan folders
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ["lq", "gt", "lq_path", "gt_path"]
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert (
"gt" in result
and "lq" in result
and "lq_path" in result
and "gt_path" in result
)
assert result["gt"].shape == (3, 128, 128)
assert result["lq"].shape == (3, 32, 32)
assert osp.normpath(result["lq_path"]) in {
osp.normpath(f"datasets/val/dataset1/lr/{x:04d}.png") for x in image_names
}
assert osp.normpath(result["gt_path"]) in {
osp.normpath(f"datasets/val/dataset1/hr/{x:04d}.png") for x in image_names
}
# ------------------ test lmdb backend and with y channel-------------------- #
# TODO
# opt["dataroot_gt"] = "tests/data/gt.lmdb"
# opt["dataroot_lq"] = "tests/data/lq.lmdb"
# opt["io_backend"] = {"type": "lmdb"}
# opt["color"] = "y"
# opt["mean"] = [0.5]
# opt["std"] = [0.5]
# dataset = PairedImageDataset(opt)
# assert dataset.io_backend_opt["type"] == "lmdb" # io backend
# assert len(dataset) == 2 # whether to read correct meta info
# assert dataset.std == [0.5]
# # test __getitem__
# result = dataset.__getitem__(1)
# # check returned keys
# expected_keys = ["lq", "gt", "lq_path", "gt_path"]
# assert set(expected_keys).issubset(set(result.keys()))
# # check shape and contents
# assert (
# "gt" in result
# and "lq" in result
# and "lq_path" in result
# and "gt_path" in result
# )
# assert result["gt"].shape == (1, 128, 128)
# assert result["lq"].shape == (1, 32, 32)
# assert result["lq_path"] == "comic"
# assert result["gt_path"] == "comic"
# ------------------ test case: val/test mode -------------------- #
# TODO
# opt["phase"] = "test"
# opt["io_backend"] = {"type": "lmdb"}
# dataset = PairedImageDataset(opt)
# # test __getitem__
# result = dataset.__getitem__(0)
# # check returned keys
# expected_keys = ["lq", "gt", "lq_path", "gt_path"]
# assert set(expected_keys).issubset(set(result.keys()))
# # check shape and contents
# assert (
# "gt" in result
# and "lq" in result
# and "lq_path" in result
# and "gt_path" in result
# )
# assert result["gt"].shape == (1, 480, 492)
# assert result["lq"].shape == (1, 120, 123)
# assert result["lq_path"] == "baboon"
# assert result["gt_path"] == "baboon"