davidpomerenke's picture
Upload from GitHub Actions: TruthfulQA translation WIP
fd102e9 verified
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.special import logit
df = pd.read_json("../results.json")
df = df[df["metric"] != "chrf"]
df = df.groupby(["task", "metric", "bcp_47"]).agg({"score": "mean"}).reset_index()
# Apply logit transformation to classification scores to reduce skewness
def transform_classification_scores(row):
if row['task'] == 'classification':
# Avoid division by zero and infinite values by clipping
score = np.clip(row['score'], 0.001, 0.999)
# Apply logit transformation (log(p/(1-p)))
return logit(score)
else:
return row['score']
df['score'] = df.apply(transform_classification_scores, axis=1)
# Create a pivot table with tasks as columns and languages as rows
pivot_df = df.pivot_table(
values='score',
index='bcp_47',
columns='task',
aggfunc='mean'
)
# Sort and filter tasks
ordered_tasks = [
'translation_from',
'translation_to',
'classification',
'mmlu',
'arc',
'mgsm',
]
# Drop 'truthfulqa' if present and reindex columns
pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]]
# Calculate correlation matrix
correlation_matrix = pivot_df.corr()
# Create the correlation plot
plt.figure(figsize=(8, 6))
# Create mask for upper triangle including diagonal to show only lower triangle
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
# Create a heatmap
sns.heatmap(
correlation_matrix,
annot=True,
cmap='Blues',
center=0,
square=True,
mask=mask,
cbar_kws={"shrink": .8},
fmt='.3f'
)
plt.xlabel('Tasks', fontsize=12)
plt.ylabel('Tasks', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
# Save the plot
plt.savefig('task_correlation_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
# Print correlation values for reference
print("Correlation Matrix:")
print("Note: Classification scores have been logit-transformed to reduce skewness")
print(correlation_matrix.round(3))
# Also create a scatter plot matrix for pairwise relationships with highlighted languages
highlighted_languages = ['en', 'zh', 'hi', 'es', 'ar']
# Create color mapping
def get_color_and_label(lang_code):
if lang_code in highlighted_languages:
color_map = {'en': 'red', 'zh': 'blue', 'hi': 'green', 'es': 'orange', 'ar': 'purple'}
return color_map[lang_code], lang_code
else:
return 'lightgray', 'Other'
# Create custom scatter plot matrix
tasks = pivot_df.columns.tolist()
n_tasks = len(tasks)
fig, axes = plt.subplots(n_tasks, n_tasks, figsize=(15, 12))
fig.suptitle('Pairwise Task Performance', fontsize=16, fontweight='bold')
# Create legend elements
legend_elements = []
for lang in highlighted_languages:
color, _ = get_color_and_label(lang)
legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=8, label=lang))
legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgray', markersize=8, label='Other'))
for i, task_y in enumerate(tasks):
for j, task_x in enumerate(tasks):
ax = axes[i, j]
if i == j:
# Diagonal: histogram
task_data = pivot_df[task_y].dropna()
colors = [get_color_and_label(lang)[0] for lang in task_data.index]
ax.hist(task_data, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
ax.set_title(f'{task_y}', fontsize=10)
else:
# Off-diagonal: scatter plot
for lang_code in pivot_df.index:
if pd.notna(pivot_df.loc[lang_code, task_x]) and pd.notna(pivot_df.loc[lang_code, task_y]):
color, _ = get_color_and_label(lang_code)
alpha = 0.8 if lang_code in highlighted_languages else 0.3
size = 50 if lang_code in highlighted_languages else 20
ax.scatter(pivot_df.loc[lang_code, task_x], pivot_df.loc[lang_code, task_y],
c=color, alpha=alpha, s=size)
# Set labels
if i == n_tasks - 1:
ax.set_xlabel(task_x, fontsize=10)
if j == 0:
ax.set_ylabel(task_y, fontsize=10)
# Remove tick labels except for edges
if i != n_tasks - 1:
ax.set_xticklabels([])
if j != 0:
ax.set_yticklabels([])
# Add legend
fig.legend(
handles=legend_elements,
loc='lower center',
bbox_to_anchor=(0.5, -0.05),
ncol=len(legend_elements),
frameon=False,
fontsize=10,
handletextpad=0.5,
columnspacing=1.0
)
plt.tight_layout()
plt.savefig('task_scatter_matrix.png', dpi=300, bbox_inches='tight')
plt.show()