{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "99FBiGH7bsfn"
      },
      "source": [
        "# Compiling \u0026 Visualizing Tracr Models\n",
        "\n",
        "This notebook demonstrates how to compile a tracr model and provides some tools visualize the model's residual stream or layer outputs for a given input sequence."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qm-PM1PEawCx"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "import jax\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# The default of float16 can lead to discrepancies between outputs of\n",
        "# the compiled model and the RASP program.\n",
        "jax.config.update('jax_default_matmul_precision', 'float32')\n",
        "\n",
        "from tracr.compiler import compiling\n",
        "from tracr.compiler import lib\n",
        "from tracr.rasp import rasp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HtOAc_yWawFR"
      },
      "outputs": [],
      "source": [
        "#@title Plotting functions\n",
        "def tidy_label(label, value_width=5):\n",
        "  if ':' in label:\n",
        "    label, value = label.split(':')\n",
        "  else:\n",
        "    value = ''\n",
        "  return label + f\":{value:\u003e{value_width}}\"\n",
        "\n",
        "\n",
        "def add_residual_ticks(model, value_width=5, x=False, y=True):\n",
        "  if y:\n",
        "    plt.yticks(\n",
        "            np.arange(len(model.residual_labels))+0.5, \n",
        "            [tidy_label(l, value_width=value_width)\n",
        "              for l in model.residual_labels], \n",
        "            family='monospace',\n",
        "            fontsize=20,\n",
        "    )\n",
        "  if x:\n",
        "    plt.xticks(\n",
        "            np.arange(len(model.residual_labels))+0.5, \n",
        "            [tidy_label(l, value_width=value_width)\n",
        "              for l in model.residual_labels], \n",
        "            family='monospace',\n",
        "            rotation=90,\n",
        "            fontsize=20,\n",
        "    )\n",
        "\n",
        "\n",
        "def plot_computation_trace(model,\n",
        "                           input_labels,\n",
        "                           residuals_or_outputs,\n",
        "                           add_input_layer=False,\n",
        "                           figsize=(12, 9)):\n",
        "  fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)\n",
        "  value_width = max(map(len, map(str, input_labels))) + 1\n",
        "\n",
        "  for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):\n",
        "    plt.sca(ax)\n",
        "    plt.pcolormesh(layer[0].T, vmin=0, vmax=1)\n",
        "    if i == 0:\n",
        "      add_residual_ticks(model, value_width=value_width)\n",
        "    plt.xticks(\n",
        "        np.arange(len(input_labels))+0.5,\n",
        "        input_labels,\n",
        "        rotation=90,\n",
        "        fontsize=20,\n",
        "    )\n",
        "    if add_input_layer and i == 0:\n",
        "      title = 'Input'\n",
        "    else:\n",
        "      layer_no = i - 1 if add_input_layer else i\n",
        "      layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'\n",
        "      title = f'{layer_type} {layer_no // 2 + 1}'\n",
        "    plt.title(title, fontsize=20)\n",
        "\n",
        "\n",
        "def plot_residuals_and_input(model, inputs, figsize=(12, 9)):\n",
        "  \"\"\"Applies model to inputs, and plots the residual stream at each layer.\"\"\"\n",
        "  model_out = assembled_model.apply(inputs)\n",
        "  residuals = np.concatenate([model_out.input_embeddings[None, ...],\n",
        "                              model_out.residuals], axis=0)\n",
        "  plot_computation_trace(\n",
        "      model=model,\n",
        "      input_labels=inputs,\n",
        "      residuals_or_outputs=residuals,\n",
        "      add_input_layer=True,\n",
        "      figsize=figsize)\n",
        "\n",
        "\n",
        "def plot_layer_outputs(model, inputs, figsize=(12, 9)):\n",
        "  \"\"\"Applies model to inputs, and plots the outputs of each layer.\"\"\"\n",
        "  model_out = assembled_model.apply(inputs)\n",
        "  plot_computation_trace(\n",
        "      model=model,\n",
        "      input_labels=inputs,\n",
        "      residuals_or_outputs=model_out.layer_outputs,\n",
        "      add_input_layer=False,\n",
        "      figsize=figsize)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "8hV0nv_ISmhM"
      },
      "outputs": [],
      "source": [
        "#@title Define RASP programs\n",
        "def get_program(program_name, max_seq_len):\n",
        "  \"\"\"Returns RASP program and corresponding token vocabulary.\"\"\"\n",
        "  if program_name == \"length\":\n",
        "    vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
        "    program = lib.make_length()\n",
        "  elif program_name == \"frac_prevs\":\n",
        "    vocab = {\"a\", \"b\", \"c\", \"x\"}\n",
        "    program = lib.make_frac_prevs((rasp.tokens == \"x\").named(\"is_x\"))\n",
        "  elif program_name == \"dyck-2\":\n",
        "    vocab = {\"(\", \")\", \"{\", \"}\"}\n",
        "    program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\"])\n",
        "  elif program_name == \"dyck-3\":\n",
        "    vocab = {\"(\", \")\", \"{\", \"}\", \"[\", \"]\"}\n",
        "    program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\", \"[]\"])\n",
        "  elif program_name == \"sort\":\n",
        "    vocab = {1, 2, 3, 4, 5}\n",
        "    program = lib.make_sort(\n",
        "        rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)\n",
        "  elif program_name == \"sort_unique\":\n",
        "    vocab = {1, 2, 3, 4, 5}\n",
        "    program = lib.make_sort_unique(rasp.tokens, rasp.tokens)\n",
        "  elif program_name == \"hist\":\n",
        "    vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
        "    program = lib.make_hist()\n",
        "  elif program_name == \"sort_freq\":\n",
        "    vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
        "    program = lib.make_sort_freq(max_seq_len=max_seq_len)\n",
        "  elif program_name == \"pair_balance\":\n",
        "    vocab = {\"(\", \")\"}\n",
        "    program = lib.make_pair_balance(\n",
        "        sop=rasp.tokens, open_token=\"(\", close_token=\")\")\n",
        "  else:\n",
        "    raise NotImplementedError(f\"Program {program_name} not implemented.\")\n",
        "  return program, vocab"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L_m_ufaua9ri"
      },
      "outputs": [],
      "source": [
        "#@title: Assemble model\n",
        "program_name = \"sort_unique\"  #@param [\"length\", \"frac_prevs\", \"dyck-2\", \"dyck-3\", \"sort\", \"sort_unique\", \"hist\", \"sort_freq\", \"pair_balance\"]\n",
        "max_seq_len = 5  #@param {label: \"Test\", type: \"integer\"}\n",
        "\n",
        "program, vocab = get_program(program_name=program_name,\n",
        "                             max_seq_len=max_seq_len)\n",
        "\n",
        "print(f\"Compiling...\")\n",
        "print(f\"   Program: {program_name}\")\n",
        "print(f\"   Input vocabulary: {vocab}\")\n",
        "print(f\"   Context size: {max_seq_len}\")\n",
        "\n",
        "assembled_model = compiling.compile_rasp_to_model(\n",
        "      program=program,\n",
        "      vocab=vocab,\n",
        "      max_seq_len=max_seq_len,\n",
        "      causal=False,\n",
        "      compiler_bos=\"bos\",\n",
        "      compiler_pad=\"pad\",\n",
        "      mlp_exactness=100)\n",
        "\n",
        "print(\"Done.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wtwiE-JiXF3F"
      },
      "outputs": [],
      "source": [
        "#@title Forward pass\n",
        "assembled_model.apply([\"bos\", 3, 4, 1]).decoded"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RkEkVcEHa2gf"
      },
      "outputs": [],
      "source": [
        "#@title Plot residual stream\n",
        "plot_residuals_and_input(\n",
        "  model=assembled_model,\n",
        "  inputs=[\"bos\", 3, 4, 1],\n",
        "  figsize=(10, 9)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8c4LakWHa4ey"
      },
      "outputs": [],
      "source": [
        "#@title Plot layer outputs\n",
        "plot_layer_outputs(\n",
        "  model=assembled_model,\n",
        "  inputs = [\"bos\", 3, 4, 1],\n",
        "  figsize=(8, 9)\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}