Merge pull request #8584 from jakevdp:fix-sum-duplicates

PiperOrigin-RevId: 410900403
This commit is contained in:
jax authors 2021-11-18 14:31:31 -08:00
commit d42255486b
2 changed files with 33 additions and 21 deletions

View File

@ -24,7 +24,6 @@ from jax import core
from jax import lax
from jax import tree_util
from jax import vmap
from jax.errors import NonConcreteBooleanIndexError
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
@ -35,6 +34,7 @@ from jax._src.api_util import flatten_axes
from jax._src.lax.lax import (
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from jax._src.numpy.lax_numpy import _unique
from . import ops
Dtype = Any
@ -72,29 +72,41 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
return mask.max()
def _bcoo_sum_duplicates(data, indices, shape, nse=None):
props = _validate_bcoo(data, indices, shape)
f = functools.partial(_bcoo_sum_duplicates_unbatched,
shape=shape[props.n_batch:props.n_batch + props.n_sparse], nse=nse)
for _ in range(props.n_batch):
f = broadcasting_vmap(f)
return f(data, indices)
def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse):
assert indices.shape == (data.shape[0], len(shape))
if indices.shape[1] == 0:
return data, indices
try:
indices_unique, inv_idx = jnp.unique(indices, axis=0, return_inverse=True,
size=nse, fill_value=jnp.array(shape))
except NonConcreteBooleanIndexError:
if nse is None and isinstance(jnp.array(0), core.Tracer):
raise ValueError("When used with JIT, vmap, or another transform, sum_duplicates() "
"requires passing a non-None value for the nse argument.")
data_shape = [indices_unique.shape[0], *data.shape[1:]]
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
oob_mask = jnp.all(indices_unique == jnp.array(shape), 1)
data_unique = jnp.where(oob_mask, 0, data_unique)
props = _validate_bcoo(data, indices, shape)
f = functools.partial(_bcoo_sum_duplicates_unbatched, shape=shape[props.n_batch:], nse=nse)
for _ in range(props.n_batch):
f = broadcasting_vmap(f)
data_unique, indices_unique, nse_out = f(data, indices)
if nse is None:
nse = jnp.max(nse_out)
data_unique = lax.slice_in_dim(data_unique, 0, nse, axis=props.n_batch)
indices_unique = lax.slice_in_dim(indices_unique, 0, nse, axis=props.n_batch)
return data_unique, indices_unique
def _bcoo_sum_duplicates_unbatched(data, indices, *, shape, nse):
props = _validate_bcoo(data, indices, shape)
if not props.n_sparse:
nse = 1 if nse is None else nse
data_unique = jnp.zeros_like(data, shape=(nse, *data.shape[1:])).at[0].set(data.sum(0))
indices_unique = jnp.zeros_like(indices, shape=(nse, 0))
return data_unique, indices_unique, nse
if nse is None:
indices_unique, inv_idx, nse = _unique(
indices, axis=0, return_inverse=True, return_true_size=True,
size=props.nse, fill_value=jnp.array(shape[:props.n_sparse]))
else:
indices_unique, inv_idx = jnp.unique(
indices, axis=0, return_inverse=True, size=nse,
fill_value=jnp.array(shape[:props.n_sparse]))
data_shape = [indices_unique.shape[0], *data.shape[1:]]
data_unique = jnp.zeros(data_shape, data.dtype).at[inv_idx].add(data)
oob_mask = jnp.all(indices_unique == jnp.array(shape[:props.n_sparse]), 1)
data_unique = jnp.where(oob_mask[(...,) + props.n_dense * (None,)], 0, data_unique)
return data_unique, indices_unique, nse
def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
if n_batch == 0:

View File

@ -1159,7 +1159,7 @@ class BCOOTest(jtu.JaxTestCase):
def test_bcoo_sum_duplicates(self, shape, dtype, n_batch, n_dense, nse):
rng = self.rng()
rng_sparse = rand_sparse(self.rng())
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype))
M = sparse.BCOO.fromdense(rng_sparse(shape, dtype), n_batch=n_batch, n_dense=n_dense)
for i, s in enumerate(shape[n_batch:len(shape) - n_dense]):
M.indices = M.indices.at[..., i, :].set(rng.randint(0, s, size=M.indices.shape[-1]))
dedupe = partial(M.sum_duplicates, nse=nse)