mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13862 from jakevdp:bcoo-extract
PiperOrigin-RevId: 501826918
This commit is contained in:
commit
9c418e9399
@ -381,29 +381,37 @@ mlir.register_lowering(bcoo_fromdense_p, mlir.lower_fun(
|
||||
|
||||
bcoo_extract_p = core.Primitive('bcoo_extract')
|
||||
|
||||
def bcoo_extract(indices: Array, mat: Array) -> Array:
|
||||
def bcoo_extract(indices: Array, mat: Array, *, assume_unique=True) -> Array:
|
||||
"""Extract BCOO data values from a dense matrix at given BCOO indices.
|
||||
|
||||
Args:
|
||||
indices: An ndarray; see BCOO indices.
|
||||
mat: A dense matrix.
|
||||
assume_unique: bool, default=True
|
||||
If True, then indices will be assumed unique and a value will be extracted
|
||||
from mat for each index. Otherwise, extra work will be done to de-duplicate
|
||||
indices to zero-out duplicate extracted values.
|
||||
|
||||
Returns:
|
||||
An ndarray; see BCOO data.
|
||||
"""
|
||||
return bcoo_extract_p.bind(indices, mat)
|
||||
return bcoo_extract_p.bind(indices, mat, assume_unique=assume_unique)
|
||||
|
||||
@bcoo_extract_p.def_impl
|
||||
def _bcoo_extract_impl(indices, mat):
|
||||
def _bcoo_extract_impl(indices, mat, *, assume_unique):
|
||||
mat = jnp.asarray(mat)
|
||||
n_batch, n_sparse, _, nse = _validate_bcoo_indices(indices, mat.shape)
|
||||
props = _validate_bcoo_indices(indices, mat.shape)
|
||||
if not assume_unique:
|
||||
indices, sort_ind = _unique_indices(indices, shape=mat.shape, return_index=True)
|
||||
original_props = props
|
||||
props = _validate_bcoo_indices(indices, mat.shape)
|
||||
|
||||
ind_slices = tuple(np.zeros(s, int) if i_s == 1 else np.arange(s)
|
||||
for s, i_s in zip(mat.shape[:n_batch], indices.shape[:n_batch]))
|
||||
for s, i_s in zip(mat.shape[:props.n_batch], indices.shape[:props.n_batch]))
|
||||
grid = tuple(np.meshgrid(*ind_slices, indexing='ij', sparse=True))
|
||||
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(n_sparse))
|
||||
sparse_ind = tuple(indices[grid + (slice(None), i)] for i in range(props.n_sparse))
|
||||
|
||||
batch_slices = tuple(np.arange(s) for s in mat.shape[:n_batch])
|
||||
batch_slices = tuple(np.arange(s) for s in mat.shape[:props.n_batch])
|
||||
grid = np.meshgrid(*batch_slices, np.arange(1), indexing='ij', sparse=True)
|
||||
batch_ind = tuple(grid)[:-1]
|
||||
|
||||
@ -411,29 +419,44 @@ def _bcoo_extract_impl(indices, mat):
|
||||
result = mat[None]
|
||||
else:
|
||||
result = mat.at[batch_ind + sparse_ind].get(mode='fill', fill_value=0)
|
||||
if n_sparse == 0 and nse != 1:
|
||||
result = lax.broadcast_in_dim(
|
||||
result, _tuple_replace(result.shape, n_batch, nse), range(result.ndim))
|
||||
if props.n_sparse == 0 and props.nse != 1:
|
||||
if assume_unique:
|
||||
result = lax.broadcast_in_dim(
|
||||
result, _tuple_replace(result.shape, props.n_batch, props.nse), range(result.ndim))
|
||||
else:
|
||||
out_shape = _tuple_replace(result.shape, props.n_batch, original_props.nse)
|
||||
ind = props.n_batch * (slice(None),) + (slice(1),)
|
||||
result = jnp.zeros_like(result, shape=out_shape).at[ind].set(result)
|
||||
if not assume_unique:
|
||||
unbatched_out_shape = (original_props.nse, *result.shape[props.n_batch + 1:])
|
||||
def f(r, i):
|
||||
return jnp.zeros_like(r, shape=unbatched_out_shape).at[i].add(r)
|
||||
for _ in range(props.n_batch):
|
||||
f = vmap(f)
|
||||
result = f(result, sort_ind)
|
||||
return result
|
||||
|
||||
@bcoo_extract_p.def_abstract_eval
|
||||
def _bcoo_extract_abstract_eval(indices, mat):
|
||||
def _bcoo_extract_abstract_eval(indices, mat, *, assume_unique):
|
||||
_ = bool(assume_unique)
|
||||
n_batch, _, n_dense, nse = _validate_bcoo_indices(indices, mat.shape)
|
||||
out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
|
||||
return core.ShapedArray(out_shape, mat.dtype)
|
||||
|
||||
def _bcoo_extract_jvp(mat_dot, indices, mat):
|
||||
def _bcoo_extract_jvp(mat_dot, indices, mat, *, assume_unique):
|
||||
assert mat_dot.shape == mat.shape
|
||||
return bcoo_extract(indices, mat_dot)
|
||||
return bcoo_extract(indices, mat_dot, assume_unique=assume_unique)
|
||||
|
||||
def _bcoo_extract_transpose(ct, indices, mat):
|
||||
def _bcoo_extract_transpose(ct, indices, mat, *, assume_unique):
|
||||
if not assume_unique:
|
||||
raise NotImplementedError("transpose of bcoo_extract with assume_unique=False")
|
||||
assert ad.is_undefined_primal(mat)
|
||||
if ad.is_undefined_primal(indices):
|
||||
raise ValueError("Cannot transpose with respect to sparse indices")
|
||||
assert ct.dtype == mat.aval.dtype
|
||||
return indices, _bcoo_todense(ct, indices, spinfo=SparseInfo(mat.aval.shape))
|
||||
|
||||
def _bcoo_extract_batching_rule(batched_args, batch_dims):
|
||||
def _bcoo_extract_batching_rule(batched_args, batch_dims, *, assume_unique):
|
||||
indices, mat = batched_args
|
||||
assert any(b is not None for b in batch_dims)
|
||||
if batch_dims[0] is None:
|
||||
@ -452,7 +475,7 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims):
|
||||
n_batch = indices.ndim - 2
|
||||
if bdim >= n_batch:
|
||||
raise ValueError(f"{batch_dims=} out of range for indices with {n_batch=}")
|
||||
return bcoo_extract(indices, mat), bdim
|
||||
return bcoo_extract(indices, mat, assume_unique=assume_unique), bdim
|
||||
|
||||
ad.defjvp(bcoo_extract_p, None, _bcoo_extract_jvp)
|
||||
ad.primitive_transposes[bcoo_extract_p] = _bcoo_extract_transpose
|
||||
@ -1068,7 +1091,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
|
||||
B_shape = B.aval.shape if hasattr(B, 'aval') else B.shape
|
||||
mat_shape = _dot_general_validated_shape(A_shape, B_shape, dimension_numbers)
|
||||
mat = ad.UndefinedPrimal(core.ShapedArray(mat_shape, ct.dtype))
|
||||
indices, ct = _bcoo_extract_transpose(ct, indices, mat)
|
||||
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
|
||||
kwds = {'dimension_numbers': dimension_numbers,
|
||||
'precision': None,
|
||||
'preferred_element_type': None}
|
||||
@ -1397,9 +1420,8 @@ def _bcoo_sum_duplicates(data: Array, indices: Array, *, spinfo: SparseInfo, nse
|
||||
@bcoo_sum_duplicates_p.def_impl
|
||||
def _bcoo_sum_duplicates_impl(data, indices, *, spinfo, nse):
|
||||
props = _validate_bcoo(data, indices, spinfo.shape)
|
||||
f = nfold_vmap(functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:]),
|
||||
N=props.n_batch, broadcasted=False)
|
||||
indices_out, mapping, nse_batched = f(indices)
|
||||
indices_out, mapping, nse_batched = _unique_indices(
|
||||
indices, shape=spinfo.shape, return_inverse=True, return_true_size=True)
|
||||
if nse is None:
|
||||
nse = 1 if props.n_sparse == 0 else nse_batched.max()
|
||||
indices_out = _adjust_indices_nse(indices_out, nse=nse, shape=spinfo.shape)
|
||||
@ -1425,22 +1447,40 @@ def _adjust_indices_nse(indices, *, nse, shape):
|
||||
indices = lax.concatenate([indices, fill], dimension=indices.ndim - 2)
|
||||
return indices
|
||||
|
||||
def _bcoo_sum_duplicates_unbatched(indices, *, shape):
|
||||
def _unique_indices(indices, *, shape, return_inverse=False,
|
||||
return_index=False, return_true_size=False):
|
||||
props = _validate_bcoo_indices(indices, shape)
|
||||
f = partial(_unique_indices_unbatched, shape=shape[props.n_batch:],
|
||||
return_inverse=return_inverse, return_index=return_index,
|
||||
return_true_size=return_true_size)
|
||||
f = nfold_vmap(f, props.n_batch, broadcasted=False)
|
||||
return f(indices)
|
||||
|
||||
def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
|
||||
return_index=False, return_true_size=False):
|
||||
props = _validate_bcoo_indices(indices, shape)
|
||||
if props.n_sparse == 0:
|
||||
nse = 1
|
||||
mapping = jnp.zeros(nse, dtype='int32')
|
||||
indices_out = jnp.zeros_like(indices, shape=(nse, props.n_sparse))
|
||||
return indices_out, mapping, nse
|
||||
indices_out = jnp.zeros_like(indices, shape=(nse, 0))
|
||||
out = (indices_out,)
|
||||
if return_index:
|
||||
out = (*out, jnp.zeros(nse, dtype='int32'))
|
||||
if return_inverse:
|
||||
out = (*out, jnp.zeros(nse, dtype='int32'))
|
||||
if return_true_size:
|
||||
out = (*out, nse)
|
||||
return out[0] if len(out) == 1 else out
|
||||
fill_value = jnp.expand_dims(jnp.array(shape[:props.n_sparse], dtype=indices.dtype), (0,))
|
||||
out_of_bounds = (indices >= fill_value).any(-1, keepdims=True)
|
||||
indices = jnp.where(out_of_bounds, fill_value, indices)
|
||||
# TODO: check if `indices_sorted` is True.
|
||||
indices_unique, inv_idx, nse = _unique(
|
||||
indices, axis=0, return_inverse=True, return_true_size=True,
|
||||
size=props.nse, fill_value=fill_value)
|
||||
nse = nse - (indices == fill_value).any().astype(nse.dtype)
|
||||
return indices_unique, inv_idx, nse
|
||||
out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index,
|
||||
return_true_size=return_true_size, size=props.nse, fill_value=fill_value)
|
||||
if return_true_size:
|
||||
nse = out[-1]
|
||||
nse = nse - (indices == fill_value).any().astype(nse.dtype)
|
||||
out = (*out[:-1], nse)
|
||||
return out
|
||||
|
||||
@bcoo_sum_duplicates_p.def_abstract_eval
|
||||
def _bcoo_sum_duplicates_abstract_eval(data, indices, *, spinfo, nse):
|
||||
@ -1472,8 +1512,8 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):
|
||||
|
||||
data, indices = primals
|
||||
data_dot, _ = tangents
|
||||
f = nfold_vmap(functools.partial(_bcoo_sum_duplicates_unbatched, shape=spinfo.shape[props.n_batch:]), props.n_batch)
|
||||
indices_out, mapping, nse_batched = f(indices)
|
||||
indices_out, mapping, nse_batched = _unique_indices(
|
||||
indices, shape=spinfo.shape, return_inverse=True, return_true_size=True)
|
||||
if nse is None:
|
||||
nse = jnp.sum(nse_batched)
|
||||
try:
|
||||
|
@ -804,18 +804,46 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for layout in iter_sparse_layouts(shape)],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
assume_unique=[True, False]
|
||||
)
|
||||
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense):
|
||||
def test_bcoo_extract(self, shape, dtype, n_batch, n_dense, assume_unique):
|
||||
rng = rand_sparse(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
nse = sparse.util._count_stored_elements(M, n_batch=n_batch,
|
||||
n_dense=n_dense)
|
||||
data, indices = sparse_bcoo._bcoo_fromdense(M, nse=nse)
|
||||
data2 = sparse.bcoo_extract(indices, M)
|
||||
bcoo_extract = partial(sparse.bcoo_extract, assume_unique=assume_unique)
|
||||
|
||||
data2 = bcoo_extract(indices, M)
|
||||
self.assertArraysEqual(data, data2)
|
||||
data3 = jit(sparse.bcoo_extract)(indices, M)
|
||||
|
||||
data3 = jit(bcoo_extract)(indices, M)
|
||||
self.assertArraysEqual(data, data3)
|
||||
|
||||
def test_bcoo_extract_duplicate_indices(self):
|
||||
data = jnp.array([1, 3, 9, 27, 81, 243])
|
||||
indices = jnp.array([[0], [5], [0], [3], [2], [3]])
|
||||
shape = (6,)
|
||||
mat = sparse.BCOO((data, indices), shape=shape).todense()
|
||||
|
||||
data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
|
||||
self.assertArraysEqual(data1, jnp.array([10, 3, 10, 270, 81, 270]))
|
||||
|
||||
data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
|
||||
self.assertArraysEqual(data2, jnp.array([10, 3, 0, 270, 81, 0]))
|
||||
|
||||
def test_bcoo_extract_duplicate_indices_n_sparse_0(self):
|
||||
data = jnp.arange(6).reshape(3, 2)
|
||||
indices = jnp.empty((3, 2, 0), dtype=int)
|
||||
shape = (3,)
|
||||
mat = sparse.BCOO((data, indices), shape=shape).todense()
|
||||
|
||||
data1 = sparse.bcoo_extract(indices, mat, assume_unique=True)
|
||||
self.assertArraysEqual(data1, jnp.array([[1, 1], [5, 5], [9, 9]]))
|
||||
|
||||
data2 = sparse.bcoo_extract(indices, mat, assume_unique=False)
|
||||
self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]]))
|
||||
|
||||
def test_bcoo_extract_batching(self):
|
||||
# https://github.com/google/jax/issues/9431
|
||||
indices = jnp.zeros((4, 1, 1), dtype=int)
|
||||
|
Loading…
x
Reference in New Issue
Block a user