sparse_test: Split into two so that each target is small enough to fit within a medium timeout.

PiperOrigin-RevId: 570882867
This commit is contained in:
jax authors 2023-10-04 19:56:04 -07:00
parent 465eb21561
commit 1c37f5091c
4 changed files with 2194 additions and 1816 deletions

View File

@ -13,26 +13,55 @@
# limitations under the License.
"""Sparse test utilities."""
from collections.abc import Sequence
from collections.abc import Iterable, Iterator, Sequence
import functools
import itertools
import math
from typing import Any, Callable, Union
import numpy as np
from typing import NamedTuple
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src.typing import DTypeLike
from jax import tree_util
from jax.util import safe_zip, split_list
from jax._src import test_util as jtu
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lib import gpu_sparse
from jax._src.typing import DTypeLike
from jax.experimental import sparse
import jax.numpy as jnp
from jax.util import safe_zip, split_list
import numpy as np
MATMUL_TOL = {
np.float32: 1e-5,
np.float64: 1e-10,
np.complex64: 1e-5,
np.complex128: 1e-10,
}
GPU_LOWERING_ENABLED = gpu_sparse and (
gpu_sparse.cuda_is_supported or gpu_sparse.rocm_is_supported
)
def is_sparse(x):
return isinstance(x, sparse.JAXSparse)
class BatchedDotGeneralProperties(NamedTuple):
lhs_shape: tuple[int, ...]
rhs_shape: tuple[int, ...]
n_batch: int
n_dense: int
dimension_numbers: DotDimensionNumbers
class SparseLayout(NamedTuple):
n_batch: int
n_dense: int
n_sparse: int
class SparseTestCase(jtu.JaxTestCase):
def assertSparseArraysEquivalent(self, x, y, *, check_dtypes=True, atol=None,
rtol=None, canonicalize_dtypes=True, err_msg=''):
@ -170,3 +199,43 @@ def rand_bcsr(rng: np.random.RandomState,
return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method,
nse=nse, n_batch=n_batch, n_dense=n_dense,
sparse_format='bcsr')
def iter_subsets(s: Sequence) -> Iterable[tuple]:
"""Return an iterator over all subsets of a sequence s"""
return itertools.chain.from_iterable(
itertools.combinations(s, n) for n in range(len(s) + 1)
)
def iter_sparse_layouts(
shape: Sequence[int], min_n_batch=0
) -> Iterator[SparseLayout]:
for n_batch in range(min_n_batch, len(shape) + 1):
for n_dense in range(len(shape) + 1 - n_batch):
n_sparse = len(shape) - n_batch - n_dense
yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense)
def iter_bcsr_layouts(
shape: Sequence[int], min_n_batch=0
) -> Iterator[SparseLayout]:
n_sparse = 2
for n_batch in range(min_n_batch, len(shape) - 1):
n_dense = len(shape) - n_sparse - n_batch
yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense)
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
def _rand_sparse(shape, dtype, nse=nse):
rand = rand_method(rng)
size = math.prod(shape)
if 0 <= nse < 1:
nse = nse * size
nse = min(size, int(nse))
M = rand(shape, dtype)
indices = rng.choice(size, size - nse, replace=False)
M.flat[indices] = 0
return post(M)
return _rand_sparse

View File

@ -834,7 +834,40 @@ jax_test(
backend_variant_args = {
"cpu": ["--jax_num_generated_cases=40"],
"cpu_x32": ["--jax_num_generated_cases=40"],
"cpu_no_jax_array": ["--jax_num_generated_cases=40"],
"gpu": ["--jax_num_generated_cases=40"],
},
shard_count = {
"cpu": 50,
"gpu": 50,
"tpu": 50,
"iree": 10,
},
tags = [
"noasan",
"nomsan",
"notsan",
], # Test times out under asan/msan/tsan.
deps = [
"//jax:experimental_sparse",
"//jax:sparse_test_util",
] + py_deps("scipy"),
)
jax_test(
name = "sparse_bcoo_bcsr_test",
srcs = ["sparse_bcoo_bcsr_test.py"],
args = ["--jax_bcoo_cusparse_lowering=true"],
backend_tags = {
"cpu": [
"nomsan", # Times out
"notsan", # Times out
],
"tpu": ["optonly"],
},
# Use fewer cases to prevent timeouts.
backend_variant_args = {
"cpu": ["--jax_num_generated_cases=40"],
"cpu_x32": ["--jax_num_generated_cases=40"],
"gpu": ["--jax_num_generated_cases=40"],
},
shard_count = {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff