mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
sparse_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 570882867
This commit is contained in:
parent
465eb21561
commit
1c37f5091c
@ -13,26 +13,55 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Sparse test utilities."""
|
"""Sparse test utilities."""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Iterable, Iterator, Sequence
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
from typing import Any, Callable, Union
|
from typing import Any, Callable, Union
|
||||||
|
from typing import NamedTuple
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax._src import test_util as jtu
|
|
||||||
from jax._src.typing import DTypeLike
|
|
||||||
from jax import tree_util
|
from jax import tree_util
|
||||||
from jax.util import safe_zip, split_list
|
from jax._src import test_util as jtu
|
||||||
|
from jax._src.lax.lax import DotDimensionNumbers
|
||||||
|
from jax._src.lib import gpu_sparse
|
||||||
|
from jax._src.typing import DTypeLike
|
||||||
from jax.experimental import sparse
|
from jax.experimental import sparse
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax.util import safe_zip, split_list
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
MATMUL_TOL = {
|
||||||
|
np.float32: 1e-5,
|
||||||
|
np.float64: 1e-10,
|
||||||
|
np.complex64: 1e-5,
|
||||||
|
np.complex128: 1e-10,
|
||||||
|
}
|
||||||
|
|
||||||
|
GPU_LOWERING_ENABLED = gpu_sparse and (
|
||||||
|
gpu_sparse.cuda_is_supported or gpu_sparse.rocm_is_supported
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_sparse(x):
|
def is_sparse(x):
|
||||||
return isinstance(x, sparse.JAXSparse)
|
return isinstance(x, sparse.JAXSparse)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedDotGeneralProperties(NamedTuple):
|
||||||
|
lhs_shape: tuple[int, ...]
|
||||||
|
rhs_shape: tuple[int, ...]
|
||||||
|
n_batch: int
|
||||||
|
n_dense: int
|
||||||
|
dimension_numbers: DotDimensionNumbers
|
||||||
|
|
||||||
|
|
||||||
|
class SparseLayout(NamedTuple):
|
||||||
|
n_batch: int
|
||||||
|
n_dense: int
|
||||||
|
n_sparse: int
|
||||||
|
|
||||||
|
|
||||||
class SparseTestCase(jtu.JaxTestCase):
|
class SparseTestCase(jtu.JaxTestCase):
|
||||||
def assertSparseArraysEquivalent(self, x, y, *, check_dtypes=True, atol=None,
|
def assertSparseArraysEquivalent(self, x, y, *, check_dtypes=True, atol=None,
|
||||||
rtol=None, canonicalize_dtypes=True, err_msg=''):
|
rtol=None, canonicalize_dtypes=True, err_msg=''):
|
||||||
@ -170,3 +199,43 @@ def rand_bcsr(rng: np.random.RandomState,
|
|||||||
return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method,
|
return functools.partial(_rand_sparse, rng=rng, rand_method=rand_method,
|
||||||
nse=nse, n_batch=n_batch, n_dense=n_dense,
|
nse=nse, n_batch=n_batch, n_dense=n_dense,
|
||||||
sparse_format='bcsr')
|
sparse_format='bcsr')
|
||||||
|
|
||||||
|
|
||||||
|
def iter_subsets(s: Sequence) -> Iterable[tuple]:
|
||||||
|
"""Return an iterator over all subsets of a sequence s"""
|
||||||
|
return itertools.chain.from_iterable(
|
||||||
|
itertools.combinations(s, n) for n in range(len(s) + 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_sparse_layouts(
|
||||||
|
shape: Sequence[int], min_n_batch=0
|
||||||
|
) -> Iterator[SparseLayout]:
|
||||||
|
for n_batch in range(min_n_batch, len(shape) + 1):
|
||||||
|
for n_dense in range(len(shape) + 1 - n_batch):
|
||||||
|
n_sparse = len(shape) - n_batch - n_dense
|
||||||
|
yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_bcsr_layouts(
|
||||||
|
shape: Sequence[int], min_n_batch=0
|
||||||
|
) -> Iterator[SparseLayout]:
|
||||||
|
n_sparse = 2
|
||||||
|
for n_batch in range(min_n_batch, len(shape) - 1):
|
||||||
|
n_dense = len(shape) - n_sparse - n_batch
|
||||||
|
yield SparseLayout(n_batch=n_batch, n_sparse=n_sparse, n_dense=n_dense)
|
||||||
|
|
||||||
|
|
||||||
|
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
||||||
|
def _rand_sparse(shape, dtype, nse=nse):
|
||||||
|
rand = rand_method(rng)
|
||||||
|
size = math.prod(shape)
|
||||||
|
if 0 <= nse < 1:
|
||||||
|
nse = nse * size
|
||||||
|
nse = min(size, int(nse))
|
||||||
|
M = rand(shape, dtype)
|
||||||
|
indices = rng.choice(size, size - nse, replace=False)
|
||||||
|
M.flat[indices] = 0
|
||||||
|
return post(M)
|
||||||
|
|
||||||
|
return _rand_sparse
|
||||||
|
35
tests/BUILD
35
tests/BUILD
@ -834,7 +834,40 @@ jax_test(
|
|||||||
backend_variant_args = {
|
backend_variant_args = {
|
||||||
"cpu": ["--jax_num_generated_cases=40"],
|
"cpu": ["--jax_num_generated_cases=40"],
|
||||||
"cpu_x32": ["--jax_num_generated_cases=40"],
|
"cpu_x32": ["--jax_num_generated_cases=40"],
|
||||||
"cpu_no_jax_array": ["--jax_num_generated_cases=40"],
|
"gpu": ["--jax_num_generated_cases=40"],
|
||||||
|
},
|
||||||
|
shard_count = {
|
||||||
|
"cpu": 50,
|
||||||
|
"gpu": 50,
|
||||||
|
"tpu": 50,
|
||||||
|
"iree": 10,
|
||||||
|
},
|
||||||
|
tags = [
|
||||||
|
"noasan",
|
||||||
|
"nomsan",
|
||||||
|
"notsan",
|
||||||
|
], # Test times out under asan/msan/tsan.
|
||||||
|
deps = [
|
||||||
|
"//jax:experimental_sparse",
|
||||||
|
"//jax:sparse_test_util",
|
||||||
|
] + py_deps("scipy"),
|
||||||
|
)
|
||||||
|
|
||||||
|
jax_test(
|
||||||
|
name = "sparse_bcoo_bcsr_test",
|
||||||
|
srcs = ["sparse_bcoo_bcsr_test.py"],
|
||||||
|
args = ["--jax_bcoo_cusparse_lowering=true"],
|
||||||
|
backend_tags = {
|
||||||
|
"cpu": [
|
||||||
|
"nomsan", # Times out
|
||||||
|
"notsan", # Times out
|
||||||
|
],
|
||||||
|
"tpu": ["optonly"],
|
||||||
|
},
|
||||||
|
# Use fewer cases to prevent timeouts.
|
||||||
|
backend_variant_args = {
|
||||||
|
"cpu": ["--jax_num_generated_cases=40"],
|
||||||
|
"cpu_x32": ["--jax_num_generated_cases=40"],
|
||||||
"gpu": ["--jax_num_generated_cases=40"],
|
"gpu": ["--jax_num_generated_cases=40"],
|
||||||
},
|
},
|
||||||
shard_count = {
|
shard_count = {
|
||||||
|
1958
tests/sparse_bcoo_bcsr_test.py
Normal file
1958
tests/sparse_bcoo_bcsr_test.py
Normal file
File diff suppressed because it is too large
Load Diff
1936
tests/sparse_test.py
1936
tests/sparse_test.py
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user