File size: 1,461 Bytes
62dbcfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import numpy as np
import pytest
from traiNNer.metrics.psnr_ssim import calculate_psnr, calculate_ssim
def test_calculate_psnr() -> None:
"""Test metric: calculate_psnr"""
# mismatched image shapes
with pytest.raises(AssertionError):
calculate_psnr(np.ones((16, 16)), np.ones((10, 10)), crop_border=0)
# wrong input order
with pytest.raises(ValueError):
calculate_psnr(
np.ones((16, 16)), np.ones((16, 16)), crop_border=1, input_order="WRONG"
)
out = calculate_psnr(
np.ones((10, 10, 3)),
np.ones((10, 10, 3)) * 2,
crop_border=1,
test_y_channel=True,
)
assert isinstance(out, float)
# test float inf
out = calculate_psnr(np.ones((10, 10, 3)), np.ones((10, 10, 3)), crop_border=0)
assert out == float("inf")
def test_calculate_ssim() -> None:
"""Test metric: calculate_ssim"""
# mismatched image shapes
with pytest.raises(AssertionError):
calculate_ssim(np.ones((16, 16)), np.ones((10, 10)), crop_border=0)
# wrong input order
with pytest.raises(ValueError):
calculate_ssim(
np.ones((16, 16)), np.ones((16, 16)), crop_border=1, input_order="WRONG"
)
out = calculate_ssim(
np.ones((10, 10, 3)),
np.ones((10, 10, 3)) * 2,
crop_border=1,
test_y_channel=True,
)
assert isinstance(out, float)
|