2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Tests for Stax library."""
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2020-05-21 18:12:18 -03:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2018-11-17 18:03:33 -08:00
|
|
|
from jax import random
|
2021-10-19 17:30:16 -07:00
|
|
|
from jax.example_libraries import stax
|
2021-08-17 20:41:02 -07:00
|
|
|
from jax import dtypes
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
from jax.config import config
|
2018-12-06 18:30:59 -05:00
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-30 16:51:32 -08:00
|
|
|
def random_inputs(rng, input_shape):
|
|
|
|
if type(input_shape) is tuple:
|
2021-08-17 20:41:02 -07:00
|
|
|
return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(np.float_))
|
2018-12-30 16:51:32 -08:00
|
|
|
elif type(input_shape) is list:
|
|
|
|
return [random_inputs(rng, shape) for shape in input_shape]
|
|
|
|
else:
|
|
|
|
raise TypeError(type(input_shape))
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
|
|
|
rng_key = random.PRNGKey(0)
|
2019-04-03 12:54:02 +01:00
|
|
|
rng_key, init_key = random.split(rng_key)
|
|
|
|
result_shape, params = init_fun(init_key, input_shape)
|
2021-12-10 10:32:09 -08:00
|
|
|
inputs = random_inputs(test_case.rng(), input_shape)
|
2019-03-01 13:12:47 -08:00
|
|
|
result = apply_fun(params, inputs, rng=rng_key)
|
2018-11-17 18:03:33 -08:00
|
|
|
test_case.assertEqual(result.shape, result_shape)
|
|
|
|
|
|
|
|
|
|
|
|
class StaxTest(jtu.JaxTestCase):
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}".format(shape), "shape": shape}
|
2018-12-06 18:30:59 -05:00
|
|
|
for shape in [(2, 3), (5,)]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testRandnInitShape(self, shape):
|
2019-04-03 12:54:02 +01:00
|
|
|
key = random.PRNGKey(0)
|
|
|
|
out = stax.randn()(key, shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
self.assertEqual(out.shape, shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}".format(shape), "shape": shape}
|
2018-12-06 18:30:59 -05:00
|
|
|
for shape in [(2, 3), (2, 3, 4)]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testGlorotInitShape(self, shape):
|
2019-04-03 12:54:02 +01:00
|
|
|
key = random.PRNGKey(0)
|
|
|
|
out = stax.glorot()(key, shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
self.assertEqual(out.shape, shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name":
|
|
|
|
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
|
|
|
|
.format(channels, filter_shape, padding, strides, input_shape),
|
|
|
|
"channels": channels, "filter_shape": filter_shape, "padding": padding,
|
|
|
|
"strides": strides, "input_shape": input_shape}
|
|
|
|
for channels in [2, 3]
|
|
|
|
for filter_shape in [(1, 1), (2, 3)]
|
|
|
|
for padding in ["SAME", "VALID"]
|
|
|
|
for strides in [None, (2, 1)]
|
2018-12-06 18:30:59 -05:00
|
|
|
for input_shape in [(2, 10, 11, 1)]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testConvShape(self, channels, filter_shape, padding, strides,
|
|
|
|
input_shape):
|
|
|
|
init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides,
|
|
|
|
padding=padding)
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
|
|
|
|
.format(channels, filter_shape, padding, strides, input_shape),
|
|
|
|
"channels": channels, "filter_shape": filter_shape, "padding": padding,
|
|
|
|
"strides": strides, "input_shape": input_shape}
|
|
|
|
for channels in [2, 3]
|
|
|
|
for filter_shape in [(1, 1), (2, 3), (3, 3)]
|
|
|
|
for padding in ["SAME", "VALID"]
|
|
|
|
for strides in [None, (2, 1), (2, 2)]
|
|
|
|
for input_shape in [(2, 10, 11, 1)]))
|
|
|
|
def testConvTransposeShape(self, channels, filter_shape, padding, strides,
|
|
|
|
input_shape):
|
2019-04-10 00:57:02 -07:00
|
|
|
init_fun, apply_fun = stax.ConvTranspose(channels, filter_shape, # 2D
|
2019-04-09 22:59:03 -07:00
|
|
|
strides=strides, padding=padding)
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name":
|
|
|
|
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
|
|
|
|
.format(channels, filter_shape, padding, strides, input_shape),
|
|
|
|
"channels": channels, "filter_shape": filter_shape, "padding": padding,
|
|
|
|
"strides": strides, "input_shape": input_shape}
|
|
|
|
for channels in [2, 3]
|
|
|
|
for filter_shape in [(1,), (2,), (3,)]
|
|
|
|
for padding in ["SAME", "VALID"]
|
|
|
|
for strides in [None, (1,), (2,)]
|
|
|
|
for input_shape in [(2, 10, 1)]))
|
|
|
|
def testConv1DTransposeShape(self, channels, filter_shape, padding, strides,
|
|
|
|
input_shape):
|
|
|
|
init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape,
|
|
|
|
strides=strides, padding=padding)
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_out_dim={}_input_shape={}"
|
|
|
|
.format(out_dim, input_shape),
|
|
|
|
"out_dim": out_dim, "input_shape": input_shape}
|
|
|
|
for out_dim in [3, 4]
|
2018-12-06 18:30:59 -05:00
|
|
|
for input_shape in [(2, 3), (3, 4)]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDenseShape(self, out_dim, input_shape):
|
|
|
|
init_fun, apply_fun = stax.Dense(out_dim)
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2019-08-01 16:39:08 -04:00
|
|
|
{"testcase_name": "_input_shape={}_nonlinear={}"
|
|
|
|
.format(input_shape, nonlinear),
|
|
|
|
"input_shape": input_shape, "nonlinear": nonlinear}
|
|
|
|
for input_shape in [(2, 3), (2, 3, 4)]
|
2019-08-01 17:11:31 -04:00
|
|
|
for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"]))
|
2019-08-01 16:39:08 -04:00
|
|
|
def testNonlinearShape(self, input_shape, nonlinear):
|
2019-08-01 16:41:06 -04:00
|
|
|
init_fun, apply_fun = getattr(stax, nonlinear)
|
2018-11-17 18:03:33 -08:00
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}"
|
2019-12-18 15:27:46 -08:00
|
|
|
"_maxpool={}_spec={}"
|
2019-01-28 09:20:02 -05:00
|
|
|
.format(window_shape, padding, strides, input_shape,
|
2019-12-18 15:27:46 -08:00
|
|
|
max_pool, spec),
|
2018-11-17 18:03:33 -08:00
|
|
|
"window_shape": window_shape, "padding": padding, "strides": strides,
|
2019-12-18 15:27:46 -08:00
|
|
|
"input_shape": input_shape, "max_pool": max_pool, "spec": spec}
|
2018-11-17 18:03:33 -08:00
|
|
|
for window_shape in [(1, 1), (2, 3)]
|
|
|
|
for padding in ["VALID"]
|
|
|
|
for strides in [None, (2, 1)]
|
2019-12-18 15:27:46 -08:00
|
|
|
for input_shape in [(2, 5, 6, 4)]
|
|
|
|
for max_pool in [False, True]
|
|
|
|
for spec in ["NHWC", "NCHW", "WHNC", "WHCN"]))
|
2019-01-28 09:20:02 -05:00
|
|
|
def testPoolingShape(self, window_shape, padding, strides, input_shape,
|
2019-12-18 15:27:46 -08:00
|
|
|
max_pool, spec):
|
2019-01-28 09:20:02 -05:00
|
|
|
layer = stax.MaxPool if max_pool else stax.AvgPool
|
2019-12-18 15:27:46 -08:00
|
|
|
init_fun, apply_fun = layer(window_shape, padding=padding, strides=strides,
|
|
|
|
spec=spec)
|
2018-11-17 18:03:33 -08:00
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_shape={}".format(input_shape),
|
|
|
|
"input_shape": input_shape}
|
2018-12-06 18:30:59 -05:00
|
|
|
for input_shape in [(2, 3), (2, 3, 4)]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testFlattenShape(self, input_shape):
|
|
|
|
init_fun, apply_fun = stax.Flatten
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_input_shape={}_spec={}".format(input_shape, i),
|
|
|
|
"input_shape": input_shape, "spec": spec}
|
|
|
|
for input_shape in [(2, 5, 6, 1)]
|
|
|
|
for i, spec in enumerate([
|
|
|
|
[stax.Conv(3, (2, 2))],
|
2018-12-06 18:30:59 -05:00
|
|
|
[stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]])))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testSerialComposeLayersShape(self, input_shape, spec):
|
|
|
|
init_fun, apply_fun = stax.serial(*spec)
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2018-12-06 18:30:59 -05:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
2018-11-17 18:03:33 -08:00
|
|
|
{"testcase_name": "_input_shape={}".format(input_shape),
|
|
|
|
"input_shape": input_shape}
|
2018-12-06 18:30:59 -05:00
|
|
|
for input_shape in [(3, 4), (2, 5, 6, 1)]))
|
2018-11-17 18:03:33 -08:00
|
|
|
def testDropoutShape(self, input_shape):
|
|
|
|
init_fun, apply_fun = stax.Dropout(0.9)
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
2018-12-30 16:51:32 -08:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_input_shape={}".format(input_shape),
|
|
|
|
"input_shape": input_shape}
|
|
|
|
for input_shape in [(3, 4), (2, 5, 6, 1)]))
|
|
|
|
def testFanInSum(self, input_shape):
|
|
|
|
init_fun, apply_fun = stax.FanInSum
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape])
|
|
|
|
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": "_inshapes={}_axis={}".format(input_shapes, axis),
|
|
|
|
"input_shapes": input_shapes, "axis": axis}
|
|
|
|
for input_shapes, axis in [
|
|
|
|
([(2, 3), (2, 1)], 1),
|
|
|
|
([(2, 3), (2, 1)], -1),
|
|
|
|
([(1, 2, 4), (1, 1, 4)], 1),
|
|
|
|
]))
|
|
|
|
def testFanInConcat(self, input_shapes, axis):
|
|
|
|
init_fun, apply_fun = stax.FanInConcat(axis)
|
|
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
|
|
|
|
|
2019-03-14 09:32:53 -04:00
|
|
|
def testIssue182(self):
|
2019-04-03 12:54:02 +01:00
|
|
|
key = random.PRNGKey(0)
|
2019-01-05 10:06:31 -08:00
|
|
|
init_fun, apply_fun = stax.Softmax
|
|
|
|
input_shape = (10, 3)
|
2020-05-21 18:12:18 -03:00
|
|
|
inputs = np.arange(30.).astype("float32").reshape(input_shape)
|
2019-01-05 10:06:31 -08:00
|
|
|
|
2019-04-03 12:54:02 +01:00
|
|
|
out_shape, params = init_fun(key, input_shape)
|
2019-01-05 10:06:31 -08:00
|
|
|
out = apply_fun(params, inputs)
|
|
|
|
|
|
|
|
assert out_shape == out.shape
|
2020-05-21 18:12:18 -03:00
|
|
|
assert np.allclose(np.sum(np.asarray(out), -1), 1.)
|
2019-01-05 10:06:31 -08:00
|
|
|
|
2020-03-31 19:59:57 -04:00
|
|
|
def testBatchNormNoScaleOrCenter(self):
|
|
|
|
key = random.PRNGKey(0)
|
|
|
|
axes = (0, 1, 2)
|
|
|
|
init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False)
|
|
|
|
input_shape = (4, 5, 6, 7)
|
2021-12-10 10:32:09 -08:00
|
|
|
inputs = random_inputs(self.rng(), input_shape)
|
2020-03-31 19:59:57 -04:00
|
|
|
|
|
|
|
out_shape, params = init_fun(key, input_shape)
|
|
|
|
out = apply_fun(params, inputs)
|
2020-05-21 18:12:18 -03:00
|
|
|
means = np.mean(out, axis=(0, 1, 2))
|
|
|
|
std_devs = np.std(out, axis=(0, 1, 2))
|
|
|
|
assert np.allclose(means, np.zeros_like(means), atol=1e-4)
|
|
|
|
assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-03-14 09:32:53 -04:00
|
|
|
def testBatchNormShapeNHWC(self):
|
2019-04-03 12:54:02 +01:00
|
|
|
key = random.PRNGKey(0)
|
2019-03-14 09:32:53 -04:00
|
|
|
init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
|
|
|
|
input_shape = (4, 5, 6, 7)
|
2021-12-10 10:32:09 -08:00
|
|
|
inputs = random_inputs(self.rng(), input_shape)
|
2019-03-14 09:32:53 -04:00
|
|
|
|
2019-04-03 12:54:02 +01:00
|
|
|
out_shape, params = init_fun(key, input_shape)
|
2019-03-14 09:32:53 -04:00
|
|
|
out = apply_fun(params, inputs)
|
|
|
|
|
|
|
|
self.assertEqual(out_shape, input_shape)
|
|
|
|
beta, gamma = params
|
|
|
|
self.assertEqual(beta.shape, (7,))
|
|
|
|
self.assertEqual(gamma.shape, (7,))
|
|
|
|
self.assertEqual(out_shape, out.shape)
|
|
|
|
|
|
|
|
def testBatchNormShapeNCHW(self):
|
2019-04-03 12:54:02 +01:00
|
|
|
key = random.PRNGKey(0)
|
2019-03-14 09:32:53 -04:00
|
|
|
# Regression test for https://github.com/google/jax/issues/461
|
|
|
|
init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
|
|
|
|
input_shape = (4, 5, 6, 7)
|
2021-12-10 10:32:09 -08:00
|
|
|
inputs = random_inputs(self.rng(), input_shape)
|
2019-03-14 09:32:53 -04:00
|
|
|
|
2019-04-03 12:54:02 +01:00
|
|
|
out_shape, params = init_fun(key, input_shape)
|
2019-03-14 09:32:53 -04:00
|
|
|
out = apply_fun(params, inputs)
|
|
|
|
|
|
|
|
self.assertEqual(out_shape, input_shape)
|
|
|
|
beta, gamma = params
|
|
|
|
self.assertEqual(beta.shape, (5,))
|
|
|
|
self.assertEqual(gamma.shape, (5,))
|
|
|
|
self.assertEqual(out_shape, out.shape)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|