lax_test: Create a separate module for lax-specific test utils in a new package.

These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
This commit is contained in:
pizzud 2023-02-16 15:29:12 -08:00 committed by jax authors
parent 47dc01637f
commit 631e4ed7e0
7 changed files with 848 additions and 524 deletions

View File

@ -73,6 +73,14 @@ py_library(
] + py_deps("absl/testing") + py_deps("numpy"),
)
py_library(
name = "internal_test_util",
testonly = 1,
srcs = glob(["_src/internal_test_util/*.py"]),
visibility = [":internal"],
deps = [":jax"] + py_deps("numpy"),
)
py_library_providing_imports_info(
name = "jax",
srcs = glob(
@ -93,6 +101,7 @@ py_library_providing_imports_info(
"_src/test_util.py",
"*_test.py",
"**/*_test.py",
"_src/internal_test_util/**",
],
) + [
# TODO(phawkins): remove global_device_array from the main JAX target.

View File

@ -0,0 +1,13 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,347 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This submodule includes lax-specific private test utilities that are not
# exported to jax.test_util. Functionality appearing here is for internal use
# only, and may be changed or removed at any time and without any deprecation
# cycle.
import collections
from jax import lax
from jax._src import dtypes
from jax._src import test_util
import numpy as np
from jax.config import config
config.parse_flags_with_absl()
# For standard unops and binops, we can generate a large number of tests on
# arguments of appropriate shapes and dtypes using the following table.
float_dtypes = test_util.dtypes.all_floating
complex_elem_dtypes = test_util.dtypes.floating
complex_dtypes = test_util.dtypes.complex
inexact_dtypes = test_util.dtypes.all_inexact
int_dtypes = test_util.dtypes.all_integer
uint_dtypes = test_util.dtypes.all_unsigned
bool_dtypes = test_util.dtypes.boolean
default_dtypes = float_dtypes + int_dtypes
all_dtypes = (
float_dtypes + complex_dtypes + int_dtypes + uint_dtypes + bool_dtypes
)
python_scalar_types = [bool, int, float, complex]
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
OpRecord = collections.namedtuple(
"OpRecord", ["op", "nargs", "dtypes", "rng_factory", "tol"]
)
def op_record(op, nargs, dtypes, rng_factory, tol=None):
return OpRecord(op, nargs, dtypes, rng_factory, tol)
ReducerOpRecord = collections.namedtuple(
"ReducerOpRecord", ["op", "reference_op", "init_val", "dtypes", "primitive"]
)
def lax_reduce_ops():
return [
ReducerOpRecord(lax.add, np.add, 0, default_dtypes, lax.reduce_sum_p),
ReducerOpRecord(
lax.mul, np.multiply, 1, default_dtypes, lax.reduce_prod_p
),
ReducerOpRecord(
lax.max, np.maximum, 0, uint_dtypes + bool_dtypes, lax.reduce_max_p
),
ReducerOpRecord(
lax.max, np.maximum, -np.inf, float_dtypes, lax.reduce_max_p
),
ReducerOpRecord(
lax.max,
np.maximum,
dtypes.iinfo(np.int32).min,
[np.int32],
lax.reduce_max_p,
),
ReducerOpRecord(
lax.max,
np.maximum,
dtypes.iinfo(np.int64).min,
[np.int64],
lax.reduce_max_p,
),
ReducerOpRecord(
lax.min, np.minimum, np.inf, float_dtypes, lax.reduce_min_p
),
ReducerOpRecord(
lax.min,
np.minimum,
dtypes.iinfo(np.int32).max,
[np.int32],
lax.reduce_min_p,
),
ReducerOpRecord(
lax.min,
np.minimum,
dtypes.iinfo(np.int64).max,
[np.int64],
lax.reduce_min_p,
),
ReducerOpRecord(
lax.min,
np.minimum,
dtypes.iinfo(np.uint32).max,
[np.uint32],
lax.reduce_min_p,
),
ReducerOpRecord(
lax.min,
np.minimum,
dtypes.iinfo(np.uint64).max,
[np.uint64],
lax.reduce_min_p,
),
ReducerOpRecord(
lax.bitwise_and,
np.bitwise_and,
-1,
int_dtypes + uint_dtypes + bool_dtypes,
lax.reduce_and_p,
),
ReducerOpRecord(
lax.bitwise_or,
np.bitwise_or,
0,
int_dtypes + uint_dtypes + bool_dtypes,
lax.reduce_or_p,
),
ReducerOpRecord(
lax.bitwise_xor,
np.bitwise_xor,
0,
int_dtypes + uint_dtypes + bool_dtypes,
lax.reduce_xor_p,
),
]
def lax_ops():
return [
op_record(
"neg", 1, default_dtypes + complex_dtypes, test_util.rand_small
),
op_record("sign", 1, default_dtypes + uint_dtypes, test_util.rand_small),
op_record("floor", 1, float_dtypes, test_util.rand_small),
op_record("ceil", 1, float_dtypes, test_util.rand_small),
op_record("round", 1, float_dtypes, test_util.rand_default),
op_record(
"nextafter",
2,
[f for f in float_dtypes if f != dtypes.bfloat16],
test_util.rand_default,
tol=0,
),
op_record("is_finite", 1, float_dtypes, test_util.rand_small),
op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record(
"expm1",
1,
float_dtypes + complex_dtypes,
test_util.rand_small,
{np.float64: 1e-8},
),
op_record(
"log", 1, float_dtypes + complex_dtypes, test_util.rand_positive
),
op_record(
"log1p", 1, float_dtypes + complex_dtypes, test_util.rand_positive
),
# TODO(b/142975473): on CPU, tanh for complex128 is only accurate to
# ~float32 precision.
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
op_record(
"tanh",
1,
float_dtypes + complex_dtypes,
test_util.rand_small,
{np.float64: 1e-9, np.complex128: 1e-7},
),
op_record(
"logistic", 1, float_dtypes + complex_dtypes, test_util.rand_default
),
op_record(
"sin", 1, float_dtypes + complex_dtypes, test_util.rand_default
),
op_record(
"cos", 1, float_dtypes + complex_dtypes, test_util.rand_default
),
op_record("atan2", 2, float_dtypes, test_util.rand_default),
op_record("sqrt", 1, float_dtypes, test_util.rand_positive),
op_record("sqrt", 1, complex_dtypes, test_util.rand_default),
op_record("rsqrt", 1, float_dtypes, test_util.rand_positive),
op_record("rsqrt", 1, complex_dtypes, test_util.rand_default),
op_record("cbrt", 1, float_dtypes, test_util.rand_default),
op_record(
"square", 1, float_dtypes + complex_dtypes, test_util.rand_default
),
op_record(
"reciprocal",
1,
float_dtypes + complex_dtypes,
test_util.rand_positive,
),
op_record(
"tan",
1,
float_dtypes + complex_dtypes,
test_util.rand_default,
{np.float32: 3e-5},
),
op_record(
"asin",
1,
float_dtypes + complex_dtypes,
test_util.rand_small,
{np.complex128: 5e-12},
),
op_record("acos", 1, float_dtypes + complex_dtypes, test_util.rand_small),
op_record("atan", 1, float_dtypes + complex_dtypes, test_util.rand_small),
op_record(
"asinh",
1,
float_dtypes + complex_dtypes,
test_util.rand_default,
tol={np.complex64: 1e-4, np.complex128: 1e-5},
),
op_record(
"acosh", 1, float_dtypes + complex_dtypes, test_util.rand_positive
),
# TODO(b/155331781): atanh has only ~float precision
op_record(
"atanh",
1,
float_dtypes + complex_dtypes,
test_util.rand_small,
{np.float64: 1e-9},
),
op_record(
"sinh", 1, float_dtypes + complex_dtypes, test_util.rand_default
),
op_record(
"cosh", 1, float_dtypes + complex_dtypes, test_util.rand_default
),
op_record(
"lgamma",
1,
float_dtypes,
test_util.rand_positive,
{
np.float32: (
1e-3 if test_util.device_under_test() == "tpu" else 1e-5
),
np.float64: 1e-14,
},
),
op_record(
"digamma",
1,
float_dtypes,
test_util.rand_positive,
{np.float64: 1e-14},
),
op_record(
"betainc",
3,
float_dtypes,
test_util.rand_positive,
{np.float64: 1e-14},
),
op_record(
"igamma",
2,
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
test_util.rand_positive,
{np.float64: 1e-14},
),
op_record(
"igammac",
2,
[f for f in float_dtypes if f not in [dtypes.bfloat16, np.float16]],
test_util.rand_positive,
{np.float64: 1e-14},
),
op_record("erf", 1, float_dtypes, test_util.rand_small),
op_record("erfc", 1, float_dtypes, test_util.rand_small),
# TODO(b/142976030): the approximation of erfinf used by XLA is only
# accurate to float32 precision.
op_record(
"erf_inv", 1, float_dtypes, test_util.rand_small, {np.float64: 1e-9}
),
op_record("bessel_i0e", 1, float_dtypes, test_util.rand_default),
op_record("bessel_i1e", 1, float_dtypes, test_util.rand_default),
op_record("real", 1, complex_dtypes, test_util.rand_default),
op_record("imag", 1, complex_dtypes, test_util.rand_default),
op_record("complex", 2, complex_elem_dtypes, test_util.rand_default),
op_record(
"conj",
1,
complex_elem_dtypes + complex_dtypes,
test_util.rand_default,
),
op_record(
"abs", 1, default_dtypes + complex_dtypes, test_util.rand_default
),
op_record(
"pow", 2, float_dtypes + complex_dtypes, test_util.rand_positive
),
op_record("bitwise_and", 2, bool_dtypes, test_util.rand_small),
op_record("bitwise_not", 1, bool_dtypes, test_util.rand_small),
op_record("bitwise_or", 2, bool_dtypes, test_util.rand_small),
op_record("bitwise_xor", 2, bool_dtypes, test_util.rand_small),
op_record(
"population_count", 1, int_dtypes + uint_dtypes, test_util.rand_int
),
op_record("clz", 1, int_dtypes + uint_dtypes, test_util.rand_int),
op_record(
"add", 2, default_dtypes + complex_dtypes, test_util.rand_small
),
op_record(
"sub", 2, default_dtypes + complex_dtypes, test_util.rand_small
),
op_record(
"mul", 2, default_dtypes + complex_dtypes, test_util.rand_small
),
op_record(
"div", 2, default_dtypes + complex_dtypes, test_util.rand_nonzero
),
op_record("rem", 2, default_dtypes, test_util.rand_nonzero),
op_record("max", 2, all_dtypes, test_util.rand_small),
op_record("min", 2, all_dtypes, test_util.rand_small),
op_record("eq", 2, all_dtypes, test_util.rand_some_equal),
op_record("ne", 2, all_dtypes, test_util.rand_small),
op_record("ge", 2, default_dtypes, test_util.rand_small),
op_record("gt", 2, default_dtypes, test_util.rand_small),
op_record("le", 2, default_dtypes, test_util.rand_small),
op_record("lt", 2, default_dtypes, test_util.rand_small),
]

View File

@ -60,7 +60,7 @@ setup(
long_description_content_type='text/markdown',
author='JAX team',
author_email='jax-dev@google.com',
packages=find_packages(exclude=["examples"]),
packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]),
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.8',
install_requires=[

View File

@ -18,7 +18,6 @@ load(
"jax_test",
"jax_test_file_visibility",
"py_deps",
"pytype_library",
"pytype_test",
)
@ -430,25 +429,7 @@ jax_test(
"tpu": 30,
"iree": 40,
},
)
pytype_library(
name = "lax_test_lib",
srcs = ["lax_test.py"],
srcs_version = "PY3",
deps = ["//jax"],
)
pytype_library(
name = "lax_vmap_test_lib",
testonly = 1,
srcs = ["lax_vmap_test.py"],
srcs_version = "PY3",
deps = [
":lax_test_lib",
"//jax",
"//jax:test_util",
],
deps = ["//jax:internal_test_util"] + py_deps("numpy"),
)
jax_test(
@ -460,7 +441,6 @@ jax_test(
"tpu": 20,
"iree": 40,
},
deps = [":lax_test_lib"],
)
jax_test(
@ -475,7 +455,7 @@ jax_test(
"tpu": 40,
"iree": 40,
},
deps = [":lax_test_lib"],
deps = ["//jax:internal_test_util"] + py_deps("numpy"),
)
py_test(
@ -587,10 +567,6 @@ jax_test(
"tpu": 30,
},
tags = ["multiaccelerator"],
deps = [
":lax_test_lib",
":lax_vmap_test_lib",
],
)
jax_test(
@ -964,9 +940,6 @@ jax_test(
name = "ann_test",
srcs = ["ann_test.py"],
shard_count = 10,
deps = [
":lax_test_lib",
],
)
py_test(

File diff suppressed because it is too large Load Diff

View File

@ -29,24 +29,18 @@ from jax import dtypes
from jax import lax
from jax._src import test_util as jtu
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.util import prod, safe_map, safe_zip
from lax_test import LAX_OPS
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
float_dtypes = jtu.dtypes.all_floating
default_dtypes = jtu.dtypes.all_floating + jtu.dtypes.integer
all_dtypes = jtu.dtypes.all
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
def all_bdims(*shapes):
bdims = (itertools.chain([cast(Optional[int], None)],
@ -96,11 +90,11 @@ class LaxVmapTest(jtu.JaxTestCase):
jtu.sample_product_testcases(
[dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)],
[dict(shapes=shapes, bdims=bdims)
for shape_group in compatible_shapes
for shape_group in lax_test_util.compatible_shapes
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
for bdims in all_bdims(*shapes)],
dtype=rec.dtypes,
) for rec in LAX_OPS))
) for rec in lax_test_util.lax_ops()))
def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
rng = rng_factory(self.rng())
op = getattr(lax, op_name)
@ -172,7 +166,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(2, 4)] for bdims in all_bdims(shape)],
dtype=float_dtypes,
dtype=lax_test_util.float_dtypes,
nexp=[1, 3, 5],
nmant=[0, 2, 4],
)
@ -205,7 +199,7 @@ class LaxVmapTest(jtu.JaxTestCase):
]
for bdims in all_bdims(min_shape, operand_shape, max_shape)
],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims):
rng = jtu.rand_default(self.rng())
@ -216,7 +210,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, bdims=bdims)
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
for bdims in all_bdims(lhs_shape, rhs_shape)],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
rng = jtu.rand_default(self.rng())
@ -240,7 +234,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(3, 2), (2, 4), [1], [0]],
]
for bdims in all_bdims(lhs_shape, rhs_shape)],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
lhs_contracting, rhs_contracting, bdims):
@ -259,7 +253,7 @@ class LaxVmapTest(jtu.JaxTestCase):
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for bdims in all_bdims(lhs_shape, rhs_shape)],
dtype=default_dtypes)
dtype=lax_test_util.default_dtypes)
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, bdims):
rng = jtu.rand_small(self.rng())
@ -276,7 +270,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(), (2, 3)] for bdims in all_bdims(shape)],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
broadcast_sizes=[(), (2,), (1, 2)],
)
def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
@ -294,7 +288,7 @@ class LaxVmapTest(jtu.JaxTestCase):
([], [2, 3], []),
]
for bdims in all_bdims(inshape)],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
@unittest.skip("this test has failures in some cases") # TODO(mattjj)
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
@ -334,7 +328,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(2, 2, 4), (2, 1, 0), (4, 2, 2)]
]
for bdims in all_bdims(arg_shape)],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
rng = jtu.rand_default(self.rng())
@ -344,7 +338,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(2, 3)] for bdims in all_bdims(shape, ())],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
pads=[[(1, 2, 1), (0, 1, 0)]],
)
def testPad(self, shape, dtype, pads, bdims):
@ -357,7 +351,7 @@ class LaxVmapTest(jtu.JaxTestCase):
for arg_shape in [(), (3,), (2, 3)]
for pred_shape in ([(), arg_shape] if arg_shape else [()])
for bdims in all_bdims(pred_shape, arg_shape, arg_shape)],
arg_dtype=default_dtypes,
arg_dtype=lax_test_util.default_dtypes,
)
def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
rng = jtu.rand_default(self.rng())
@ -381,7 +375,7 @@ class LaxVmapTest(jtu.JaxTestCase):
]
for bdims in all_bdims(shape)
],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testSlice(self, shape, dtype, starts, limits, strides, bdims):
rng = jtu.rand_default(self.rng())
@ -398,7 +392,7 @@ class LaxVmapTest(jtu.JaxTestCase):
]
for bdims in all_bdims(shape)
],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testTranspose(self, shape, dtype, perm, bdims):
rng = jtu.rand_default(self.rng())
@ -408,14 +402,14 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(init_val=init_val, op=op, dtype=dtype)
for init_val, op, dtypes in [
(0, lax.add, default_dtypes),
(1, lax.mul, default_dtypes),
(0, lax.add, lax_test_util.default_dtypes),
(1, lax.mul, lax_test_util.default_dtypes),
# non-monoidal for everything except unsigned integers
(0, lax.max, all_dtypes),
(-np.inf, lax.max, float_dtypes),
(0, lax.max, lax_test_util.all_dtypes),
(-np.inf, lax.max, lax_test_util.float_dtypes),
(dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
(dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
(np.inf, lax.min, float_dtypes),
(np.inf, lax.min, lax_test_util.float_dtypes),
(dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
(dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
(dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
@ -442,7 +436,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
]
for bdims in all_bdims(shape, shape)],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testVariadicReduce(self, shape, dtype, dims, bdims):
def op(a, b):
@ -461,7 +455,7 @@ class LaxVmapTest(jtu.JaxTestCase):
for bdims in all_bdims(shape)
for dim in range(len(shape))],
op=[lax.argmin, lax.argmax],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
def testArgminmax(self, op, shape, dtype, dim, bdims):
rng = jtu.rand_default(self.rng())
@ -539,7 +533,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(1, 2, 2, 1), (1, 1, 1, 1)]))
for bdims in all_bdims(shape, shape)
],
dtype=float_dtypes,
dtype=lax_test_util.float_dtypes,
padding=["VALID", "SAME"]
)
@jtu.ignore_warning(message="Using reduced precision for gradient.*")
@ -554,7 +548,7 @@ class LaxVmapTest(jtu.JaxTestCase):
self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng)
@jtu.sample_product(
dtype=float_dtypes,
dtype=lax_test_util.float_dtypes,
padding=["VALID", "SAME"],
shape=[(3, 2, 4, 6)],
dims=[(1, 1, 2, 1)],
@ -611,7 +605,7 @@ class LaxVmapTest(jtu.JaxTestCase):
(1, 3)),
]
for bdims in all_bdims(shape, idxs.shape)],
dtype=all_dtypes
dtype=lax_test_util.all_dtypes
)
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
@ -620,7 +614,6 @@ class LaxVmapTest(jtu.JaxTestCase):
self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
jtu.rand_default(self.rng()))
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums, bdims=bdims)
@ -636,7 +629,7 @@ class LaxVmapTest(jtu.JaxTestCase):
scatter_dims_to_operand_dims=(0,))),
]
for bdims in all_bdims(arg_shape, idxs.shape, update_shape)],
dtype=float_dtypes
dtype=lax_test_util.float_dtypes
)
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
fun = partial(lax.scatter_add, dimension_numbers=dnums)
@ -667,7 +660,7 @@ class LaxVmapTest(jtu.JaxTestCase):
for shape in [(4,), (3, 5, 3)]
for bdims in all_bdims(shape)],
k=[1, 3],
dtype=default_dtypes,
dtype=lax_test_util.default_dtypes,
)
# The top_k indices for integer arrays with identical entries won't match between
# vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
@ -704,7 +697,6 @@ class LaxVmapTest(jtu.JaxTestCase):
self._CheckBatching(fun, 5, bdims, (shape,) * arity,
(np.float32,) * arity, rng)
# TODO Concatenate
# TODO Reverse
# TODO DynamicSlice
@ -734,7 +726,6 @@ class LaxVmapTest(jtu.JaxTestCase):
padding=((0, 0),))
return x[idxs]
inpt = jnp.array([
[1.0, 0.0, 1.0],
[2.0, 2.0, 0.0],