sisr2onnx / test_scripts /test_precision.py
Zarxrax's picture
Upload 823 files
62dbcfb verified
raw
history blame
3.73 kB
from typing import Any
import torch
from torch import Tensor, autocast, nn
from traiNNer.archs import ARCH_REGISTRY, SPANDREL_REGISTRY
from traiNNer.losses.basic_loss import L1Loss
ALL_REGISTRIES = list(ARCH_REGISTRY)
EXCLUDE_BENCHMARK_ARCHS = {
"dat",
"hat",
"swinir",
"lmlt",
"vggstylediscriminator",
"unetdiscriminatorsn",
"vggfeatureextractor",
}
FILTERED_REGISTRY = [
(name, arch)
for name, arch in list(SPANDREL_REGISTRY) + list(ARCH_REGISTRY)
if name not in EXCLUDE_BENCHMARK_ARCHS
]
# For archs that have extra parameters, list all combinations that need to be benchmarked.
EXTRA_ARCH_PARAMS: dict[str, list[dict[str, Any]]] = {
k: [] for k, _ in FILTERED_REGISTRY
}
EXTRA_ARCH_PARAMS["realplksr"] = [
{"upsampler": "dysample"},
{"upsampler": "pixelshuffle"},
]
# A list of tuples in the format of (name, arch, extra_params).
FILTERED_REGISTRIES_PARAMS = [
(name, arch, extra_params)
for (name, arch) in FILTERED_REGISTRY
for extra_params in (EXTRA_ARCH_PARAMS[name] if EXTRA_ARCH_PARAMS[name] else [{}])
]
def format_extra_params(extra_arch_params: dict[str, Any]) -> str:
out = ""
for k, v in extra_arch_params.items():
if isinstance(v, str):
out += f"{v} "
else:
out += f"{k}={v} "
return out.strip()
def compare_precision(
net: nn.Module, input_tensor: Tensor, criterion: nn.Module
) -> tuple[float, float]:
with torch.inference_mode():
fp32_output = net(input_tensor)
fp16_loss = None
try:
with autocast(dtype=torch.float16, device_type="cuda"), torch.inference_mode():
fp16_output = net(input_tensor)
fp16_loss = criterion(fp16_output.float(), fp32_output).item()
except Exception as e:
print(f"Error in FP16 inference: {e}")
fp16_loss = float("inf")
bf16_loss = None
try:
with autocast(dtype=torch.bfloat16, device_type="cuda"), torch.inference_mode():
bf16_output = net(input_tensor)
bf16_loss = criterion(bf16_output.float(), fp32_output).item()
except Exception as e:
print(f"Error in BF16 inference: {e}")
bf16_loss = float("inf")
return fp16_loss, bf16_loss
if __name__ == "__main__":
scale = 4
for name, arch, extra_arch_params in FILTERED_REGISTRIES_PARAMS:
label = f"{name} {format_extra_params(extra_arch_params)} {scale}x"
try:
if name not in {
"rcan",
"esrgan",
"compact",
"span",
"dat_2",
"spanplus",
"realplksr",
}:
continue
net: nn.Module = arch(scale=scale, **extra_arch_params).eval().to("cuda")
# net.load_state_dict(
# torch.load(
# r"DAT_2_x4.pth",
# weights_only=True,
# )["params"]
# )
input_tensor = torch.randn((2, 3, 192, 192), device="cuda")
criterion = L1Loss(1.0)
fp16_loss, bf16_loss = compare_precision(net, input_tensor, criterion)
diff = abs(fp16_loss - bf16_loss)
if fp16_loss < bf16_loss:
print(
f"{label:>30s}: FP16: {fp16_loss:.6f}; BF16: {bf16_loss:.6f}; diff = {diff}"
)
else:
print(
f"{label:>30s}: BF16: {bf16_loss:.6f}; FP16: {fp16_loss:.6f}; diff = {diff}"
)
except Exception as e:
print(f"skip {label}", e)