# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for compiler.craft_graph_to_model."""

from absl.testing import absltest
from absl.testing import parameterized
import networkx as nx
from tracr.compiler import craft_graph_to_model
from tracr.compiler import nodes
from tracr.compiler import rasp_to_graph
from tracr.craft import bases
from tracr.craft.chamber import categorical_attn
from tracr.craft.chamber import categorical_mlp
from tracr.rasp import rasp


class CraftAllocateModulesToLayersTest(parameterized.TestCase):

  def _get_dummy_block(self, block_type):
    if block_type == "ATTN":
      return categorical_attn.categorical_attn(
          query_space=bases.VectorSpaceWithBasis.from_names(["query"]),
          key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]),
          value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]),
          output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
          bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]),
          one_space=bases.VectorSpaceWithBasis.from_names(["one"]),
          attn_fn=lambda x, y: True,
      )
    elif block_type == "MLP":
      return categorical_mlp.map_categorical_mlp(
          input_space=bases.VectorSpaceWithBasis.from_names(["input"]),
          output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
          operation=lambda x: x,
      )
    else:
      return None

  def test_get_longest_path_length_to_node_returns_expected_result(self):
    """Creates a graph and checks the longest path for each node."""

    # Node IDs:
    # 0 -- 1 -- 2 -- 3 ------------  4
    #               /              /
    # 5 -- 6 ---------- 7 -- 8 -- 9
    #
    # 10
    # Expected return values:
    # 0 -- 1 -- 2 -- 3 ------------  5
    #               /              /
    # 0 -- 1 ---------- 2 -- 3 -- 4
    #
    # -1

    graph = nx.DiGraph()
    node_ids = list(range(11))
    expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1]
    for node_id, res in zip(node_ids, expected_results):
      graph.add_node(
          node_id, **{
              nodes.ID: node_id,
              nodes.EXPR: rasp.ConstantSOp(1),
              "expected_result": res
          })
    graph.add_edge(0, 1)
    graph.add_edge(1, 2)
    graph.add_edge(2, 3)
    graph.add_edge(3, 4)
    graph.add_edge(5, 6)
    graph.add_edge(6, 7)
    graph.add_edge(7, 8)
    graph.add_edge(8, 9)
    graph.add_edge(6, 3)
    graph.add_edge(9, 4)
    sources = [graph.nodes[0], graph.nodes[5]]

    for node_id, node in graph.nodes.items():
      result = craft_graph_to_model._get_longest_path_length_to_node(
          graph, sources, node)
      self.assertEqual(result, node["expected_result"])

  def test_allocate_modules_to_layers_returns_expected_result(self):
    """Creates a graph and checks if the correct layer assignment is returned."""

    # Computation Graph:
    # INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT
    #           /           /          /
    # INPUT -- MLP --- MLP          ATTN
    #                      \      /
    #                        ATTN
    # Node IDs:
    # 0 -- 1 -- 2 -- 3 -- 4 -- 5
    #         /     /     /
    # 6 -- 7 ---- 8      9
    #               \   /
    #                10
    # Expected layer allocation:
    # -1 -- 0 -- 3 -- 4 -- 7 -- -1
    #         /     /     /
    # -1 -- 1 --- 3      6
    #               \   /
    #                 4

    graph = nx.DiGraph()
    node_ids = list(range(11))
    types = [
        "INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP",
        "ATTN", "ATTN"
    ]
    expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4]
    for node_id, node_type, res in zip(node_ids, types, expected_results):
      graph.add_node(
          node_id, **{
              nodes.ID: node_id,
              nodes.EXPR: rasp.ConstantSOp(1),
              nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
              "expected_result": res
          })

    graph.add_edge(0, 1)
    graph.add_edge(1, 2)
    graph.add_edge(2, 3)
    graph.add_edge(3, 4)
    graph.add_edge(4, 5)
    graph.add_edge(6, 7)
    graph.add_edge(7, 2)
    graph.add_edge(7, 8)
    graph.add_edge(8, 3)
    graph.add_edge(8, 10)
    graph.add_edge(9, 4)
    graph.add_edge(10, 9)

    craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
        graph=graph,
        sink=graph.nodes[10],
        sources=[graph.nodes[0], graph.nodes[6]])

    layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
        craft_graph.graph, craft_graph.sources)
    for node_id, node in graph.nodes.items():
      self.assertEqual(layer_allocation[node_id], node["expected_result"])

  def test_allocate_modules_to_layers_returns_expected_result_for_chain(self):
    """Tests a chain of alternating attention layers and MLPs."""

    # Computation Graph:
    # INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT
    # Node IDs:
    # 0 -- 1 -- 2 -- 3 -- 4 -- 5
    # Expected layer allocation:
    # -1 -- 0 -- 1 -- 2 -- 3 -- -1

    graph = nx.DiGraph()
    node_ids = list(range(11))
    types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"]
    expected_results = [-1, 0, 1, 2, 3, -1]
    for node_id, node_type, res in zip(node_ids, types, expected_results):
      graph.add_node(
          node_id, **{
              nodes.ID: node_id,
              nodes.EXPR: rasp.ConstantSOp(1),
              nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
              "expected_result": res
          })

    graph.add_edge(0, 1)
    graph.add_edge(1, 2)
    graph.add_edge(2, 3)
    graph.add_edge(3, 4)
    graph.add_edge(4, 5)

    craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
        graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]])

    layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
        craft_graph.graph, craft_graph.sources)
    for node_id, node in graph.nodes.items():
      self.assertEqual(layer_allocation[node_id], node["expected_result"])


if __name__ == "__main__":
  absltest.main()