mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 16:16:08 +00:00
133 lines
5.1 KiB
Python
133 lines
5.1 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for Stax library."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import numpy as onp
|
|
|
|
from jax import test_util as jtu
|
|
from jax import random
|
|
from jax.experimental import stax
|
|
|
|
|
|
def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):
|
|
result_shape, params = init_fun(input_shape)
|
|
inputs = onp.random.RandomState(0).randn(*input_shape).astype("float32")
|
|
rng_key = random.PRNGKey(0)
|
|
result = apply_fun(params, inputs, rng_key)
|
|
test_case.assertEqual(result.shape, result_shape)
|
|
|
|
|
|
class StaxTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}".format(shape), "shape": shape}
|
|
for shape in [(2, 3), (5,)])
|
|
def testRandnInitShape(self, shape):
|
|
out = stax.randn()(shape)
|
|
self.assertEqual(out.shape, shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}".format(shape), "shape": shape}
|
|
for shape in [(2, 3), (2, 3, 4)])
|
|
def testGlorotInitShape(self, shape):
|
|
out = stax.glorot()(shape)
|
|
self.assertEqual(out.shape, shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name":
|
|
"_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
|
|
.format(channels, filter_shape, padding, strides, input_shape),
|
|
"channels": channels, "filter_shape": filter_shape, "padding": padding,
|
|
"strides": strides, "input_shape": input_shape}
|
|
for channels in [2, 3]
|
|
for filter_shape in [(1, 1), (2, 3)]
|
|
for padding in ["SAME", "VALID"]
|
|
for strides in [None, (2, 1)]
|
|
for input_shape in [(2, 10, 11, 1)])
|
|
def testConvShape(self, channels, filter_shape, padding, strides,
|
|
input_shape):
|
|
init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides,
|
|
padding=padding)
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_out_dim={}_input_shape={}"
|
|
.format(out_dim, input_shape),
|
|
"out_dim": out_dim, "input_shape": input_shape}
|
|
for out_dim in [3, 4]
|
|
for input_shape in [(2, 3), (3, 4)])
|
|
def testDenseShape(self, out_dim, input_shape):
|
|
init_fun, apply_fun = stax.Dense(out_dim)
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_input_shape={}".format(input_shape),
|
|
"input_shape": input_shape}
|
|
for input_shape in [(2, 3), (2, 3, 4)])
|
|
def testReluShape(self, input_shape):
|
|
init_fun, apply_fun = stax.Relu
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}"
|
|
.format(window_shape, padding, strides, input_shape),
|
|
"window_shape": window_shape, "padding": padding, "strides": strides,
|
|
"input_shape": input_shape}
|
|
for window_shape in [(1, 1), (2, 3)]
|
|
for padding in ["VALID"]
|
|
for strides in [None, (2, 1)]
|
|
for input_shape in [(2, 5, 6, 1)])
|
|
def testPoolingShape(self, window_shape, padding, strides, input_shape):
|
|
init_fun, apply_fun = stax.MaxPool(window_shape, padding=padding,
|
|
strides=strides)
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_shape={}".format(input_shape),
|
|
"input_shape": input_shape}
|
|
for input_shape in [(2, 3), (2, 3, 4)])
|
|
def testFlattenShape(self, input_shape):
|
|
init_fun, apply_fun = stax.Flatten
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_input_shape={}_spec={}".format(input_shape, i),
|
|
"input_shape": input_shape, "spec": spec}
|
|
for input_shape in [(2, 5, 6, 1)]
|
|
for i, spec in enumerate([
|
|
[stax.Conv(3, (2, 2))],
|
|
[stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))
|
|
def testSerialComposeLayersShape(self, input_shape, spec):
|
|
init_fun, apply_fun = stax.serial(*spec)
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": "_input_shape={}".format(input_shape),
|
|
"input_shape": input_shape}
|
|
for input_shape in [(3, 4), (2, 5, 6, 1)])
|
|
def testDropoutShape(self, input_shape):
|
|
init_fun, apply_fun = stax.Dropout(0.9)
|
|
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|