|
import numpy as np
|
|
import pytest
|
|
from traiNNer.metrics.psnr_ssim import calculate_psnr, calculate_ssim
|
|
|
|
|
|
def test_calculate_psnr() -> None:
|
|
"""Test metric: calculate_psnr"""
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
calculate_psnr(np.ones((16, 16)), np.ones((10, 10)), crop_border=0)
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
calculate_psnr(
|
|
np.ones((16, 16)), np.ones((16, 16)), crop_border=1, input_order="WRONG"
|
|
)
|
|
|
|
out = calculate_psnr(
|
|
np.ones((10, 10, 3)),
|
|
np.ones((10, 10, 3)) * 2,
|
|
crop_border=1,
|
|
test_y_channel=True,
|
|
)
|
|
assert isinstance(out, float)
|
|
|
|
|
|
out = calculate_psnr(np.ones((10, 10, 3)), np.ones((10, 10, 3)), crop_border=0)
|
|
assert out == float("inf")
|
|
|
|
|
|
def test_calculate_ssim() -> None:
|
|
"""Test metric: calculate_ssim"""
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
calculate_ssim(np.ones((16, 16)), np.ones((10, 10)), crop_border=0)
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
calculate_ssim(
|
|
np.ones((16, 16)), np.ones((16, 16)), crop_border=1, input_order="WRONG"
|
|
)
|
|
|
|
out = calculate_ssim(
|
|
np.ones((10, 10, 3)),
|
|
np.ones((10, 10, 3)) * 2,
|
|
crop_border=1,
|
|
test_y_channel=True,
|
|
)
|
|
assert isinstance(out, float)
|
|
|