sisr2onnx / scripts /benchmarking /generate_charts.py
Zarxrax's picture
Upload 823 files
62dbcfb verified
raw
history blame
4.34 kB
import csv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def process_table(df: pd.DataFrame) -> pd.DataFrame:
df["psnr_df2k"] = pd.to_numeric(df["psnr_df2k"], errors="coerce")
return df[["name", "fps", "psnr_df2k", "vram", "params"]] # pyright: ignore[reportReturnType]
def plot_scatter(df: pd.DataFrame, scale: int, size: str) -> None:
plt.figure(figsize=(10, 6), dpi=300)
# Scale VRAM directly between its min and max for circle sizes
vram_min, vram_max = df["vram"].min(), df["vram"].max()
vram_scaled = np.interp(
df["vram"], (vram_min, vram_max), (10, 10000)
) # Circle size range (10, 1000)
# Use a colormap for dot colors
colors = plt.cm.tab10(np.linspace(0, 1, len(df))) # type: ignore
for i, row in df.iterrows():
# Plot the shaded circle (background)
plt.scatter(
row["fps"],
row["psnr_df2k"],
s=vram_scaled[i], # type: ignore
color=colors[i],
alpha=0.2,
label=row["name"],
)
# Plot the main dot
plt.scatter(
row["fps"], row["psnr_df2k"], color=colors[i], edgecolors="black", alpha=0.9
)
# Annotate each point with the model name
plt.annotate(
row["name"], # pyright: ignore[reportArgumentType]
(row["fps"], row["psnr_df2k"]), # pyright: ignore[reportArgumentType]
textcoords="offset points",
xytext=(5, 5),
ha="center",
fontsize=5,
)
type = "Restoration" if scale == 1 else "Upscale"
plt.title(
f"{scale}x {type} {size} DF2K Urban100 PSNR vs FPS on 640x480 input with RTX 4090",
fontsize=16,
)
plt.xlabel("FPS (PyTorch)", fontsize=12)
plt.ylabel("DF2K Urban100 PSNR", fontsize=12)
# plt.grid(True)
plt.grid(
which="major", linestyle="-", linewidth=0.75, color="darkgray", alpha=0.75
) # Solid major gridlines
plt.grid(
which="minor", linestyle="-", linewidth=0.5, color="gray", alpha=0.25
) # Dashed minor gridlines
plt.minorticks_on()
plt.savefig(f"docs/source/resources/benchmark{scale}x_{size.lower()}.png")
def plot_single(df: pd.DataFrame, scale: int, key: str) -> None:
df = df.sort_values(by=key, ascending=False)
plt.figure(figsize=(20, 6), dpi=300)
plt.bar(df["name"], df[key], color="purple", edgecolor="black")
plt.title(f"{scale}x Model {key}", fontsize=16)
plt.xlabel("Model", fontsize=12)
plt.ylabel(key, fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.grid(True)
min_params = df[key].min()
max_params = df[key].max()
plt.ylim(min_params - 1, max_params + 1)
plt.tight_layout()
plt.show()
def main() -> None:
for scale in [4, 3, 2]:
file_path = f"docs/source/resources/benchmark{scale}x.csv"
with open(file_path, newline="") as f:
reader = csv.DictReader(f)
d = list(reader)
d = [row for row in d if row["psnr_df2k"] != "-"]
for row in d:
row["params"] = int(row["params"])
row["fps"] = float(row["fps"])
row["vram"] = float(row["vram"])
if row["psnr_df2k"] != "-":
row["psnr_df2k"] = float(row["psnr_df2k"])
row["name"] = f"{row['name']} {row['variant']}".strip()
threshold1 = 2
threshold2 = 24
dsmall = [row for row in d if threshold2 < float(row["fps"])]
dmed = [row for row in d if threshold1 <= float(row["fps"]) <= threshold2]
dlarge = [row for row in d if float(row["fps"]) < threshold1]
dfsmall = process_table(pd.DataFrame(dsmall))
dfmed = process_table(pd.DataFrame(dmed))
dflarge = process_table(pd.DataFrame(dlarge))
df = process_table(pd.DataFrame(d))
plot_scatter(dfsmall, scale, "Fast")
plot_scatter(dfmed, scale, "Medium")
plot_scatter(dflarge, scale, "Slow")
plot_scatter(df, scale, "All")
# plot_single(df, scale, "psnr_df2k")
if __name__ == "__main__":
main()