From 038798ed258bf0058288a0d981071ce144e085ac Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 1 Feb 2023 16:16:14 -0800 Subject: [PATCH] [sparse] add support for simple 1D convolutions --- jax/experimental/sparse/__init__.py | 13 +++++--- jax/experimental/sparse/bcoo.py | 49 ++++++++++++++++++++++++++++ jax/experimental/sparse/transform.py | 1 + tests/BUILD | 1 + tests/sparsify_test.py | 19 +++++++++++ 5 files changed, 78 insertions(+), 5 deletions(-) diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index b24b4b949..2c291ae80 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -130,11 +130,13 @@ Support for :func:`sparsify` includes a large number of the most common primitiv - generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`) - zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.) - zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.) -- summation reductions (:obj:`lax.reduce_sum_p`) -- general indexing operations (:obj:`lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`) -- concatenation and stacking (:obj:`lax.concatenate_p`) -- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`lax.reshape_p`, :obj:`lax.squeeze_p`) -- some higher-order functions (:obj:`lax.cond_p`, :obj:`lax.while_p`, :obj:`lax.scan_p`) +- summation reductions (:obj:`~jax.lax.reduce_sum_p`) +- general indexing operations (:obj:`~jax.lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`) +- concatenation and stacking (:obj:`~jax.lax.concatenate_p`) +- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`~jax.lax.reshape_p`, + :obj:`~jax.lax.squeeze_p`, :obj:`~jax.lax.broadcast_in_dim_p`) +- some higher-order functions (:obj:`~jax.lax.cond_p`, :obj:`~jax.lax.while_p`, :obj:`~jax.lax.scan_p`) +- some simple 1D convolutions (:obj:`~jax.lax.conv_general_dilated_p`) Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used within a sparsify transform to operate on sparse arrays. This set of primitives is enough @@ -199,6 +201,7 @@ from jax.experimental.sparse.ad import ( from jax.experimental.sparse.bcoo import ( bcoo_broadcast_in_dim as bcoo_broadcast_in_dim, bcoo_concatenate as bcoo_concatenate, + bcoo_conv_general_dilated as bcoo_conv_general_dilated, bcoo_dot_general as bcoo_dot_general, bcoo_dot_general_p as bcoo_dot_general_p, bcoo_dot_general_sampled as bcoo_dot_general_sampled, diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 2263f498a..903f4d917 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -2350,6 +2350,55 @@ def bcoo_gather(operand: BCOO, start_indices: Array, return result.reshape(out_aval.shape).astype(out_aval.dtype) +def bcoo_conv_general_dilated(lhs, rhs, *, window_strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers, feature_group_count, + batch_group_count, precision, preferred_element_type): + # So far, we support just simple padded convolutions. + if not (isinstance(lhs, BCOO) and isinstance(rhs, jax.Array)): + raise NotImplementedError("bcoo_conv_general_dilated only implemented for sparse lhs and dense rhs. " + f"got {type(lhs)=} and {type(rhs)=}") + # Validate inputs using lax.conv_general_dilated abstract evaluation. + out_aval = jax.eval_shape( + functools.partial(lax.conv_general_dilated, window_strides=window_strides, padding=padding, + lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, batch_group_count=batch_group_count, + precision=precision, preferred_element_type=preferred_element_type), + jax.ShapeDtypeStruct(lhs.shape, lhs.dtype), jax.ShapeDtypeStruct(rhs.shape, rhs.dtype)) + if lhs_dilation != (1,) * (lhs.ndim - 2) or rhs_dilation != (1,) * (rhs.ndim - 2): + raise NotImplementedError("bcoo convolution with non-unit dilation.") + if window_strides != (1,) * (lhs.ndim - 2): + raise NotImplementedError("bcoo convolution with non-unit window_strides.") + if batch_group_count != 1 or feature_group_count != 1: + raise NotImplementedError("bcoo convolution with non-unit group counts.") + if lhs.shape[:2] != rhs.shape[:2] != (1, 1): + raise NotImplementedError("bcoo convolution with leading dimensions other than (1, 1)") + del precision, preferred_element_type # unused + + lhs = bcoo_squeeze(lhs, dimensions=(0, 1)) + rhs = lax.squeeze(rhs, dimensions=(0, 1)) + + if lhs.ndim != rhs.ndim != 1: + raise NotImplementedError("only 1-dimensional convoutions are implemented") + if lhs.n_batch != lhs.n_dense != 0: + raise NotImplementedError("bcoo convolution with batch or dense dimensions.") + + padding, = padding + assert len(padding) == 2 + + new_data = (lhs.data[:, None] * rhs[None, :]).astype(out_aval.dtype) + new_data = new_data.ravel() + + offset = jnp.arange(len(rhs), dtype=lhs.indices.dtype)[::-1] + 1 + padding[0] - len(rhs) + new_indices = lhs.indices[:, None, :] + offset[None, :, None] + new_indices = new_indices.reshape(lhs.indices.shape[0] * len(rhs), lhs.indices.shape[1]) + + mask = (new_indices < 0).any(1) + new_indices = jnp.where(mask[:, None], 0, new_indices) + new_data = jnp.where(mask, 0, new_data) + + out = BCOO((new_data, new_indices), shape=(lhs.shape[0] + padding[0] + padding[1] - len(rhs) + 1,)) + return bcoo_broadcast_in_dim(out, shape=(1, 1, *out.shape), broadcast_dimensions=range(2, 2 + out.ndim)) + @tree_util.register_pytree_node_class class BCOO(JAXSparse): """Experimental batched COO matrix implemented in JAX diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 75155eb69..b2b572dff 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -504,6 +504,7 @@ def _standard_sparse_rule(prim, sparse_op): _BCOO_STANDARD_PRIMITIVES = { lax.broadcast_in_dim_p: sparse.bcoo_broadcast_in_dim, lax.concatenate_p: lambda *a, **k: sparse.bcoo_concatenate(a, **k), + lax.conv_general_dilated_p: sparse.bcoo_conv_general_dilated, lax.dot_general_p: sparse.bcoo_dot_general, lax.dynamic_slice_p: lambda *a, **k: sparse.bcoo_dynamic_slice(a[0], a[1:], **k), lax.reshape_p: sparse.bcoo_reshape, diff --git a/tests/BUILD b/tests/BUILD index 478b6d89d..fed0fd025 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -761,6 +761,7 @@ jax_test( }, deps = [ "//jax:experimental_sparse", + "//jax:sparse_test_util", ], ) diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 8668f2c79..9bd30147d 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -28,6 +28,7 @@ from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer from jax.experimental.sparse.transform import ( arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv) from jax.experimental.sparse.util import CuSparseEfficiencyWarning +from jax.experimental.sparse import test_util as sptu config.parse_flags_with_absl() @@ -334,6 +335,24 @@ class SparsifyTest(jtu.JaxTestCase): sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs] self.assertArraysEqual(f(arrs), f(sparrs).todense()) + @jtu.sample_product( + lhs_shape=[(5,), (10,), (15,)], + rhs_shape=[(3,), (4,), (5,)], + mode=['same', 'valid', 'full'], + dtype=jtu.dtypes.numeric, + ) + def testSparseConvolve(self, lhs_shape, rhs_shape, mode, dtype): + f = self.sparsify(partial(jnp.convolve, mode=mode, precision='highest')) + sprng = sptu.rand_bcoo(self.rng(), n_batch=0, n_dense=0) + rng = jtu.rand_default(self.rng()) + + lhs_sp = sprng(lhs_shape, dtype) + lhs = lhs_sp.todense() + rhs = rng(rhs_shape, dtype) + + tol = {np.float32: 1E-5, np.complex64: 1E-5, np.float64: 1E-14, np.complex128: 1E-14} + self.assertAllClose(f(lhs, rhs), f(lhs_sp, rhs).todense(), atol=tol, rtol=tol) + def testSparseReshapeMethod(self): # Note: this is more fully tested in sparse_test.py:test_bcoo_reshape shape = (2, 3, 4)