import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import numpy as np from scipy.special import logit df = pd.read_json("../results.json") df = df[df["metric"] != "chrf"] df = df.groupby(["task", "metric", "bcp_47"]).agg({"score": "mean"}).reset_index() # Apply logit transformation to classification scores to reduce skewness def transform_classification_scores(row): if row['task'] == 'classification': # Avoid division by zero and infinite values by clipping score = np.clip(row['score'], 0.001, 0.999) # Apply logit transformation (log(p/(1-p))) return logit(score) else: return row['score'] df['score'] = df.apply(transform_classification_scores, axis=1) # Create a pivot table with tasks as columns and languages as rows pivot_df = df.pivot_table( values='score', index='bcp_47', columns='task', aggfunc='mean' ) # Sort and filter tasks ordered_tasks = [ 'translation_from', 'translation_to', 'classification', 'mmlu', 'arc', 'mgsm', ] # Drop 'truthfulqa' if present and reindex columns pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]] # Calculate correlation matrix correlation_matrix = pivot_df.corr() # Create the correlation plot plt.figure(figsize=(8, 6)) # Create mask for upper triangle including diagonal to show only lower triangle mask = np.triu(np.ones_like(correlation_matrix, dtype=bool)) # Create a heatmap sns.heatmap( correlation_matrix, annot=True, cmap='Blues', center=0, square=True, mask=mask, cbar_kws={"shrink": .8}, fmt='.3f' ) plt.xlabel('Tasks', fontsize=12) plt.ylabel('Tasks', fontsize=12) plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() # Save the plot plt.savefig('task_correlation_matrix.png', dpi=300, bbox_inches='tight') plt.show() # Print correlation values for reference print("Correlation Matrix:") print("Note: Classification scores have been logit-transformed to reduce skewness") print(correlation_matrix.round(3)) # Also create a scatter plot matrix for pairwise relationships with highlighted languages highlighted_languages = ['en', 'zh', 'hi', 'es', 'ar'] # Create color mapping def get_color_and_label(lang_code): if lang_code in highlighted_languages: color_map = {'en': 'red', 'zh': 'blue', 'hi': 'green', 'es': 'orange', 'ar': 'purple'} return color_map[lang_code], lang_code else: return 'lightgray', 'Other' # Create custom scatter plot matrix tasks = pivot_df.columns.tolist() n_tasks = len(tasks) fig, axes = plt.subplots(n_tasks, n_tasks, figsize=(15, 12)) fig.suptitle('Pairwise Task Performance', fontsize=16, fontweight='bold') # Create legend elements legend_elements = [] for lang in highlighted_languages: color, _ = get_color_and_label(lang) legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=8, label=lang)) legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgray', markersize=8, label='Other')) for i, task_y in enumerate(tasks): for j, task_x in enumerate(tasks): ax = axes[i, j] if i == j: # Diagonal: histogram task_data = pivot_df[task_y].dropna() colors = [get_color_and_label(lang)[0] for lang in task_data.index] ax.hist(task_data, bins=20, alpha=0.7, color='skyblue', edgecolor='black') ax.set_title(f'{task_y}', fontsize=10) else: # Off-diagonal: scatter plot for lang_code in pivot_df.index: if pd.notna(pivot_df.loc[lang_code, task_x]) and pd.notna(pivot_df.loc[lang_code, task_y]): color, _ = get_color_and_label(lang_code) alpha = 0.8 if lang_code in highlighted_languages else 0.3 size = 50 if lang_code in highlighted_languages else 20 ax.scatter(pivot_df.loc[lang_code, task_x], pivot_df.loc[lang_code, task_y], c=color, alpha=alpha, s=size) # Set labels if i == n_tasks - 1: ax.set_xlabel(task_x, fontsize=10) if j == 0: ax.set_ylabel(task_y, fontsize=10) # Remove tick labels except for edges if i != n_tasks - 1: ax.set_xticklabels([]) if j != 0: ax.set_yticklabels([]) # Add legend fig.legend( handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=len(legend_elements), frameon=False, fontsize=10, handletextpad=0.5, columnspacing=1.0 ) plt.tight_layout() plt.savefig('task_scatter_matrix.png', dpi=300, bbox_inches='tight') plt.show()