mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13869 from jakevdp:bcoo-extract-api
PiperOrigin-RevId: 501842123
This commit is contained in:
commit
d1593289a0
@ -85,12 +85,13 @@ def _todense_transpose(ct, *bufs, tree):
|
||||
|
||||
standin = object()
|
||||
obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
|
||||
from jax.experimental.sparse import BCOO, bcoo_extract
|
||||
from jax.experimental.sparse import BCOO
|
||||
from jax.experimental.sparse.bcoo import _bcoo_extract
|
||||
if obj is standin:
|
||||
return (ct,)
|
||||
elif isinstance(obj, BCOO):
|
||||
_, indices = bufs
|
||||
return bcoo_extract(indices, ct), indices
|
||||
return _bcoo_extract(indices, ct), indices
|
||||
elif isinstance(obj, COO):
|
||||
_, row, col = bufs
|
||||
return _coo_extract(row, col, ct), row, col
|
||||
|
@ -236,7 +236,7 @@ def _bcoo_todense_transpose(ct, data, indices, *, spinfo):
|
||||
raise ValueError("Cannot transpose with respect to sparse indices")
|
||||
assert ct.shape == shape
|
||||
assert ct.dtype == data.aval.dtype
|
||||
return bcoo_extract(indices, ct), indices
|
||||
return _bcoo_extract(indices, ct), indices
|
||||
|
||||
def _bcoo_todense_batching_rule(batched_args, batch_dims, *, spinfo):
|
||||
data, indices, spinfo = _bcoo_batch_dims_to_front(batched_args, batch_dims, spinfo)
|
||||
@ -320,7 +320,7 @@ def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
|
||||
indices = jnp.zeros(mask.shape[:n_batch] + (nse, 0), index_dtype)
|
||||
else:
|
||||
indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch + 1)
|
||||
data = bcoo_extract(indices, mat)
|
||||
data = _bcoo_extract(indices, mat)
|
||||
|
||||
true_nse = mask.sum(list(range(n_batch, mask.ndim)))[..., None]
|
||||
true_nonzeros = lax.broadcasted_iota(true_nse.dtype, (1,) * n_batch + (nse,), n_batch) < true_nse
|
||||
@ -346,7 +346,7 @@ def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype
|
||||
if type(Mdot) is ad.Zero:
|
||||
data_dot = ad.Zero.from_value(data)
|
||||
else:
|
||||
data_dot = bcoo_extract(indices, Mdot)
|
||||
data_dot = _bcoo_extract(indices, Mdot)
|
||||
|
||||
tangents_out = (data_dot, ad.Zero.from_value(indices))
|
||||
|
||||
@ -381,44 +381,70 @@ mlir.register_lowering(bcoo_fromdense_p, mlir.lower_fun(
|
||||
|
||||
bcoo_extract_p = core.Primitive('bcoo_extract')
|
||||
|
||||
def bcoo_extract(indices: Array, mat: Array, *, assume_unique=True) -> Array:
|
||||
"""Extract BCOO data values from a dense matrix at given BCOO indices.
|
||||
|
||||
def bcoo_extract(sparr: BCOO, arr: ArrayLike, *, assume_unique: Optional[bool] = None) -> BCOO:
|
||||
"""Extract values from a dense array according to the sparse array's indices.
|
||||
|
||||
Args:
|
||||
sparr : BCOO array whose indices will be used for the output.
|
||||
arr : ArrayLike with shape equal to self.shape
|
||||
assume_unique : bool, defaults to sparr.unique_indices
|
||||
If True, extract values for every index, even if index contains duplicates.
|
||||
If False, duplicate indices will have their values summed and returned in
|
||||
the position of the first index.
|
||||
|
||||
Returns:
|
||||
extracted : a BCOO array with the same sparsity pattern as self.
|
||||
"""
|
||||
if not isinstance(sparr, BCOO):
|
||||
raise ValueError(f"First argument to bcoo_extract should be a BCOO array. Got {type(sparr)=}")
|
||||
arr = jnp.asarray(arr)
|
||||
if arr.shape != sparr.shape:
|
||||
raise ValueError(f"shape mismatch: {sparr.shape=} {arr.shape=}")
|
||||
if assume_unique is None:
|
||||
assume_unique = sparr.unique_indices
|
||||
data = _bcoo_extract(sparr.indices, arr, assume_unique=assume_unique)
|
||||
return BCOO((data, sparr.indices), **sparr._info._asdict())
|
||||
|
||||
|
||||
def _bcoo_extract(indices: Array, arr: Array, *, assume_unique=True) -> Array:
|
||||
"""Extract BCOO data values from a dense array at given BCOO indices.
|
||||
|
||||
Args:
|
||||
indices: An ndarray; see BCOO indices.
|
||||
mat: A dense matrix.
|
||||
arr: A dense array.
|
||||
assume_unique: bool, default=True
|
||||
If True, then indices will be assumed unique and a value will be extracted
|
||||
from mat for each index. Otherwise, extra work will be done to de-duplicate
|
||||
from arr for each index. Otherwise, extra work will be done to de-duplicate
|
||||
indices to zero-out duplicate extracted values.
|
||||
|
||||
Returns:
|
||||
An ndarray; see BCOO data.
|
||||
"""
|
||||
return bcoo_extract_p.bind(indices, mat, assume_unique=assume_unique)
|
||||
return bcoo_extract_p.bind(indices, arr, assume_unique=assume_unique)
|
||||
|
||||
@bcoo_extract_p.def_impl
|
||||
def _bcoo_extract_impl(indices, mat, *, assume_unique):
|
||||
mat = jnp.asarray(mat)
|
||||
props = _validate_bcoo_indices(indices, mat.shape)
|
||||
def _bcoo_extract_impl(indices, arr, *, assume_unique):
|
||||
arr = jnp.asarray(arr)
|
||||
props = _validate_bcoo_indices(indices, arr.shape)
|
||||
if not assume_unique:
|
||||
indices, sort_ind = _unique_indices(indices, shape=mat.shape, return_index=True)
|
||||
indices, sort_ind = _unique_indices(indices, shape=arr.shape, return_index=True)
|
||||
original_props = props
|
||||
props = _validate_bcoo_indices(indices, mat.shape)
|
||||
props = _validate_bcoo_indices(indices, arr.shape)
|
||||
|
||||
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
|
||||
for s, i_s in zip(mat.shape[:props.n_batch], indices.shape[:props.n_batch]))
|
||||
for s, i_s in zip(arr.shape[:props.n_batch], indices.shape[:props.n_batch]))
|
||||
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
|
||||
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(props.n_sparse))
|
||||
|
||||
batch_slices = tuple(np.arange(s) for s in mat.shape[:props.n_batch])
|
||||
batch_slices = tuple(np.arange(s) for s in arr.shape[:props.n_batch])
|
||||
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
|
||||
batch_ind = tuple(grid)[:-1]
|
||||
|
||||
if not sparse_ind + batch_ind:
|
||||
result = mat[None]
|
||||
result = arr[None]
|
||||
else:
|
||||
result = mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
|
||||
result = arr.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
|
||||
if props.n_sparse == 0 and props.nse != 1:
|
||||
if assume_unique:
|
||||
result = lax.broadcast_in_dim(
|
||||
@ -437,27 +463,27 @@ def _bcoo_extract_impl(indices, mat, *, assume_unique):
|
||||
return result
|
||||
|
||||
@bcoo_extract_p.def_abstract_eval
|
||||
def _bcoo_extract_abstract_eval(indices, mat, *, assume_unique):
|
||||
def _bcoo_extract_abstract_eval(indices, arr, *, assume_unique):
|
||||
_ = bool(assume_unique)
|
||||
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, mat.shape)
|
||||
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
|
||||
return core.ShapedArray(out_shape, mat.dtype)
|
||||
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, arr.shape)
|
||||
out_shape = arr.shape[:n_batch] + (nse,) + arr.shape[arr.ndim - n_dense:]
|
||||
return core.ShapedArray(out_shape, arr.dtype)
|
||||
|
||||
def _bcoo_extract_jvp(mat_dot, indices, mat, *, assume_unique):
|
||||
assert mat_dot.shape == mat.shape
|
||||
return bcoo_extract(indices, mat_dot, assume_unique=assume_unique)
|
||||
def _bcoo_extract_jvp(arr_dot, indices, arr, *, assume_unique):
|
||||
assert arr_dot.shape == arr.shape
|
||||
return _bcoo_extract(indices, arr_dot, assume_unique=assume_unique)
|
||||
|
||||
def _bcoo_extract_transpose(ct, indices, mat, *, assume_unique):
|
||||
def _bcoo_extract_transpose(ct, indices, arr, *, assume_unique):
|
||||
if not assume_unique:
|
||||
raise NotImplementedError("transpose of bcoo_extract with assume_unique=False")
|
||||
assert ad.is_undefined_primal(mat)
|
||||
assert ad.is_undefined_primal(arr)
|
||||
if ad.is_undefined_primal(indices):
|
||||
raise ValueError("Cannot transpose with respect to sparse indices")
|
||||
assert ct.dtype == mat.aval.dtype
|
||||
return indices, _bcoo_todense(ct, indices, spinfo=SparseInfo(mat.aval.shape))
|
||||
assert ct.dtype == arr.aval.dtype
|
||||
return indices, _bcoo_todense(ct, indices, spinfo=SparseInfo(arr.aval.shape))
|
||||
|
||||
def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
|
||||
indices, mat = batched_args
|
||||
indices, arr = batched_args
|
||||
assert any(b is not None for b in batch_dims)
|
||||
if batch_dims[0] is None:
|
||||
bdim = batch_dims[1]
|
||||
@ -465,9 +491,9 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
|
||||
elif batch_dims[1] is None:
|
||||
# TODO(jakevdp) can we handle this case without explicit broadcasting?
|
||||
bdim = batch_dims[0]
|
||||
result_shape = list(mat.shape)
|
||||
result_shape = list(arr.shape)
|
||||
result_shape.insert(bdim, indices.shape[bdim])
|
||||
mat = lax.broadcast_in_dim(mat, result_shape, (bdim,))
|
||||
arr = lax.broadcast_in_dim(arr, result_shape, (bdim,))
|
||||
else:
|
||||
if batch_dims[0] != batch_dims[1]:
|
||||
raise NotImplementedError("bcoo_extract with unequal batch dimensions.")
|
||||
@ -475,7 +501,7 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
|
||||
n_batch = indices.ndim - 2
|
||||
if bdim >= n_batch:
|
||||
raise ValueError(f"{batch_dims=} out of range for indices with {n_batch=}")
|
||||
return bcoo_extract(indices, mat, assume_unique=assume_unique), bdim
|
||||
return _bcoo_extract(indices, arr, assume_unique=assume_unique), bdim
|
||||
|
||||
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
|
||||
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
|
||||
@ -1004,7 +1030,7 @@ def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_num
|
||||
# Fallback to direct approach when above is not possible.
|
||||
out_dense_T = lax.dot_general(ct, rhs, dimension_numbers=dims)
|
||||
out_dense = lax.transpose(out_dense_T, out_axes)
|
||||
result = bcoo_extract(lhs_indices, out_dense)
|
||||
result = _bcoo_extract(lhs_indices, out_dense)
|
||||
return result, lhs_indices, rhs
|
||||
else:
|
||||
dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch)) # type: ignore[assignment]
|
||||
@ -1078,12 +1104,12 @@ def bcoo_dot_general_sampled(A: Array, B: Array, indices: Array, *, dimension_nu
|
||||
def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
|
||||
# TODO(jakevdp): use a more efficient implementation that avoids the full dot product.
|
||||
dense_result = lax.dot_general(A, B, dimension_numbers=dimension_numbers)
|
||||
return bcoo_extract(indices, dense_result)
|
||||
return _bcoo_extract(indices, dense_result)
|
||||
|
||||
@bcoo_dot_general_sampled_p.def_abstract_eval
|
||||
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
|
||||
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
|
||||
sparse_result, = pe.abstract_eval_fun(lambda *args: [bcoo_extract(*args)], indices, dense_result)
|
||||
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result)
|
||||
return sparse_result
|
||||
|
||||
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):
|
||||
@ -2221,7 +2247,7 @@ def _bcoo_multiply_dense(data: Array, indices: Array, v: Array, *, spinfo: Spars
|
||||
return lax.mul(data, v)
|
||||
if shape == v.shape:
|
||||
# Note: due to distributive property, no deduplication necessary!
|
||||
return lax.mul(data, bcoo_extract(indices, v))
|
||||
return lax.mul(data, _bcoo_extract(indices, v))
|
||||
|
||||
if lax.broadcast_shapes(v.shape, shape) != shape:
|
||||
raise NotImplementedError(
|
||||
|
@ -291,7 +291,7 @@ def bcsr_extract(indices: ArrayLike, indptr: ArrayLike, mat: ArrayLike) -> Array
|
||||
def _bcsr_extract_impl(indices, indptr, mat):
|
||||
mat = jnp.asarray(mat)
|
||||
bcoo_indices = _bcsr_to_bcoo(indices, indptr, shape=mat.shape)
|
||||
return bcoo.bcoo_extract(bcoo_indices, mat)
|
||||
return bcoo._bcoo_extract(bcoo_indices, mat)
|
||||
|
||||
|
||||
@bcsr_extract_p.def_abstract_eval
|
||||
|
@ -804,21 +804,24 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for layout in iter_sparse_layouts(shape)],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
assume_unique=[True, False]
|
||||
assume_unique=[True, False, None]
|
||||
)
|
||||
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense, assume_unique):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse)
|
||||
bcoo_extract = partial(sparse.bcoo_extract, assume_unique=assume_unique)
|
||||
|
||||
data2 = bcoo_extract(indices, M)
|
||||
self.assertArraysEqual(data, data2)
|
||||
def args_maker():
|
||||
x = rng(shape, dtype)
|
||||
x_bcoo = sparse.bcoo_fromdense(x, n_batch=n_batch, n_dense=n_dense)
|
||||
# Unique indices are required for this test when assume_unique == True.
|
||||
self.assertTrue(x_bcoo.unique_indices)
|
||||
return x_bcoo, x
|
||||
|
||||
data3 = jit(bcoo_extract)(indices, M)
|
||||
self.assertArraysEqual(data, data3)
|
||||
dense_op = lambda _, x: x
|
||||
sparse_op = partial(sparse.bcoo_extract, assume_unique=assume_unique)
|
||||
|
||||
self._CheckAgainstDense(dense_op, sparse_op, args_maker)
|
||||
self._CompileAndCheckSparse(sparse_op, args_maker)
|
||||
self._CheckBatchingSparse(dense_op, sparse_op, args_maker, bdims=2 * self._random_bdims(n_batch))
|
||||
|
||||
def test_bcoo_extract_duplicate_indices(self):
|
||||
data = jnp.array([1, 3, 9, 27, 81, 243])
|
||||
@ -826,10 +829,10 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
shape = (6,)
|
||||
mat = sparse.BCOO((data, indices), shape=shape).todense()
|
||||
|
||||
data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
|
||||
data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True)
|
||||
self.assertArraysEqual(data1, jnp.array([10, 3, 10, 270, 81, 270]))
|
||||
|
||||
data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
|
||||
data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False)
|
||||
self.assertArraysEqual(data2, jnp.array([10, 3, 0, 270, 81, 0]))
|
||||
|
||||
def test_bcoo_extract_duplicate_indices_n_sparse_0(self):
|
||||
@ -838,10 +841,10 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
shape = (3,)
|
||||
mat = sparse.BCOO((data, indices), shape=shape).todense()
|
||||
|
||||
data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
|
||||
data1 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=True)
|
||||
self.assertArraysEqual(data1, jnp.array([[1, 1], [5, 5], [9, 9]]))
|
||||
|
||||
data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
|
||||
data2 = sparse_bcoo._bcoo_extract(indices, mat, assume_unique=False)
|
||||
self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]]))
|
||||
|
||||
def test_bcoo_extract_batching(self):
|
||||
@ -850,18 +853,18 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
mat = jnp.arange(4.).reshape((4, 1))
|
||||
|
||||
# in_axes = (0, None)
|
||||
expected = jnp.vstack([sparse.bcoo_extract(i, mat[0]) for i in indices])
|
||||
actual = vmap(sparse.bcoo_extract, in_axes=(0, None))(indices, mat[0])
|
||||
expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, mat[0]) for i in indices])
|
||||
actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(0, None))(indices, mat[0])
|
||||
self.assertArraysEqual(expected, actual)
|
||||
|
||||
# in_axes = (None, 0)
|
||||
expected = jnp.vstack([sparse.bcoo_extract(indices[0], m) for m in mat])
|
||||
actual = vmap(sparse.bcoo_extract, in_axes=(None, 0))(indices[0], mat)
|
||||
expected = jnp.vstack([sparse_bcoo._bcoo_extract(indices[0], m) for m in mat])
|
||||
actual = vmap(sparse_bcoo._bcoo_extract, in_axes=(None, 0))(indices[0], mat)
|
||||
self.assertArraysEqual(expected, actual)
|
||||
|
||||
# in_axes = (0, 0)
|
||||
expected = jnp.vstack([sparse.bcoo_extract(i, m) for i, m in zip(indices, mat)])
|
||||
actual = vmap(sparse.bcoo_extract, in_axes=0)(indices, mat)
|
||||
expected = jnp.vstack([sparse_bcoo._bcoo_extract(i, m) for i, m in zip(indices, mat)])
|
||||
actual = vmap(sparse_bcoo._bcoo_extract, in_axes=0)(indices, mat)
|
||||
self.assertArraysEqual(expected, actual)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -877,7 +880,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
n_dense=n_dense)
|
||||
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense)
|
||||
|
||||
extract = partial(sparse.bcoo_extract, indices)
|
||||
extract = partial(sparse_bcoo._bcoo_extract, indices)
|
||||
j1 = jax.jacfwd(extract)(M)
|
||||
j2 = jax.jacrev(extract)(M)
|
||||
hess = jax.hessian(extract)(M)
|
||||
@ -890,11 +893,11 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
# (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2
|
||||
args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3))
|
||||
self._CompileAndCheck(sparse.bcoo_extract, args_maker)
|
||||
self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker)
|
||||
|
||||
# (n_batch, n_sparse, n_dense) = (0, 0, 1), nse = 2
|
||||
args_maker = lambda: (jnp.zeros((2, 0), dtype='int32'), jnp.arange(3))
|
||||
self._CompileAndCheck(sparse.bcoo_extract, args_maker)
|
||||
self._CompileAndCheck(sparse_bcoo._bcoo_extract, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense)
|
||||
@ -1252,7 +1255,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
def dense_fun(lhs, rhs, indices):
|
||||
AB = lax.dot_general(lhs, rhs, dimension_numbers=props.dimension_numbers)
|
||||
return sparse.bcoo_extract(indices, AB)
|
||||
return sparse_bcoo._bcoo_extract(indices, AB)
|
||||
def sparse_fun(lhs, rhs, indices):
|
||||
return sparse.bcoo_dot_general_sampled(
|
||||
lhs, rhs, indices, dimension_numbers=props.dimension_numbers)
|
||||
@ -1295,7 +1298,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
|
||||
def dense_fun(lhs, rhs, indices):
|
||||
AB = lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
return sparse.bcoo_extract(indices, AB)
|
||||
return sparse_bcoo._bcoo_extract(indices, AB)
|
||||
def sparse_fun(lhs, rhs, indices):
|
||||
return sparse.bcoo_dot_general_sampled(
|
||||
lhs, rhs, indices, dimension_numbers=dimension_numbers)
|
||||
@ -2171,7 +2174,7 @@ class SparseGradTest(sptu.SparseTestCase):
|
||||
val_sp, grad_sp = sparse.value_and_grad(f, argnums=0, has_aux=has_aux)(Xsp, y)
|
||||
self.assertIsInstance(grad_sp, sparse.BCOO)
|
||||
self.assertAllClose(val_de, val_sp)
|
||||
self.assertAllClose(grad_sp.data, sparse.bcoo_extract(grad_sp.indices, grad_de))
|
||||
self.assertAllClose(grad_sp.data, sparse_bcoo._bcoo_extract(grad_sp.indices, grad_de))
|
||||
|
||||
with self.subTest("wrt dense"):
|
||||
self.assertAllClose(jax.value_and_grad(f, argnums=1, has_aux=has_aux)(X, y),
|
||||
@ -2199,7 +2202,7 @@ class SparseGradTest(sptu.SparseTestCase):
|
||||
grad_sp, aux_sp = grad_sp
|
||||
self.assertAllClose(aux_de, aux_sp)
|
||||
self.assertIsInstance(grad_sp, sparse.BCOO)
|
||||
self.assertAllClose(grad_sp.data, sparse.bcoo_extract(grad_sp.indices, grad_de))
|
||||
self.assertAllClose(grad_sp.data, sparse_bcoo._bcoo_extract(grad_sp.indices, grad_de))
|
||||
|
||||
with self.subTest("wrt dense"):
|
||||
self.assertAllClose(jax.grad(f, argnums=1, has_aux=has_aux)(X, y),
|
||||
@ -2233,7 +2236,7 @@ class SparseGradTest(sptu.SparseTestCase):
|
||||
grad_sp, aux_sp = grad_sp
|
||||
self.assertAllClose(aux_de, aux_sp)
|
||||
self.assertIsInstance(grad_sp, sparse.BCOO)
|
||||
self.assertAllClose(grad_sp.data, sparse.bcoo_extract(grad_sp.indices, grad_de))
|
||||
self.assertAllClose(grad_sp.data, sparse_bcoo._bcoo_extract(grad_sp.indices, grad_de))
|
||||
|
||||
with self.subTest("wrt dense"):
|
||||
rtol = 0.01 if jtu.device_under_test() == 'tpu' else None
|
||||
|
Loading…
x
Reference in New Issue
Block a user