# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Add craft model blocks to graph of RASPExpr."""

from typing import Any, Callable, Optional

import networkx as nx
from tracr.compiler import nodes
from tracr.craft import bases
from tracr.craft.chamber import categorical_attn
from tracr.craft.chamber import categorical_mlp
from tracr.craft.chamber import numerical_mlp
from tracr.craft.chamber import selector_width
from tracr.rasp import rasp


def _transform_fun_to_basis_fun(
    fun: Callable[..., Any],
    output_direction_name: Optional[str] = None) -> Callable[..., Any]:
  """Transforms a function acting on values into one acting on directions."""

  def bases_fun(*args):
    values = [d.value for d in args]
    result = fun(*values)
    if output_direction_name:
      return bases.BasisDirection(output_direction_name, result)
    return result

  return bases_fun


def _check_selector_expression(expr, graph):
  """Check graph structure and encodings for an aggregate or selector width."""
  sel_expr = expr.selector

  # Check graph structure
  assert sel_expr.label in graph.predecessors(expr.label)
  assert sel_expr.keys.label in graph.predecessors(sel_expr.label)
  assert sel_expr.queries.label in graph.predecessors(sel_expr.label)

  if (not rasp.is_categorical(sel_expr.queries) or
      not rasp.is_categorical(sel_expr.keys)):
    raise ValueError("Selector keys and queries must be categorical.")


def add_craft_components_to_rasp_graph(
    graph: nx.DiGraph,
    bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"),
    one_dir: bases.BasisDirection = bases.BasisDirection("one"),
    causal: bool = False,
    mlp_exactness: float = 100,
) -> None:
  """Translates expressions to craft blocks and attaches them to the graph.

  Sets the `MODEL_BLOCK` attribute for all nodes in `graph`.

  Args:
    graph: RASP graph with  `VALUE_SET` but not `MODEL_BLOCK` attributes.
    bos_dir: Basis direction representing beginning of sequence (bos) token.
    one_dir: Auxiliary basis direction that must contain 1.
    causal: If True, marks attention blocks as causal.
    mlp_exactness: Controls the approximation of the MLP layers.

  Raises:
    ValueError: On invalid input (if `MODEL_BLOCK` is set already, or
      `VALUE_SET` is not set already)
    NotImplementedError: If the graph contains an unsupported expression.
  """
  one_space = bases.VectorSpaceWithBasis([one_dir])

  for node_id, node in graph.nodes.items():
    expr = node[nodes.EXPR]

    if not isinstance(expr, rasp.SOp):
      continue

    if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]:
      raise ValueError("Input graph cannot have model blocks set already.")
    if nodes.VALUE_SET not in node:
      raise ValueError(
          "Craft components can only be added after basis inference.")

    if expr is rasp.tokens or expr is rasp.indices:
      block = None
    elif isinstance(expr, rasp.Map):
      inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label]
      assert inner_expr.label in graph.predecessors(node_id)
      input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS])
      output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr):
        basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
        block = categorical_mlp.map_categorical_mlp(
            input_space=input_space,
            output_space=output_space,
            operation=basis_fun)
      elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr):
        block = categorical_mlp.map_categorical_to_numerical_mlp(
            input_space=input_space,
            output_space=output_space,
            operation=expr.f,
        )
      elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr):
        block = numerical_mlp.map_numerical_to_categorical_mlp(
            f=expr.f,
            input_space=input_space,
            output_space=output_space,
            input_value_set=inner_node[nodes.VALUE_SET],
            one_space=one_space,
            hidden_name=f"_hidden_{expr.label}_",
            large_number=mlp_exactness)
      elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr):
        block = numerical_mlp.map_numerical_mlp(
            f=expr.f,
            input_space=input_space,
            output_space=output_space,
            input_value_set=inner_node[nodes.VALUE_SET],
            one_space=one_space,
            hidden_name=f"_hidden_{expr.label}_",
            large_number=mlp_exactness)
      else:
        raise NotImplementedError("Map does no support "
                                  f"in_type '{inner_expr.type}' and"
                                  f" out_type '{expr.type}'!")

    elif isinstance(expr, rasp.SequenceMap):
      fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label]
      snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label]

      # Check graph structure
      assert fst_expr.label in graph.predecessors(node_id)
      assert snd_expr.label in graph.predecessors(node_id)

      fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS])
      snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS])
      out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      if (isinstance(expr, rasp.LinearSequenceMap) and
          not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))):
        raise NotImplementedError("Linear SequenceMap only supports numerical "
                                  "inputs/outputs.")
      elif (
          not isinstance(expr, rasp.LinearSequenceMap) and
          not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))):
        raise NotImplementedError("(Non-linear) SequenceMap only supports "
                                  "categorical inputs/outputs.")

      if isinstance(expr, rasp.LinearSequenceMap):
        assert len(fst_space.basis) == 1
        assert len(snd_space.basis) == 1
        assert len(out_space.basis) == 1
        block = numerical_mlp.linear_sequence_map_numerical_mlp(
            input1_basis_direction=fst_space.basis[0],
            input2_basis_direction=snd_space.basis[0],
            output_basis_direction=out_space.basis[0],
            input1_factor=expr.fst_fac,
            input2_factor=expr.snd_fac,
            hidden_name=f"_hidden_{expr.label}_")
      elif fst_space == snd_space:
        # It's okay to use the local variable expr.f because it is
        # only used within the same loop iteration to create the MLP.
        # pylint: disable=cell-var-from-loop
        basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x),
                                                expr.label)
        block = categorical_mlp.map_categorical_mlp(
            input_space=fst_space, output_space=out_space, operation=basis_fun)
      else:
        basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
        block = categorical_mlp.sequence_map_categorical_mlp(
            input1_space=fst_space,
            input2_space=snd_space,
            output_space=out_space,
            operation=basis_fun,
            one_space=one_space,
            hidden_name=f"_hidden_{expr.label}_")
    elif isinstance(expr, rasp.Aggregate):
      sel_expr: rasp.Select = expr.selector
      agg_expr: rasp.Aggregate = expr

      if not isinstance(sel_expr, rasp.Select):
        raise TypeError("Compiling composite Selectors is not supported. "
                        f"Got a {sel_expr}.")

      queries = graph.nodes[sel_expr.queries.label]
      keys = graph.nodes[sel_expr.keys.label]
      sop = graph.nodes[agg_expr.sop.label]

      _check_selector_expression(expr, graph)
      assert agg_expr.sop.label in graph.predecessors(node_id)
      if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr):
        raise ValueError(
            "sop encoding must match output encoding of the aggregate.")
      if rasp.is_categorical(agg_expr) and agg_expr.default is not None:
        raise ValueError("Default for a categorical aggregate must be None. "
                         f"Got {agg_expr.default}")
      if rasp.is_numerical(agg_expr) and agg_expr.default != 0:
        raise ValueError("Default for a numerical aggregate must be 0. "
                         f"Got {agg_expr.default}")

      bos_space = bases.VectorSpaceWithBasis([bos_dir])
      one_space = bases.VectorSpaceWithBasis([one_dir])
      query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
      key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
      value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS])
      output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      # Argument order is different in craft / transformers than RASP selectors
      def attn_basis_fn(query: bases.BasisDirection,
                        key: bases.BasisDirection) -> bool:
        # It's okay to use the local variable sel_expr because this function is
        # only used within the same loop iteration to create an attention head.
        # pylint: disable=cell-var-from-loop
        selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate)
        return selector_basis_fn(key, query)

      block = categorical_attn.categorical_attn(
          query_space=query_space,
          key_space=key_space,
          value_space=value_space,
          output_space=output_space,
          bos_space=bos_space,
          one_space=one_space,
          attn_fn=attn_basis_fn,
          default_output=output_space.null_vector(),
          causal=causal,
          always_attend_to_bos=False,
          use_bos_for_default_output=True,
          softmax_coldness=100)
    elif isinstance(expr, rasp.SelectorWidth):
      sel_expr = expr.selector
      queries = graph.nodes[sel_expr.queries.label]
      keys = graph.nodes[sel_expr.keys.label]
      _check_selector_expression(expr, graph)

      bos_space = bases.VectorSpaceWithBasis([bos_dir])
      query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
      key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
      output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])

      # Argument order is different in craft / transformers than RASP selectors
      def attn_basis_fn(query: bases.BasisDirection,
                        key: bases.BasisDirection) -> bool:
        # It's okay to use the local variable sel_expr because this function is
        # only used within the same loop iteration to create an attention head.
        selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate)  # pylint: disable=cell-var-from-loop
        return selector_basis_fn(key, query)

      block = selector_width.selector_width(
          query_space=query_space,
          key_space=key_space,
          output_space=output_space,
          bos_space=bos_space,
          one_space=one_space,
          attn_fn=attn_basis_fn,
          out_value_set=node[nodes.VALUE_SET],
          categorical_output=rasp.is_categorical(expr),
          causal=False,
          softmax_coldness=100,
          mlp_large_number=mlp_exactness,
          label=expr.label)
    else:
      raise NotImplementedError(f"Expression {expr} cannot be translated to "
                                "a model component.")

    graph.nodes[node_id][nodes.MODEL_BLOCK] = block