# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for nn module.""" import collections import itertools from absl.testing import absltest from absl.testing import parameterized import numpy as onp from jax import test_util as jtu from jax.test_util import check_grads from jax import nn from jax import random import jax import jax.numpy as np from jax.config import config config.parse_flags_with_absl() class NNFunctionsTest(jtu.JaxTestCase): @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): check_grads(nn.softplus, (1e-8,), 4, rtol=1e-2 if jtu.device_under_test() == "tpu" else None) def testSoftplusValue(self): val = nn.softplus(89.) self.assertAllClose(val, 89., check_dtypes=False) @jtu.skip_on_flag("jax_skip_slow_tests", True) def testEluGrad(self): check_grads(nn.elu, (1e4,), 4, eps=1.) def testEluValue(self): val = nn.elu(1e4) self.assertAllClose(val, 1e4, check_dtypes=False) @parameterized.parameters(*itertools.product( (np.float32, np.bfloat16, np.float16), (nn.gelu, nn.relu, nn.softplus, nn.sigmoid))) def testDtypeMatchesInput(self, dtype, fn): if dtype is np.float16 and jtu.device_under_test() == "tpu": self.skipTest("float16 not supported on TPU") x = np.zeros((), dtype=dtype) out = fn(x) self.assertEqual(out.dtype, dtype) @jtu.skip_on_devices("gpu", "tpu") def testEluMemory(self): # see https://github.com/google/jax/pull/1640 jax.make_jaxpr(nn.elu)(np.ones((10 ** 12,))) # don't oom @jtu.skip_on_devices("gpu", "tpu") def testHardTanhMemory(self): # see https://github.com/google/jax/pull/1640 jax.make_jaxpr(nn.hard_tanh)(np.ones((10 ** 12,))) # don't oom def testOneHot(self): actual = nn.one_hot(np.array([0, 1, 2]), 3) expected = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) self.assertAllClose(actual, expected, check_dtypes=True) actual = nn.one_hot(np.array([1, 2, 0]), 3) expected = np.array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]]) self.assertAllClose(actual, expected, check_dtypes=True) def testOneHotOutOfBound(self): actual = nn.one_hot(np.array([-1, 3]), 3) expected = np.array([[0., 0., 0.], [0., 0., 0.]]) self.assertAllClose(actual, expected, check_dtypes=True) def testOneHotNonArrayInput(self): actual = nn.one_hot([0, 1, 2], 3) expected = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) self.assertAllClose(actual, expected, check_dtypes=True) def testOneHotCustomDtype(self): actual = nn.one_hot(np.array([0, 1, 2]), 3, dtype=np.bool_) expected = np.array([[True, False, False], [False, True, False], [False, False, True]]) self.assertAllClose(actual, expected, check_dtypes=True) InitializerRecord = collections.namedtuple( "InitializerRecord", ["name", "initializer", "shapes"]) ALL_SHAPES = [(2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)] def initializer_record(name, initializer, min_dims=2, max_dims=4): shapes = [shape for shape in ALL_SHAPES if min_dims <= len(shape) <= max_dims] return InitializerRecord(name, initializer, shapes) INITIALIZER_RECS = [ initializer_record("uniform", nn.initializers.uniform, 1), initializer_record("normal", nn.initializers.normal, 1), initializer_record("he_normal", nn.initializers.he_normal), initializer_record("he_uniform", nn.initializers.he_uniform), initializer_record("glorot_normal", nn.initializers.glorot_normal), initializer_record("glorot_uniform", nn.initializers.glorot_uniform), initializer_record("lecun_normal", nn.initializers.lecun_normal), initializer_record("lecun_uniform", nn.initializers.lecun_uniform), initializer_record("orthogonal", nn.initializers.orthogonal, 2, 2), initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, 4, 4) ] class NNInitializersTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer": rec.initializer(), "shape": shape, "dtype": dtype} for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in [onp.float32, onp.float64])) def testInitializer(self, initializer, shape, dtype): rng = random.PRNGKey(0) val = initializer(rng, shape, dtype) self.assertEqual(shape, np.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), np.dtype(val)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( rec.name, jtu.format_shape_dtype_string(shape, dtype)), "initializer_provider": rec.initializer, "shape": shape, "dtype": dtype} for rec in INITIALIZER_RECS for shape in rec.shapes for dtype in [onp.float32, onp.float64])) def testInitializerProvider(self, initializer_provider, shape, dtype): rng = random.PRNGKey(0) initializer = initializer_provider(dtype=dtype) val = initializer(rng, shape) self.assertEqual(shape, np.shape(val)) self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), np.dtype(val)) if __name__ == "__main__": absltest.main()