mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8584 from jakevdp:fix-sum-duplicates
PiperOrigin-RevId: 410900403
This commit is contained in:
commit
d42255486b
@ -24,7 +24,6 @@ from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
from jax.errors import NonConcreteBooleanIndexError
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
@ -35,6 +34,7 @@ from jax._src.api_util import flatten_axes
|
||||
from jax._src.lax.lax import (
|
||||
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
|
||||
DotDimensionNumbers)
|
||||
from jax._src.numpy.lax_numpy import _unique
|
||||
from . import ops
|
||||
|
||||
Dtype = Any
|
||||
@ -72,29 +72,41 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
|
||||
return mask.max()
|
||||
|
||||
def _bcoo_sum_duplicates(data, indices, shape, nse=None):
|
||||
props = _validate_bcoo(data, indices, shape)
|
||||
f = functools.partial(_bcoo_sum_duplicates_unbatched,
|
||||
shape=shape[props.n_batch:props.n_batch + props.n_sparse], nse=nse)
|
||||
for _ in range(props.n_batch):
|
||||
f = broadcasting_vmap(f)
|
||||
return f(data, indices)
|
||||
|
||||
def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse):
|
||||
assert indices.shape == (data.shape[0], len(shape))
|
||||
if indices.shape[1] == 0:
|
||||
return data, indices
|
||||
try:
|
||||
indices_unique, inv_idx = jnp.unique(indices, axis=0, return_inverse=True,
|
||||
size=nse, fill_value=jnp.array(shape))
|
||||
except NonConcreteBooleanIndexError:
|
||||
if nse is None and isinstance(jnp.array(0), core.Tracer):
|
||||
raise ValueError("When used with JIT, vmap, or another transform, sum_duplicates() "
|
||||
"requires passing a non-None value for the nse argument.")
|
||||
data_shape = [indices_unique.shape[0], *data.shape[1:]]
|
||||
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
|
||||
oob_mask = jnp.all(indices_unique == jnp.array(shape), 1)
|
||||
data_unique = jnp.where(oob_mask, 0, data_unique)
|
||||
props = _validate_bcoo(data, indices, shape)
|
||||
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=shape[props.n_batch:], nse=nse)
|
||||
for _ in range(props.n_batch):
|
||||
f = broadcasting_vmap(f)
|
||||
data_unique, indices_unique, nse_out = f(data, indices)
|
||||
if nse is None:
|
||||
nse = jnp.max(nse_out)
|
||||
data_unique = lax.slice_in_dim(data_unique, 0, nse, axis=props.n_batch)
|
||||
indices_unique = lax.slice_in_dim(indices_unique, 0, nse, axis=props.n_batch)
|
||||
return data_unique, indices_unique
|
||||
|
||||
def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse):
|
||||
props = _validate_bcoo(data, indices, shape)
|
||||
if not props.n_sparse:
|
||||
nse = 1 if nse is None else nse
|
||||
data_unique = jnp.zeros_like(data, shape=(nse, *data.shape[1:])).at[0].set(data.sum(0))
|
||||
indices_unique = jnp.zeros_like(indices, shape=(nse, 0))
|
||||
return data_unique, indices_unique, nse
|
||||
if nse is None:
|
||||
indices_unique, inv_idx, nse = _unique(
|
||||
indices, axis=0, return_inverse=True, return_true_size=True,
|
||||
size=props.nse, fill_value=jnp.array(shape[:props.n_sparse]))
|
||||
else:
|
||||
indices_unique, inv_idx = jnp.unique(
|
||||
indices, axis=0, return_inverse=True, size=nse,
|
||||
fill_value=jnp.array(shape[:props.n_sparse]))
|
||||
data_shape = [indices_unique.shape[0], *data.shape[1:]]
|
||||
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
|
||||
oob_mask = jnp.all(indices_unique == jnp.array(shape[:props.n_sparse]), 1)
|
||||
data_unique = jnp.where(oob_mask[(...,) + props.n_dense * (None,)], 0, data_unique)
|
||||
return data_unique, indices_unique, nse
|
||||
|
||||
def _unbatch_bcoo(data, indices, shape):
|
||||
n_batch = _validate_bcoo(data, indices, shape).n_batch
|
||||
if n_batch == 0:
|
||||
|
@ -1159,7 +1159,7 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse):
|
||||
rng = self.rng()
|
||||
rng_sparse = rand_sparse(self.rng())
|
||||
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype))
|
||||
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
|
||||
for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
|
||||
M.indices = M.indices.at[..., i, :].set(rng.randint(0, s, size=M.indices.shape[-1]))
|
||||
dedupe = partial(M.sum_duplicates, nse=nse)
|
||||
|
Loading…
x
Reference in New Issue
Block a user