mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
sparse_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 570882867
This commit is contained in:
parent
465eb21561
commit
1c37f5091c
@ -13,26 +13,55 @@
|
||||
# limitations under the License.
|
||||
"""Sparse test utilities."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import numpy as np
|
||||
from typing import NamedTuple
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.typing import DTypeLike
|
||||
from jax import tree_util
|
||||
from jax.util import safe_zip, split_list
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.typing import DTypeLike
|
||||
from jax.experimental import sparse
|
||||
import jax.numpy as jnp
|
||||
from jax.util import safe_zip, split_list
|
||||
import numpy as np
|
||||
|
||||
MATMUL_TOL = {
|
||||
np.float32: 1e-5,
|
||||
np.float64: 1e-10,
|
||||
np.complex64: 1e-5,
|
||||
np.complex128: 1e-10,
|
||||
}
|
||||
|
||||
GPU_LOWERING_ENABLED = gpu_sparse and (
|
||||
gpu_sparse.cuda_is_supported or gpu_sparse.rocm_is_supported
|
||||
)
|
||||
|
||||
|
||||
def is_sparse(x):
|
||||
return isinstance(x, sparse.JAXSparse)
|
||||
|
||||
|
||||
class BatchedDotGeneralProperties(NamedTuple):
|
||||
lhs_shape: tuple[int, ...]
|
||||
rhs_shape: tuple[int, ...]
|
||||
n_batch: int
|
||||
n_dense: int
|
||||
dimension_numbers: DotDimensionNumbers
|
||||
|
||||
|
||||
class SparseLayout(NamedTuple):
|
||||
n_batch: int
|
||||
n_dense: int
|
||||
n_sparse: int
|
||||
|
||||
|
||||
class SparseTestCase(jtu.JaxTestCase):
|
||||
def assertSparseArraysEquivalent(self, x, y, *, check_dtypes=True, atol=None,
|
||||
rtol=None, canonicalize_dtypes=True, err_msg=''):
|
||||
@ -170,3 +199,43 @@ def rand_bcsr(rng: np.random.RandomState,
|
||||
return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method,
|
||||
nse=nse, n_batch=n_batch, n_dense=n_dense,
|
||||
sparse_format='bcsr')
|
||||
|
||||
|
||||
def iter_subsets(s: Sequence) -> Iterable[tuple]:
|
||||
"""Return an iterator over all subsets of a sequence s"""
|
||||
return itertools.chain.from_iterable(
|
||||
itertools.combinations(s, n) for n in range(len(s) + 1)
|
||||
)
|
||||
|
||||
|
||||
def iter_sparse_layouts(
|
||||
shape: Sequence[int], min_n_batch=0
|
||||
) -> Iterator[SparseLayout]:
|
||||
for n_batch in range(min_n_batch, len(shape) + 1):
|
||||
for n_dense in range(len(shape) + 1 - n_batch):
|
||||
n_sparse = len(shape) - n_batch - n_dense
|
||||
yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense)
|
||||
|
||||
|
||||
def iter_bcsr_layouts(
|
||||
shape: Sequence[int], min_n_batch=0
|
||||
) -> Iterator[SparseLayout]:
|
||||
n_sparse = 2
|
||||
for n_batch in range(min_n_batch, len(shape) - 1):
|
||||
n_dense = len(shape) - n_sparse - n_batch
|
||||
yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense)
|
||||
|
||||
|
||||
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
||||
def _rand_sparse(shape, dtype, nse=nse):
|
||||
rand = rand_method(rng)
|
||||
size = math.prod(shape)
|
||||
if 0 <= nse < 1:
|
||||
nse = nse * size
|
||||
nse = min(size, int(nse))
|
||||
M = rand(shape, dtype)
|
||||
indices = rng.choice(size, size - nse, replace=False)
|
||||
M.flat[indices] = 0
|
||||
return post(M)
|
||||
|
||||
return _rand_sparse
|
||||
|
35
tests/BUILD
35
tests/BUILD
@ -834,7 +834,40 @@ jax_test(
|
||||
backend_variant_args = {
|
||||
"cpu": ["--jax_num_generated_cases=40"],
|
||||
"cpu_x32": ["--jax_num_generated_cases=40"],
|
||||
"cpu_no_jax_array": ["--jax_num_generated_cases=40"],
|
||||
"gpu": ["--jax_num_generated_cases=40"],
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 50,
|
||||
"gpu": 50,
|
||||
"tpu": 50,
|
||||
"iree": 10,
|
||||
},
|
||||
tags = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
], # Test times out under asan/msan/tsan.
|
||||
deps = [
|
||||
"//jax:experimental_sparse",
|
||||
"//jax:sparse_test_util",
|
||||
] + py_deps("scipy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "sparse_bcoo_bcsr_test",
|
||||
srcs = ["sparse_bcoo_bcsr_test.py"],
|
||||
args = ["--jax_bcoo_cusparse_lowering=true"],
|
||||
backend_tags = {
|
||||
"cpu": [
|
||||
"nomsan", # Times out
|
||||
"notsan", # Times out
|
||||
],
|
||||
"tpu": ["optonly"],
|
||||
},
|
||||
# Use fewer cases to prevent timeouts.
|
||||
backend_variant_args = {
|
||||
"cpu": ["--jax_num_generated_cases=40"],
|
||||
"cpu_x32": ["--jax_num_generated_cases=40"],
|
||||
"gpu": ["--jax_num_generated_cases=40"],
|
||||
},
|
||||
shard_count = {
|
||||
|
1958
tests/sparse_bcoo_bcsr_test.py
Normal file
1958
tests/sparse_bcoo_bcsr_test.py
Normal file
File diff suppressed because it is too large
Load Diff
1936
tests/sparse_test.py
1936
tests/sparse_test.py
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user