rocm_jax/tests/vectorize_test.py
2020-06-29 11:08:57 -07:00

146 lines
5.1 KiB
Python

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Vectorize library."""
from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
from jax import test_util as jtu
from jax.experimental.vectorize import vectorize
from jax.config import config
config.parse_flags_with_absl()
matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot)
matvec = vectorize('(n,m),(m)->(n)')(jnp.dot)
vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot)
vecvec = vectorize('(m),(m)->()')(jnp.dot)
@vectorize('(n)->()')
def magnitude(x):
return jnp.dot(x, x)
mean = vectorize('(n)->()')(jnp.mean)
@vectorize('()->(n)')
def stack_plus_minus(x):
return jnp.stack([x, -x])
@vectorize('(n)->(),(n)')
def center(array):
bias = jnp.mean(array)
debiased = array - bias
return bias, debiased
class VectorizeTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
for left_shape, right_shape, result_shape in [
((2, 3), (3, 4), (2, 4)),
((2, 3), (1, 3, 4), (1, 2, 4)),
((5, 2, 3), (1, 3, 4), (5, 2, 4)),
((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
]))
def test_matmat(self, left_shape, right_shape, result_shape):
self.assertEqual(matmat(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
for left_shape, right_shape, result_shape in [
((2, 3), (3,), (2,)),
((2, 3), (1, 3), (1, 2)),
((4, 2, 3), (1, 3), (4, 2)),
((5, 4, 2, 3), (1, 3), (5, 4, 2)),
]))
def test_matvec(self, left_shape, right_shape, result_shape):
self.assertEqual(matvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_leftshape={}_rightshape={}".format(left_shape, right_shape),
"left_shape": left_shape, "right_shape": right_shape, "result_shape": result_shape}
for left_shape, right_shape, result_shape in [
((3,), (3,), ()),
((2, 3), (3,), (2,)),
((4, 2, 3), (3,), (4, 2)),
]))
def test_vecvec(self, left_shape, right_shape, result_shape):
self.assertEqual(vecvec(jnp.zeros(left_shape),
jnp.zeros(right_shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
"shape": shape, "result_shape": result_shape}
for shape, result_shape in [
((3,), ()),
((2, 3,), (2,)),
((1, 2, 3,), (1, 2)),
]))
def test_magnitude(self, shape, result_shape):
size = 1
for x in shape:
size *= x
self.assertEqual(magnitude(jnp.arange(size).reshape(shape)).shape, result_shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
"shape": shape, "result_shape": result_shape}
for shape, result_shape in [
((3,), ()),
((2, 3), (2,)),
((1, 2, 3, 4), (1, 2, 3)),
]))
def test_mean(self, shape, result_shape):
self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
def test_mean_axis(self):
self.assertEqual(mean(jnp.zeros((2, 3)), axis=0).shape, (3,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(shape),
"shape": shape, "result_shape": result_shape}
for shape, result_shape in [
((), (2,)),
((3,), (3,2,)),
]))
def test_stack_plus_minus(self, shape, result_shape):
self.assertEqual(stack_plus_minus(jnp.zeros(shape)).shape, result_shape)
def test_center(self):
b, a = center(jnp.arange(3))
self.assertEqual(a.shape, (3,))
self.assertEqual(b.shape, ())
self.assertAllClose(1.0, b, check_dtypes=False)
X = jnp.arange(12).reshape((3, 4))
b, a = center(X, axis=1)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (3,))
self.assertAllClose(jnp.mean(X, axis=1), b)
b, a = center(X, axis=0)
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (4,))
self.assertAllClose(jnp.mean(X, axis=0), b)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())