lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.

Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.

PiperOrigin-RevId: 514745247
This commit is contained in:
pizzud 2023-03-07 08:49:05 -08:00 committed by jax authors
parent ea68198f37
commit 22cbf95e07
5 changed files with 161 additions and 72 deletions

View File

@ -79,7 +79,9 @@ py_library(
testonly = 1,
srcs = glob(["_src/internal_test_util/*.py"]),
visibility = [":internal"],
deps = [":jax"] + py_deps("numpy"),
deps = [
":jax",
] + py_deps("numpy"),
)
py_library_providing_imports_info(

View File

@ -18,17 +18,22 @@
# cycle.
import collections
import itertools
from typing import Optional, cast
from jax import lax
from jax._src import dtypes
from jax._src import test_util
from jax._src.util import safe_map, safe_zip
import numpy as np
from jax.config import config
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# For standard unops and binops, we can generate a large number of tests on
# arguments of appropriate shapes and dtypes using the following table.
@ -345,3 +350,28 @@ def lax_ops():
op_record("le", 2, default_dtypes, test_util.rand_small),
op_record("lt", 2, default_dtypes, test_util.rand_small),
]
def all_bdims(*shapes):
bdims = (itertools.chain([cast(Optional[int], None)],
range(len(shape) + 1)) for shape in shapes)
return (t for t in itertools.product(*bdims) if not all(e is None for e in t))
def add_bdim(bdim_size, bdim, shape):
shape = list(shape)
if bdim is not None:
shape.insert(bdim, bdim_size)
return tuple(shape)
def slicer(x, bdim):
if bdim is None:
return lambda _: x
else:
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
def args_slicer(args, bdims):
slicers = map(slicer, args, bdims)
return lambda i: [sl(i) for sl in slicers]

View File

@ -470,16 +470,25 @@ jax_test(
jax_test(
name = "lax_vmap_test",
srcs = ["lax_vmap_test.py"],
backend_tags = {
"tpu": ["noasan"], # Test times out.
},
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 40,
"iree": 40,
},
deps = ["//jax:internal_test_util"] + py_deps("numpy"),
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
)
jax_test(
name = "lax_vmap_op_test",
srcs = ["lax_vmap_op_test.py"],
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 40,
"iree": 40,
},
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
)
py_test(

85
tests/lax_vmap_op_test.py Normal file
View File

@ -0,0 +1,85 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src.internal_test_util import lax_test_util
from jax._src import lib
from jax._src import util
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
class LaxVmapOpTest(jtu.JaxTestCase):
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
rtol=None, atol=None, multiple_results=False):
batched_shapes = map(functools.partial(lax_test_util.add_bdim, bdim_size),
bdims, shapes)
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
args_slice = lax_test_util.args_slicer(args, bdims)
ans = jax.vmap(op, bdims)(*args)
if bdim_size == 0:
args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
out = op(*args)
if not multiple_results:
expected = np.zeros((0,) + out.shape, out.dtype)
else:
expected = [np.zeros((0,) + o.shape, o.dtype) for o in out]
else:
outs = [op(*args_slice(i)) for i in range(bdim_size)]
if not multiple_results:
expected = np.stack(outs)
else:
expected = [np.stack(xs) for xs in zip(*outs)]
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)],
[dict(shapes=shapes, bdims=bdims)
for shape_group in lax_test_util.compatible_shapes
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
for bdims in lax_test_util.all_bdims(*shapes)],
dtype=rec.dtypes,
) for rec in lax_test_util.lax_ops()))
def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
# TODO(pizzud): Make this unconditional after the next minimum jaxlib bump.
if lib.xla_extension_version >= 134:
if dtype == np.float64 or any(len(shape) > 2 for shape in shapes):
self.skipTest('Skipping big tests under sanitizers due to slowdown.')
rng = rng_factory(self.rng())
op = getattr(lax, op_name)
self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng,
atol=tol, rtol=tol)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -20,7 +20,6 @@ from typing import Optional, cast
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
@ -43,34 +42,13 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def all_bdims(*shapes):
bdims = (itertools.chain([cast(Optional[int], None)],
range(len(shape) + 1)) for shape in shapes)
return (t for t in itertools.product(*bdims) if not all(e is None for e in t))
def add_bdim(bdim_size, bdim, shape):
shape = list(shape)
if bdim is not None:
shape.insert(bdim, bdim_size)
return tuple(shape)
def slicer(x, bdim):
if bdim is None:
return lambda _: x
else:
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
def args_slicer(args, bdims):
slicers = map(slicer, args, bdims)
return lambda i: [sl(i) for sl in slicers]
class LaxVmapTest(jtu.JaxTestCase):
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
rtol=None, atol=None, multiple_results=False):
batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
batched_shapes = map(partial(lax_test_util.add_bdim, bdim_size), bdims, shapes)
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
args_slice = args_slicer(args, bdims)
args_slice = lax_test_util.args_slicer(args, bdims)
ans = jax.vmap(op, bdims)(*args)
if bdim_size == 0:
args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@ -87,21 +65,6 @@ class LaxVmapTest(jtu.JaxTestCase):
expected = [np.stack(xs) for xs in zip(*outs)]
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op_name=rec.op, rng_factory=rec.rng_factory, tol=rec.tol)],
[dict(shapes=shapes, bdims=bdims)
for shape_group in lax_test_util.compatible_shapes
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
for bdims in all_bdims(*shapes)],
dtype=rec.dtypes,
) for rec in lax_test_util.lax_ops()))
def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
rng = rng_factory(self.rng())
op = getattr(lax, op_name)
self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng,
atol=tol, rtol=tol)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape,
batch_group_count=batch_group_count,
@ -157,7 +120,7 @@ class LaxVmapTest(jtu.JaxTestCase):
repeat=2)
],
[dict(shape=shape, bdims=bdims)
for shape in [(2, 3)] for bdims in all_bdims(shape)]
for shape in [(2, 3)] for bdims in lax_test_util.all_bdims(shape)]
)
def testConvertElementType(self, shape, from_dtype, to_dtype, bdims):
rng = jtu.rand_default(self.rng())
@ -166,7 +129,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(2, 4)] for bdims in all_bdims(shape)],
for shape in [(2, 4)] for bdims in lax_test_util.all_bdims(shape)],
dtype=lax_test_util.float_dtypes,
nexp=[1, 3, 5],
nmant=[0, 2, 4],
@ -182,7 +145,7 @@ class LaxVmapTest(jtu.JaxTestCase):
repeat=2)
],
[dict(shape=shape, bdims=bdims)
for shape in [(2, 3)] for bdims in all_bdims(shape)]
for shape in [(2, 3)] for bdims in lax_test_util.all_bdims(shape)]
)
def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims,):
rng = jtu.rand_default(self.rng())
@ -198,7 +161,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(), (2, 3), (2, 3)],
[(2, 3), (2, 3), (2, 3)],
]
for bdims in all_bdims(min_shape, operand_shape, max_shape)
for bdims in lax_test_util.all_bdims(min_shape, operand_shape, max_shape)
],
dtype=lax_test_util.default_dtypes,
)
@ -210,7 +173,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, bdims=bdims)
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]
for bdims in all_bdims(lhs_shape, rhs_shape)],
for bdims in lax_test_util.all_bdims(lhs_shape, rhs_shape)],
dtype=lax_test_util.default_dtypes,
)
def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
@ -234,7 +197,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
[(3, 2), (2, 4), [1], [0]],
]
for bdims in all_bdims(lhs_shape, rhs_shape)],
for bdims in lax_test_util.all_bdims(lhs_shape, rhs_shape)],
dtype=lax_test_util.default_dtypes,
)
def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
@ -253,7 +216,7 @@ class LaxVmapTest(jtu.JaxTestCase):
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
for bdims in all_bdims(lhs_shape, rhs_shape)],
for bdims in lax_test_util.all_bdims(lhs_shape, rhs_shape)],
dtype=lax_test_util.default_dtypes)
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
dimension_numbers, bdims):
@ -270,7 +233,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(), (2, 3)] for bdims in all_bdims(shape)],
for shape in [(), (2, 3)] for bdims in lax_test_util.all_bdims(shape)],
dtype=lax_test_util.default_dtypes,
broadcast_sizes=[(), (2,), (1, 2)],
)
@ -288,7 +251,7 @@ class LaxVmapTest(jtu.JaxTestCase):
([2], [2, 3], [0]),
([], [2, 3], []),
]
for bdims in all_bdims(inshape)],
for bdims in lax_test_util.all_bdims(inshape)],
dtype=lax_test_util.default_dtypes,
)
@unittest.skip("this test has failures in some cases") # TODO(mattjj)
@ -309,7 +272,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(2, 1, 3, 1), (3,)],
[(2, 1, 3, 1), (1, -1)],
]
for bdims in all_bdims(arg_shape)],
for bdims in lax_test_util.all_bdims(arg_shape)],
)
def testSqueeze(self, arg_shape, dimensions, bdims):
dtype = np.float32
@ -328,7 +291,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(2, 2, 4), (1, 0, 2), (8, 2)],
[(2, 2, 4), (2, 1, 0), (4, 2, 2)]
]
for bdims in all_bdims(arg_shape)],
for bdims in lax_test_util.all_bdims(arg_shape)],
dtype=lax_test_util.default_dtypes,
)
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
@ -338,7 +301,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(2, 3)] for bdims in all_bdims(shape, ())],
for shape in [(2, 3)] for bdims in lax_test_util.all_bdims(shape, ())],
dtype=lax_test_util.default_dtypes,
pads=[[(1, 2, 1), (0, 1, 0)]],
)
@ -351,7 +314,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[dict(arg_shape=arg_shape, pred_shape=pred_shape, bdims=bdims)
for arg_shape in [(), (3,), (2, 3)]
for pred_shape in ([(), arg_shape] if arg_shape else [()])
for bdims in all_bdims(pred_shape, arg_shape, arg_shape)],
for bdims in lax_test_util.all_bdims(pred_shape, arg_shape, arg_shape)],
arg_dtype=lax_test_util.default_dtypes,
)
def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
@ -374,7 +337,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(5, 3), (1, 1), (2, 1), (1, 1)],
[(5, 3), (1, 1), (5, 3), (2, 1)],
]
for bdims in all_bdims(shape)
for bdims in lax_test_util.all_bdims(shape)
],
dtype=lax_test_util.default_dtypes,
)
@ -391,7 +354,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(3, 4, 5), (2, 1, 0)],
[(3, 4, 5), (1, 0, 2)],
]
for bdims in all_bdims(shape)
for bdims in lax_test_util.all_bdims(shape)
],
dtype=lax_test_util.default_dtypes,
)
@ -422,7 +385,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
]
for bdims in all_bdims(shape)],
for bdims in lax_test_util.all_bdims(shape)],
)
def testReduce(self, op, init_val, shape, dtype, dims, bdims):
rng = jtu.rand_small(self.rng())
@ -436,7 +399,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
]
for bdims in all_bdims(shape, shape)],
for bdims in lax_test_util.all_bdims(shape, shape)],
dtype=lax_test_util.default_dtypes,
)
def testVariadicReduce(self, shape, dtype, dims, bdims):
@ -453,7 +416,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims, dim=dim)
for shape in [(3, 4, 5)]
for bdims in all_bdims(shape)
for bdims in lax_test_util.all_bdims(shape)
for dim in range(len(shape))],
op=[lax.argmin, lax.argmax],
dtype=lax_test_util.default_dtypes,
@ -499,7 +462,7 @@ class LaxVmapTest(jtu.JaxTestCase):
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
base_dilation, window_dilation)
for bdims in all_bdims(shape):
for bdims in lax_test_util.all_bdims(shape):
self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)
@jtu.sample_product(
@ -512,7 +475,7 @@ class LaxVmapTest(jtu.JaxTestCase):
[dict(shape=shape, bdims=bdims, axis=axis)
for shape in [[10], [3, 4, 5]]
for axis in range(len(shape))
for bdims in all_bdims(shape)],
for bdims in lax_test_util.all_bdims(shape)],
reverse=[False, True],
)
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
@ -532,7 +495,7 @@ class LaxVmapTest(jtu.JaxTestCase):
itertools.product(
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)]))
for bdims in all_bdims(shape, shape)
for bdims in lax_test_util.all_bdims(shape, shape)
],
dtype=lax_test_util.float_dtypes,
padding=["VALID", "SAME"]
@ -569,14 +532,14 @@ class LaxVmapTest(jtu.JaxTestCase):
x, x, lax.ge_p, dims, strides, pads, ones, ones),
np.ones(shape, dtype)).shape
for bdims in all_bdims(cotangent_shape, shape):
for bdims in lax_test_util.all_bdims(cotangent_shape, shape):
self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
(dtype, dtype), rng)
@jtu.sample_product(
[dict(shape=shape, fft_ndims=fft_ndims, bdims=bdims)
for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)]
for bdims in all_bdims(shape)
for bdims in lax_test_util.all_bdims(shape)
for fft_ndims in range(0, min(3, len(shape)) + 1)],
)
def testFft(self, fft_ndims, shape, bdims):
@ -605,7 +568,7 @@ class LaxVmapTest(jtu.JaxTestCase):
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]
for bdims in all_bdims(shape, idxs.shape)],
for bdims in lax_test_util.all_bdims(shape, idxs.shape)],
dtype=lax_test_util.all_dtypes
)
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
@ -629,7 +592,7 @@ class LaxVmapTest(jtu.JaxTestCase):
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
for bdims in all_bdims(arg_shape, idxs.shape, update_shape)],
for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape, update_shape)],
dtype=lax_test_util.float_dtypes
)
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
@ -659,7 +622,7 @@ class LaxVmapTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(4,), (3, 5, 3)]
for bdims in all_bdims(shape)],
for bdims in lax_test_util.all_bdims(shape)],
k=[1, 3],
dtype=lax_test_util.default_dtypes,
)
@ -680,7 +643,7 @@ class LaxVmapTest(jtu.JaxTestCase):
for shape in [(2, 3)]
for dimension in [0, 1]
for arity in range(3)
for bdims in all_bdims(*((shape,) * arity))
for bdims in lax_test_util.all_bdims(*((shape,) * arity))
],
is_stable=[False, True]
)