# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for vectorspace_fns."""

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tracr.craft import bases
from tracr.craft import tests_common
from tracr.craft import vectorspace_fns as vs_fns


class LinearTest(tests_common.VectorFnTestCase):

  def test_identity_from_matrix(self):
    vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
    f = vs_fns.Linear(vs, vs, np.eye(3))
    for v in vs.basis_vectors():
      self.assertEqual(f(v), v)

  def test_identity_from_action(self):
    vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
    f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction)
    for v in vs.basis_vectors():
      self.assertEqual(f(v), v)

  def test_nonidentiy(self):
    vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
    b = vs.vector_from_basis_direction(bases.BasisDirection("b"))

    f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]]))

    self.assertEqual(
        f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7])))
    self.assertEqual(
        f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1])))

  def test_different_vector_spaces(self):
    vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
    a, b = vs1.basis_vectors()
    c, d = vs2.basis_vectors()

    f = vs_fns.Linear(vs1, vs2, np.eye(2))

    self.assertEqual(f(a), c)
    self.assertEqual(f(b), d)

  def test_combining_linear_functions_with_different_input(self):
    vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
    vs = bases.direct_sum(vs1, vs2)
    a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
    b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
    c = vs.vector_from_basis_direction(bases.BasisDirection("c"))
    d = vs.vector_from_basis_direction(bases.BasisDirection("d"))

    f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]]))
    f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]]))
    f3 = vs_fns.Linear.combine_in_parallel([f1, f2])

    self.assertEqual(
        f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0])))
    self.assertEqual(
        f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0])))
    self.assertEqual(
        f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0])))
    self.assertEqual(
        f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0])))

  def test_combining_linear_functions_with_same_input(self):
    vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
    b = vs.vector_from_basis_direction(bases.BasisDirection("b"))

    f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]]))
    f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]]))
    f3 = vs_fns.Linear.combine_in_parallel([f1, f2])

    self.assertEqual(
        f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1])))
    self.assertEqual(
        f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0])))
    self.assertEqual(f3(a), f1(a) + f2(a))
    self.assertEqual(f3(b), f1(b) + f2(b))


class ProjectionTest(tests_common.VectorFnTestCase):

  def test_projection_to_larger_space(self):
    vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
    a1, b1 = vs1.basis_vectors()
    a2, b2, _, _ = vs2.basis_vectors()

    f = vs_fns.project(vs1, vs2)

    self.assertEqual(f(a1), a2)
    self.assertEqual(f(b1), b2)

  def test_projection_to_smaller_space(self):
    vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
    vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    a1, b1, c1, d1 = vs1.basis_vectors()
    a2, b2 = vs2.basis_vectors()

    f = vs_fns.project(vs1, vs2)

    self.assertEqual(f(a1), a2)
    self.assertEqual(f(b1), b2)
    self.assertEqual(f(c1), vs2.null_vector())
    self.assertEqual(f(d1), vs2.null_vector())


class ScalarBilinearTest(parameterized.TestCase):

  def test_identity_matrix(self):
    vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    a, b = vs.basis_vectors()

    f = vs_fns.ScalarBilinear(vs, vs, np.eye(2))

    self.assertEqual(f(a, a), 1)
    self.assertEqual(f(a, b), 0)
    self.assertEqual(f(b, a), 0)
    self.assertEqual(f(b, b), 1)

  def test_identity_from_action(self):
    vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    a, b = vs.basis_vectors()

    f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y))

    self.assertEqual(f(a, a), 1)
    self.assertEqual(f(a, b), 0)
    self.assertEqual(f(b, a), 0)
    self.assertEqual(f(b, b), 1)

  def test_non_identity(self):
    vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
    a, b = vs.basis_vectors()

    f = vs_fns.ScalarBilinear.from_action(vs, vs,
                                          lambda x, y: int(x.name == "a"))

    self.assertEqual(f(a, a), 1)
    self.assertEqual(f(a, b), 1)
    self.assertEqual(f(b, a), 0)
    self.assertEqual(f(b, b), 0)


if __name__ == "__main__":
  absltest.main()