mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
Unfortunately we can't conditionally change the timeout, as size and timeout are both non-configurable even if jax_test supported setting the size. PiperOrigin-RevId: 514745247
This commit is contained in:
parent
ea68198f37
commit
22cbf95e07
@ -79,7 +79,9 @@ py_library(
|
||||
testonly = 1,
|
||||
srcs = glob(["_src/internal_test_util/*.py"]),
|
||||
visibility = [":internal"],
|
||||
deps = [":jax"] + py_deps("numpy"),
|
||||
deps = [
|
||||
":jax",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
|
@ -18,17 +18,22 @@
|
||||
# cycle.
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
from typing import Optional, cast
|
||||
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
# For standard unops and binops, we can generate a large number of tests on
|
||||
# arguments of appropriate shapes and dtypes using the following table.
|
||||
@ -345,3 +350,28 @@ def lax_ops():
|
||||
op_record("le", 2, default_dtypes, test_util.rand_small),
|
||||
op_record("lt", 2, default_dtypes, test_util.rand_small),
|
||||
]
|
||||
|
||||
|
||||
def all_bdims(*shapes):
|
||||
bdims = (itertools.chain([cast(Optional[int], None)],
|
||||
range(len(shape) + 1)) for shape in shapes)
|
||||
return (t for t in itertools.product(*bdims) if not all(e is None for e in t))
|
||||
|
||||
|
||||
def add_bdim(bdim_size, bdim, shape):
|
||||
shape = list(shape)
|
||||
if bdim is not None:
|
||||
shape.insert(bdim, bdim_size)
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
def slicer(x, bdim):
|
||||
if bdim is None:
|
||||
return lambda _: x
|
||||
else:
|
||||
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
|
||||
|
||||
|
||||
def args_slicer(args, bdims):
|
||||
slicers = map(slicer, args, bdims)
|
||||
return lambda i: [sl(i) for sl in slicers]
|
||||
|
17
tests/BUILD
17
tests/BUILD
@ -470,16 +470,25 @@ jax_test(
|
||||
jax_test(
|
||||
name = "lax_vmap_test",
|
||||
srcs = ["lax_vmap_test.py"],
|
||||
backend_tags = {
|
||||
"tpu": ["noasan"], # Test times out.
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
"iree": 40,
|
||||
},
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy"),
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_vmap_op_test",
|
||||
srcs = ["lax_vmap_op_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
"iree": 40,
|
||||
},
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
85
tests/lax_vmap_op_test.py
Normal file
85
tests/lax_vmap_op_test.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src import lib
|
||||
from jax._src import util
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
|
||||
class LaxVmapOpTest(jtu.JaxTestCase):
|
||||
|
||||
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
|
||||
rtol=None, atol=None, multiple_results=False):
|
||||
batched_shapes = map(functools.partial(lax_test_util.add_bdim, bdim_size),
|
||||
bdims, shapes)
|
||||
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
|
||||
args_slice = lax_test_util.args_slicer(args, bdims)
|
||||
ans = jax.vmap(op, bdims)(*args)
|
||||
if bdim_size == 0:
|
||||
args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
out = op(*args)
|
||||
if not multiple_results:
|
||||
expected = np.zeros((0,) + out.shape, out.dtype)
|
||||
else:
|
||||
expected = [np.zeros((0,) + o.shape, o.dtype) for o in out]
|
||||
else:
|
||||
outs = [op(*args_slice(i)) for i in range(bdim_size)]
|
||||
if not multiple_results:
|
||||
expected = np.stack(outs)
|
||||
else:
|
||||
expected = [np.stack(xs) for xs in zip(*outs)]
|
||||
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)],
|
||||
[dict(shapes=shapes, bdims=bdims)
|
||||
for shape_group in lax_test_util.compatible_shapes
|
||||
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
|
||||
for bdims in lax_test_util.all_bdims(*shapes)],
|
||||
dtype=rec.dtypes,
|
||||
) for rec in lax_test_util.lax_ops()))
|
||||
def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
|
||||
# TODO(pizzud): Make this unconditional after the next minimum jaxlib bump.
|
||||
if lib.xla_extension_version >= 134:
|
||||
if dtype == np.float64 or any(len(shape) > 2 for shape in shapes):
|
||||
self.skipTest('Skipping big tests under sanitizers due to slowdown.')
|
||||
|
||||
rng = rng_factory(self.rng())
|
||||
op = getattr(lax, op_name)
|
||||
self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -20,7 +20,6 @@ from typing import Optional, cast
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -43,34 +42,13 @@ map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
def all_bdims(*shapes):
|
||||
bdims = (itertools.chain([cast(Optional[int], None)],
|
||||
range(len(shape) + 1)) for shape in shapes)
|
||||
return (t for t in itertools.product(*bdims) if not all(e is None for e in t))
|
||||
|
||||
def add_bdim(bdim_size, bdim, shape):
|
||||
shape = list(shape)
|
||||
if bdim is not None:
|
||||
shape.insert(bdim, bdim_size)
|
||||
return tuple(shape)
|
||||
|
||||
def slicer(x, bdim):
|
||||
if bdim is None:
|
||||
return lambda _: x
|
||||
else:
|
||||
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
|
||||
|
||||
def args_slicer(args, bdims):
|
||||
slicers = map(slicer, args, bdims)
|
||||
return lambda i: [sl(i) for sl in slicers]
|
||||
|
||||
class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
|
||||
rtol=None, atol=None, multiple_results=False):
|
||||
batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
|
||||
batched_shapes = map(partial(lax_test_util.add_bdim, bdim_size), bdims, shapes)
|
||||
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
|
||||
args_slice = args_slicer(args, bdims)
|
||||
args_slice = lax_test_util.args_slicer(args, bdims)
|
||||
ans = jax.vmap(op, bdims)(*args)
|
||||
if bdim_size == 0:
|
||||
args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
@ -87,21 +65,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
expected = [np.stack(xs) for xs in zip(*outs)]
|
||||
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)],
|
||||
[dict(shapes=shapes, bdims=bdims)
|
||||
for shape_group in lax_test_util.compatible_shapes
|
||||
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
|
||||
for bdims in all_bdims(*shapes)],
|
||||
dtype=rec.dtypes,
|
||||
) for rec in lax_test_util.lax_ops()))
|
||||
def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
|
||||
rng = rng_factory(self.rng())
|
||||
op = getattr(lax, op_name)
|
||||
self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape,
|
||||
batch_group_count=batch_group_count,
|
||||
@ -157,7 +120,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
repeat=2)
|
||||
],
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(2, 3)] for bdims in all_bdims(shape)]
|
||||
for shape in [(2, 3)] for bdims in lax_test_util.all_bdims(shape)]
|
||||
)
|
||||
def testConvertElementType(self, shape, from_dtype, to_dtype, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -166,7 +129,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(2, 4)] for bdims in all_bdims(shape)],
|
||||
for shape in [(2, 4)] for bdims in lax_test_util.all_bdims(shape)],
|
||||
dtype=lax_test_util.float_dtypes,
|
||||
nexp=[1, 3, 5],
|
||||
nmant=[0, 2, 4],
|
||||
@ -182,7 +145,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
repeat=2)
|
||||
],
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(2, 3)] for bdims in all_bdims(shape)]
|
||||
for shape in [(2, 3)] for bdims in lax_test_util.all_bdims(shape)]
|
||||
)
|
||||
def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims,):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -198,7 +161,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(), (2, 3), (2, 3)],
|
||||
[(2, 3), (2, 3), (2, 3)],
|
||||
]
|
||||
for bdims in all_bdims(min_shape, operand_shape, max_shape)
|
||||
for bdims in lax_test_util.all_bdims(min_shape, operand_shape, max_shape)
|
||||
],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
@ -210,7 +173,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, bdims=bdims)
|
||||
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
|
||||
for bdims in all_bdims(lhs_shape, rhs_shape)],
|
||||
for bdims in lax_test_util.all_bdims(lhs_shape, rhs_shape)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
|
||||
@ -234,7 +197,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
|
||||
[(3, 2), (2, 4), [1], [0]],
|
||||
]
|
||||
for bdims in all_bdims(lhs_shape, rhs_shape)],
|
||||
for bdims in lax_test_util.all_bdims(lhs_shape, rhs_shape)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
|
||||
@ -253,7 +216,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
|
||||
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
|
||||
]
|
||||
for bdims in all_bdims(lhs_shape, rhs_shape)],
|
||||
for bdims in lax_test_util.all_bdims(lhs_shape, rhs_shape)],
|
||||
dtype=lax_test_util.default_dtypes)
|
||||
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
|
||||
dimension_numbers, bdims):
|
||||
@ -270,7 +233,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(), (2, 3)] for bdims in all_bdims(shape)],
|
||||
for shape in [(), (2, 3)] for bdims in lax_test_util.all_bdims(shape)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
broadcast_sizes=[(), (2,), (1, 2)],
|
||||
)
|
||||
@ -288,7 +251,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
([2], [2, 3], [0]),
|
||||
([], [2, 3], []),
|
||||
]
|
||||
for bdims in all_bdims(inshape)],
|
||||
for bdims in lax_test_util.all_bdims(inshape)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
@unittest.skip("this test has failures in some cases") # TODO(mattjj)
|
||||
@ -309,7 +272,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(2, 1, 3, 1), (3,)],
|
||||
[(2, 1, 3, 1), (1, -1)],
|
||||
]
|
||||
for bdims in all_bdims(arg_shape)],
|
||||
for bdims in lax_test_util.all_bdims(arg_shape)],
|
||||
)
|
||||
def testSqueeze(self, arg_shape, dimensions, bdims):
|
||||
dtype = np.float32
|
||||
@ -328,7 +291,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(2, 2, 4), (1, 0, 2), (8, 2)],
|
||||
[(2, 2, 4), (2, 1, 0), (4, 2, 2)]
|
||||
]
|
||||
for bdims in all_bdims(arg_shape)],
|
||||
for bdims in lax_test_util.all_bdims(arg_shape)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
|
||||
@ -338,7 +301,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(2, 3)] for bdims in all_bdims(shape, ())],
|
||||
for shape in [(2, 3)] for bdims in lax_test_util.all_bdims(shape, ())],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
pads=[[(1, 2, 1), (0, 1, 0)]],
|
||||
)
|
||||
@ -351,7 +314,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[dict(arg_shape=arg_shape, pred_shape=pred_shape, bdims=bdims)
|
||||
for arg_shape in [(), (3,), (2, 3)]
|
||||
for pred_shape in ([(), arg_shape] if arg_shape else [()])
|
||||
for bdims in all_bdims(pred_shape, arg_shape, arg_shape)],
|
||||
for bdims in lax_test_util.all_bdims(pred_shape, arg_shape, arg_shape)],
|
||||
arg_dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
|
||||
@ -374,7 +337,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(5, 3), (1, 1), (2, 1), (1, 1)],
|
||||
[(5, 3), (1, 1), (5, 3), (2, 1)],
|
||||
]
|
||||
for bdims in all_bdims(shape)
|
||||
for bdims in lax_test_util.all_bdims(shape)
|
||||
],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
@ -391,7 +354,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (2, 1, 0)],
|
||||
[(3, 4, 5), (1, 0, 2)],
|
||||
]
|
||||
for bdims in all_bdims(shape)
|
||||
for bdims in lax_test_util.all_bdims(shape)
|
||||
],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
@ -422,7 +385,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
|
||||
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
||||
]
|
||||
for bdims in all_bdims(shape)],
|
||||
for bdims in lax_test_util.all_bdims(shape)],
|
||||
)
|
||||
def testReduce(self, op, init_val, shape, dtype, dims, bdims):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
@ -436,7 +399,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
|
||||
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
||||
]
|
||||
for bdims in all_bdims(shape, shape)],
|
||||
for bdims in lax_test_util.all_bdims(shape, shape)],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
def testVariadicReduce(self, shape, dtype, dims, bdims):
|
||||
@ -453,7 +416,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims, dim=dim)
|
||||
for shape in [(3, 4, 5)]
|
||||
for bdims in all_bdims(shape)
|
||||
for bdims in lax_test_util.all_bdims(shape)
|
||||
for dim in range(len(shape))],
|
||||
op=[lax.argmin, lax.argmax],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
@ -499,7 +462,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
|
||||
for bdims in all_bdims(shape):
|
||||
for bdims in lax_test_util.all_bdims(shape):
|
||||
self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -512,7 +475,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
[dict(shape=shape, bdims=bdims, axis=axis)
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for bdims in all_bdims(shape)],
|
||||
for bdims in lax_test_util.all_bdims(shape)],
|
||||
reverse=[False, True],
|
||||
)
|
||||
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
|
||||
@ -532,7 +495,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)]))
|
||||
for bdims in all_bdims(shape, shape)
|
||||
for bdims in lax_test_util.all_bdims(shape, shape)
|
||||
],
|
||||
dtype=lax_test_util.float_dtypes,
|
||||
padding=["VALID", "SAME"]
|
||||
@ -569,14 +532,14 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
x, x, lax.ge_p, dims, strides, pads, ones, ones),
|
||||
np.ones(shape, dtype)).shape
|
||||
|
||||
for bdims in all_bdims(cotangent_shape, shape):
|
||||
for bdims in lax_test_util.all_bdims(cotangent_shape, shape):
|
||||
self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
|
||||
(dtype, dtype), rng)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, fft_ndims=fft_ndims, bdims=bdims)
|
||||
for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
|
||||
for bdims in all_bdims(shape)
|
||||
for bdims in lax_test_util.all_bdims(shape)
|
||||
for fft_ndims in range(0, min(3, len(shape)) + 1)],
|
||||
)
|
||||
def testFft(self, fft_ndims, shape, bdims):
|
||||
@ -605,7 +568,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
|
||||
(1, 3)),
|
||||
]
|
||||
for bdims in all_bdims(shape, idxs.shape)],
|
||||
for bdims in lax_test_util.all_bdims(shape, idxs.shape)],
|
||||
dtype=lax_test_util.all_dtypes
|
||||
)
|
||||
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
|
||||
@ -629,7 +592,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for bdims in all_bdims(arg_shape, idxs.shape, update_shape)],
|
||||
for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape, update_shape)],
|
||||
dtype=lax_test_util.float_dtypes
|
||||
)
|
||||
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
|
||||
@ -659,7 +622,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, bdims=bdims)
|
||||
for shape in [(4,), (3, 5, 3)]
|
||||
for bdims in all_bdims(shape)],
|
||||
for bdims in lax_test_util.all_bdims(shape)],
|
||||
k=[1, 3],
|
||||
dtype=lax_test_util.default_dtypes,
|
||||
)
|
||||
@ -680,7 +643,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for shape in [(2, 3)]
|
||||
for dimension in [0, 1]
|
||||
for arity in range(3)
|
||||
for bdims in all_bdims(*((shape,) * arity))
|
||||
for bdims in lax_test_util.all_bdims(*((shape,) * arity))
|
||||
],
|
||||
is_stable=[False, True]
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user