mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a library, which is an odd thing to do. The new package and the module within it are not built into the wheel, as these are internal utilities for JAX's tests, not utilities for JAX users writing their own tests. Followup changes will add additional existing internal test utilities to this package. This will allow removing sys.path manipulation from deprecation_module_test and hopefully lazy_loader_test, as well as removing the non-public test_util.py from _src to make it clearer that it should not be used from outside JAX. PiperOrigin-RevId: 510260230
This commit is contained in:
parent
47dc01637f
commit
631e4ed7e0
@ -73,6 +73,14 @@ py_library(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "internal_test_util",
|
||||
testonly = 1,
|
||||
srcs = glob(["_src/internal_test_util/*.py"]),
|
||||
visibility = [":internal"],
|
||||
deps = [":jax"] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "jax",
|
||||
srcs = glob(
|
||||
@ -93,6 +101,7 @@ py_library_providing_imports_info(
|
||||
"_src/test_util.py",
|
||||
"*_test.py",
|
||||
"**/*_test.py",
|
||||
"_src/internal_test_util/**",
|
||||
],
|
||||
) + [
|
||||
# TODO(phawkins): remove global_device_array from the main JAX target.
|
||||
|
13
jax/_src/internal_test_util/__init__.py
Normal file
13
jax/_src/internal_test_util/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
347
jax/_src/internal_test_util/lax_test_util.py
Normal file
347
jax/_src/internal_test_util/lax_test_util.py
Normal file
@ -0,0 +1,347 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This submodule includes lax-specific private test utilities that are not
|
||||
# exported to jax.test_util. Functionality appearing here is for internal use
|
||||
# only, and may be changed or removed at any time and without any deprecation
|
||||
# cycle.
|
||||
|
||||
import collections
|
||||
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
# For standard unops and binops, we can generate a large number of tests on
|
||||
# arguments of appropriate shapes and dtypes using the following table.
|
||||
|
||||
float_dtypes = test_util.dtypes.all_floating
|
||||
complex_elem_dtypes = test_util.dtypes.floating
|
||||
complex_dtypes = test_util.dtypes.complex
|
||||
inexact_dtypes = test_util.dtypes.all_inexact
|
||||
int_dtypes = test_util.dtypes.all_integer
|
||||
uint_dtypes = test_util.dtypes.all_unsigned
|
||||
bool_dtypes = test_util.dtypes.boolean
|
||||
|
||||
default_dtypes = float_dtypes + int_dtypes
|
||||
all_dtypes = (
|
||||
float_dtypes + complex_dtypes + int_dtypes + uint_dtypes + bool_dtypes
|
||||
)
|
||||
python_scalar_types = [bool, int, float, complex]
|
||||
|
||||
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
||||
|
||||
OpRecord = collections.namedtuple(
|
||||
"OpRecord", ["op", "nargs", "dtypes", "rng_factory", "tol"]
|
||||
)
|
||||
|
||||
|
||||
def op_record(op, nargs, dtypes, rng_factory, tol=None):
|
||||
return OpRecord(op, nargs, dtypes, rng_factory, tol)
|
||||
|
||||
|
||||
ReducerOpRecord = collections.namedtuple(
|
||||
"ReducerOpRecord", ["op", "reference_op", "init_val", "dtypes", "primitive"]
|
||||
)
|
||||
|
||||
|
||||
def lax_reduce_ops():
|
||||
return [
|
||||
ReducerOpRecord(lax.add, np.add, 0, default_dtypes, lax.reduce_sum_p),
|
||||
ReducerOpRecord(
|
||||
lax.mul, np.multiply, 1, default_dtypes, lax.reduce_prod_p
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.max, np.maximum, 0, uint_dtypes + bool_dtypes, lax.reduce_max_p
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.max, np.maximum, -np.inf, float_dtypes, lax.reduce_max_p
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.max,
|
||||
np.maximum,
|
||||
dtypes.iinfo(np.int32).min,
|
||||
[np.int32],
|
||||
lax.reduce_max_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.max,
|
||||
np.maximum,
|
||||
dtypes.iinfo(np.int64).min,
|
||||
[np.int64],
|
||||
lax.reduce_max_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.min, np.minimum, np.inf, float_dtypes, lax.reduce_min_p
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.min,
|
||||
np.minimum,
|
||||
dtypes.iinfo(np.int32).max,
|
||||
[np.int32],
|
||||
lax.reduce_min_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.min,
|
||||
np.minimum,
|
||||
dtypes.iinfo(np.int64).max,
|
||||
[np.int64],
|
||||
lax.reduce_min_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.min,
|
||||
np.minimum,
|
||||
dtypes.iinfo(np.uint32).max,
|
||||
[np.uint32],
|
||||
lax.reduce_min_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.min,
|
||||
np.minimum,
|
||||
dtypes.iinfo(np.uint64).max,
|
||||
[np.uint64],
|
||||
lax.reduce_min_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.bitwise_and,
|
||||
np.bitwise_and,
|
||||
-1,
|
||||
int_dtypes + uint_dtypes + bool_dtypes,
|
||||
lax.reduce_and_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.bitwise_or,
|
||||
np.bitwise_or,
|
||||
0,
|
||||
int_dtypes + uint_dtypes + bool_dtypes,
|
||||
lax.reduce_or_p,
|
||||
),
|
||||
ReducerOpRecord(
|
||||
lax.bitwise_xor,
|
||||
np.bitwise_xor,
|
||||
0,
|
||||
int_dtypes + uint_dtypes + bool_dtypes,
|
||||
lax.reduce_xor_p,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def lax_ops():
|
||||
return [
|
||||
op_record(
|
||||
"neg", 1, default_dtypes + complex_dtypes, test_util.rand_small
|
||||
),
|
||||
op_record("sign", 1, default_dtypes + uint_dtypes, test_util.rand_small),
|
||||
op_record("floor", 1, float_dtypes, test_util.rand_small),
|
||||
op_record("ceil", 1, float_dtypes, test_util.rand_small),
|
||||
op_record("round", 1, float_dtypes, test_util.rand_default),
|
||||
op_record(
|
||||
"nextafter",
|
||||
2,
|
||||
[f for f in float_dtypes if f != dtypes.bfloat16],
|
||||
test_util.rand_default,
|
||||
tol=0,
|
||||
),
|
||||
op_record("is_finite", 1, float_dtypes, test_util.rand_small),
|
||||
op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small),
|
||||
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
|
||||
# precision.
|
||||
op_record(
|
||||
"expm1",
|
||||
1,
|
||||
float_dtypes + complex_dtypes,
|
||||
test_util.rand_small,
|
||||
{np.float64: 1e-8},
|
||||
),
|
||||
op_record(
|
||||
"log", 1, float_dtypes + complex_dtypes, test_util.rand_positive
|
||||
),
|
||||
op_record(
|
||||
"log1p", 1, float_dtypes + complex_dtypes, test_util.rand_positive
|
||||
),
|
||||
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
|
||||
# ~float32 precision.
|
||||
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
|
||||
op_record(
|
||||
"tanh",
|
||||
1,
|
||||
float_dtypes + complex_dtypes,
|
||||
test_util.rand_small,
|
||||
{np.float64: 1e-9, np.complex128: 1e-7},
|
||||
),
|
||||
op_record(
|
||||
"logistic", 1, float_dtypes + complex_dtypes, test_util.rand_default
|
||||
),
|
||||
op_record(
|
||||
"sin", 1, float_dtypes + complex_dtypes, test_util.rand_default
|
||||
),
|
||||
op_record(
|
||||
"cos", 1, float_dtypes + complex_dtypes, test_util.rand_default
|
||||
),
|
||||
op_record("atan2", 2, float_dtypes, test_util.rand_default),
|
||||
op_record("sqrt", 1, float_dtypes, test_util.rand_positive),
|
||||
op_record("sqrt", 1, complex_dtypes, test_util.rand_default),
|
||||
op_record("rsqrt", 1, float_dtypes, test_util.rand_positive),
|
||||
op_record("rsqrt", 1, complex_dtypes, test_util.rand_default),
|
||||
op_record("cbrt", 1, float_dtypes, test_util.rand_default),
|
||||
op_record(
|
||||
"square", 1, float_dtypes + complex_dtypes, test_util.rand_default
|
||||
),
|
||||
op_record(
|
||||
"reciprocal",
|
||||
1,
|
||||
float_dtypes + complex_dtypes,
|
||||
test_util.rand_positive,
|
||||
),
|
||||
op_record(
|
||||
"tan",
|
||||
1,
|
||||
float_dtypes + complex_dtypes,
|
||||
test_util.rand_default,
|
||||
{np.float32: 3e-5},
|
||||
),
|
||||
op_record(
|
||||
"asin",
|
||||
1,
|
||||
float_dtypes + complex_dtypes,
|
||||
test_util.rand_small,
|
||||
{np.complex128: 5e-12},
|
||||
),
|
||||
op_record("acos", 1, float_dtypes + complex_dtypes, test_util.rand_small),
|
||||
op_record("atan", 1, float_dtypes + complex_dtypes, test_util.rand_small),
|
||||
op_record(
|
||||
"asinh",
|
||||
1,
|
||||
float_dtypes + complex_dtypes,
|
||||
test_util.rand_default,
|
||||
tol={np.complex64: 1e-4, np.complex128: 1e-5},
|
||||
),
|
||||
op_record(
|
||||
"acosh", 1, float_dtypes + complex_dtypes, test_util.rand_positive
|
||||
),
|
||||
# TODO(b/155331781): atanh has only ~float precision
|
||||
op_record(
|
||||
"atanh",
|
||||
1,
|
||||
float_dtypes + complex_dtypes,
|
||||
test_util.rand_small,
|
||||
{np.float64: 1e-9},
|
||||
),
|
||||
op_record(
|
||||
"sinh", 1, float_dtypes + complex_dtypes, test_util.rand_default
|
||||
),
|
||||
op_record(
|
||||
"cosh", 1, float_dtypes + complex_dtypes, test_util.rand_default
|
||||
),
|
||||
op_record(
|
||||
"lgamma",
|
||||
1,
|
||||
float_dtypes,
|
||||
test_util.rand_positive,
|
||||
{
|
||||
np.float32: (
|
||||
1e-3 if test_util.device_under_test() == "tpu" else 1e-5
|
||||
),
|
||||
np.float64: 1e-14,
|
||||
},
|
||||
),
|
||||
op_record(
|
||||
"digamma",
|
||||
1,
|
||||
float_dtypes,
|
||||
test_util.rand_positive,
|
||||
{np.float64: 1e-14},
|
||||
),
|
||||
op_record(
|
||||
"betainc",
|
||||
3,
|
||||
float_dtypes,
|
||||
test_util.rand_positive,
|
||||
{np.float64: 1e-14},
|
||||
),
|
||||
op_record(
|
||||
"igamma",
|
||||
2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
|
||||
test_util.rand_positive,
|
||||
{np.float64: 1e-14},
|
||||
),
|
||||
op_record(
|
||||
"igammac",
|
||||
2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
|
||||
test_util.rand_positive,
|
||||
{np.float64: 1e-14},
|
||||
),
|
||||
op_record("erf", 1, float_dtypes, test_util.rand_small),
|
||||
op_record("erfc", 1, float_dtypes, test_util.rand_small),
|
||||
# TODO(b/142976030): the approximation of erfinf used by XLA is only
|
||||
# accurate to float32 precision.
|
||||
op_record(
|
||||
"erf_inv", 1, float_dtypes, test_util.rand_small, {np.float64: 1e-9}
|
||||
),
|
||||
op_record("bessel_i0e", 1, float_dtypes, test_util.rand_default),
|
||||
op_record("bessel_i1e", 1, float_dtypes, test_util.rand_default),
|
||||
op_record("real", 1, complex_dtypes, test_util.rand_default),
|
||||
op_record("imag", 1, complex_dtypes, test_util.rand_default),
|
||||
op_record("complex", 2, complex_elem_dtypes, test_util.rand_default),
|
||||
op_record(
|
||||
"conj",
|
||||
1,
|
||||
complex_elem_dtypes + complex_dtypes,
|
||||
test_util.rand_default,
|
||||
),
|
||||
op_record(
|
||||
"abs", 1, default_dtypes + complex_dtypes, test_util.rand_default
|
||||
),
|
||||
op_record(
|
||||
"pow", 2, float_dtypes + complex_dtypes, test_util.rand_positive
|
||||
),
|
||||
op_record("bitwise_and", 2, bool_dtypes, test_util.rand_small),
|
||||
op_record("bitwise_not", 1, bool_dtypes, test_util.rand_small),
|
||||
op_record("bitwise_or", 2, bool_dtypes, test_util.rand_small),
|
||||
op_record("bitwise_xor", 2, bool_dtypes, test_util.rand_small),
|
||||
op_record(
|
||||
"population_count", 1, int_dtypes + uint_dtypes, test_util.rand_int
|
||||
),
|
||||
op_record("clz", 1, int_dtypes + uint_dtypes, test_util.rand_int),
|
||||
op_record(
|
||||
"add", 2, default_dtypes + complex_dtypes, test_util.rand_small
|
||||
),
|
||||
op_record(
|
||||
"sub", 2, default_dtypes + complex_dtypes, test_util.rand_small
|
||||
),
|
||||
op_record(
|
||||
"mul", 2, default_dtypes + complex_dtypes, test_util.rand_small
|
||||
),
|
||||
op_record(
|
||||
"div", 2, default_dtypes + complex_dtypes, test_util.rand_nonzero
|
||||
),
|
||||
op_record("rem", 2, default_dtypes, test_util.rand_nonzero),
|
||||
op_record("max", 2, all_dtypes, test_util.rand_small),
|
||||
op_record("min", 2, all_dtypes, test_util.rand_small),
|
||||
op_record("eq", 2, all_dtypes, test_util.rand_some_equal),
|
||||
op_record("ne", 2, all_dtypes, test_util.rand_small),
|
||||
op_record("ge", 2, default_dtypes, test_util.rand_small),
|
||||
op_record("gt", 2, default_dtypes, test_util.rand_small),
|
||||
op_record("le", 2, default_dtypes, test_util.rand_small),
|
||||
op_record("lt", 2, default_dtypes, test_util.rand_small),
|
||||
]
|
2
setup.py
2
setup.py
@ -60,7 +60,7 @@ setup(
|
||||
long_description_content_type='text/markdown',
|
||||
author='JAX team',
|
||||
author_email='jax-dev@google.com',
|
||||
packages=find_packages(exclude=["examples"]),
|
||||
packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]),
|
||||
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
|
||||
python_requires='>=3.8',
|
||||
install_requires=[
|
||||
|
31
tests/BUILD
31
tests/BUILD
@ -18,7 +18,6 @@ load(
|
||||
"jax_test",
|
||||
"jax_test_file_visibility",
|
||||
"py_deps",
|
||||
"pytype_library",
|
||||
"pytype_test",
|
||||
)
|
||||
|
||||
@ -430,25 +429,7 @@ jax_test(
|
||||
"tpu": 30,
|
||||
"iree": 40,
|
||||
},
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "lax_test_lib",
|
||||
srcs = ["lax_test.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = ["//jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "lax_vmap_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["lax_vmap_test.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":lax_test_lib",
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -460,7 +441,6 @@ jax_test(
|
||||
"tpu": 20,
|
||||
"iree": 40,
|
||||
},
|
||||
deps = [":lax_test_lib"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -475,7 +455,7 @@ jax_test(
|
||||
"tpu": 40,
|
||||
"iree": 40,
|
||||
},
|
||||
deps = [":lax_test_lib"],
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -587,10 +567,6 @@ jax_test(
|
||||
"tpu": 30,
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
":lax_test_lib",
|
||||
":lax_vmap_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -964,9 +940,6 @@ jax_test(
|
||||
name = "ann_test",
|
||||
srcs = ["ann_test.py"],
|
||||
shard_count = 10,
|
||||
deps = [
|
||||
":lax_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -29,24 +29,18 @@ from jax import dtypes
|
||||
from jax import lax
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import prod, safe_map, safe_zip
|
||||
|
||||
from lax_test import LAX_OPS
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
default_dtypes = jtu.dtypes.all_floating + jtu.dtypes.integer
|
||||
all_dtypes = jtu.dtypes.all
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
||||
|
||||
def all_bdims(*shapes):
|
||||
bdims = (itertools.chain([cast(Optional[int], None)],
|
||||
@ -96,11 +90,11 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)],
|
||||
[dict(shapes=shapes, bdims=bdims)
|
||||
for shape_group in compatible_shapes
|
||||
for shape_group in lax_test_util.compatible_shapes
|
||||
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
|
||||
for bdims in all_bdims(*shapes)],
|
||||
dtype=rec.dtypes,
|
||||
) for rec in LAX_OPS))
|
||||
) for rec in lax_test_util.lax_ops()))
|
||||
def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
|
||||
rng = rng_factory(self.rng())
|
||||
op = getattr(lax, op_name)
|
||||
@ -172,7 +166,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(2, 4)] for bdims in all_bdims(shape)],
|
||||
dtype=float_dtypes,
|
||||
dtype=lax_test_util.float_dtypes,
|
||||
nexp=[1, 3, 5],
|
||||
nmant=[0, 2, 4],
|
||||
)
|
||||
@ -205,7 +199,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
]
|
||||
for bdims in all_bdims(min_shape, operand_shape, max_shape)
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -216,7 +210,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, bdims=bdims)
|
||||
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
||||
for bdims in all_bdims(lhs_shape, rhs_shape)],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -240,7 +234,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(3, 2), (2, 4), [1], [0]],
|
||||
]
|
||||
for bdims in all_bdims(lhs_shape, rhs_shape)],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
|
||||
lhs_contracting, rhs_contracting, bdims):
|
||||
@ -259,7 +253,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
||||
]
|
||||
for bdims in all_bdims(lhs_shape, rhs_shape)],
|
||||
dtype=default_dtypes)
|
||||
dtype=lax_test_util.default_dtypes)
|
||||
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
|
||||
dimension_numbers, bdims):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
@ -276,7 +270,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(), (2, 3)] for bdims in all_bdims(shape)],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
broadcast_sizes=[(), (2,), (1, 2)],
|
||||
)
|
||||
def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
|
||||
@ -294,7 +288,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
([], [2, 3], []),
|
||||
]
|
||||
for bdims in all_bdims(inshape)],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
@unittest.skip("this test has failures in some cases") # TODO(mattjj)
|
||||
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
|
||||
@ -334,7 +328,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(2, 2, 4), (2, 1, 0), (4, 2, 2)]
|
||||
]
|
||||
for bdims in all_bdims(arg_shape)],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -344,7 +338,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(2, 3)] for bdims in all_bdims(shape, ())],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
pads=[[(1, 2, 1), (0, 1, 0)]],
|
||||
)
|
||||
def testPad(self, shape, dtype, pads, bdims):
|
||||
@ -357,7 +351,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for arg_shape in [(), (3,), (2, 3)]
|
||||
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
||||
for bdims in all_bdims(pred_shape, arg_shape, arg_shape)],
|
||||
arg_dtype=default_dtypes,
|
||||
arg_dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -381,7 +375,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
]
|
||||
for bdims in all_bdims(shape)
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testSlice(self, shape, dtype, starts, limits, strides, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -398,7 +392,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
]
|
||||
for bdims in all_bdims(shape)
|
||||
],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testTranspose(self, shape, dtype, perm, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -408,14 +402,14 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
[dict(init_val=init_val, op=op, dtype=dtype)
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, default_dtypes),
|
||||
(1, lax.mul, default_dtypes),
|
||||
(0, lax.add, lax_test_util.default_dtypes),
|
||||
(1, lax.mul, lax_test_util.default_dtypes),
|
||||
# non-monoidal for everything except unsigned integers
|
||||
(0, lax.max, all_dtypes),
|
||||
(-np.inf, lax.max, float_dtypes),
|
||||
(0, lax.max, lax_test_util.all_dtypes),
|
||||
(-np.inf, lax.max, lax_test_util.float_dtypes),
|
||||
(dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
|
||||
(dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
|
||||
(np.inf, lax.min, float_dtypes),
|
||||
(np.inf, lax.min, lax_test_util.float_dtypes),
|
||||
(dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
|
||||
(dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
|
||||
(dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
|
||||
@ -442,7 +436,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
||||
]
|
||||
for bdims in all_bdims(shape, shape)],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testVariadicReduce(self, shape, dtype, dims, bdims):
|
||||
def op(a, b):
|
||||
@ -461,7 +455,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for bdims in all_bdims(shape)
|
||||
for dim in range(len(shape))],
|
||||
op=[lax.argmin, lax.argmax],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testArgminmax(self, op, shape, dtype, dim, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -539,7 +533,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)]))
|
||||
for bdims in all_bdims(shape, shape)
|
||||
],
|
||||
dtype=float_dtypes,
|
||||
dtype=lax_test_util.float_dtypes,
|
||||
padding=["VALID", "SAME"]
|
||||
)
|
||||
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
|
||||
@ -554,7 +548,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes,
|
||||
dtype=lax_test_util.float_dtypes,
|
||||
padding=["VALID", "SAME"],
|
||||
shape=[(3, 2, 4, 6)],
|
||||
dims=[(1, 1, 2, 1)],
|
||||
@ -611,7 +605,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
(1, 3)),
|
||||
]
|
||||
for bdims in all_bdims(shape, idxs.shape)],
|
||||
dtype=all_dtypes
|
||||
dtype=lax_test_util.all_dtypes
|
||||
)
|
||||
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
|
||||
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
||||
@ -620,7 +614,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
|
||||
dnums=dnums, bdims=bdims)
|
||||
@ -636,7 +629,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for bdims in all_bdims(arg_shape, idxs.shape, update_shape)],
|
||||
dtype=float_dtypes
|
||||
dtype=lax_test_util.float_dtypes
|
||||
)
|
||||
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
|
||||
fun = partial(lax.scatter_add, dimension_numbers=dnums)
|
||||
@ -667,7 +660,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for shape in [(4,), (3, 5, 3)]
|
||||
for bdims in all_bdims(shape)],
|
||||
k=[1, 3],
|
||||
dtype=default_dtypes,
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
# The top_k indices for integer arrays with identical entries won't match between
|
||||
# vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
|
||||
@ -704,7 +697,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
self._CheckBatching(fun, 5, bdims, (shape,) * arity,
|
||||
(np.float32,) * arity, rng)
|
||||
|
||||
|
||||
# TODO Concatenate
|
||||
# TODO Reverse
|
||||
# TODO DynamicSlice
|
||||
@ -734,7 +726,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
padding=((0, 0),))
|
||||
return x[idxs]
|
||||
|
||||
|
||||
inpt = jnp.array([
|
||||
[1.0, 0.0, 1.0],
|
||||
[2.0, 2.0, 0.0],
|
||||
|
Loading…
x
Reference in New Issue
Block a user