[sparse] add support for simple 1D convolutions

This commit is contained in:
Jake VanderPlas 2023-02-01 16:16:14 -08:00
parent 4d56def91f
commit 038798ed25
5 changed files with 78 additions and 5 deletions

View File

@ -130,11 +130,13 @@ Support for :func:`sparsify` includes a large number of the most common primitiv
- generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`) - generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`)
- zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.) - zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.)
- zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.) - zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.)
- summation reductions (:obj:`lax.reduce_sum_p`) - summation reductions (:obj:`~jax.lax.reduce_sum_p`)
- general indexing operations (:obj:`lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`) - general indexing operations (:obj:`~jax.lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`)
- concatenation and stacking (:obj:`lax.concatenate_p`) - concatenation and stacking (:obj:`~jax.lax.concatenate_p`)
- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`lax.reshape_p`, :obj:`lax.squeeze_p`) - transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`~jax.lax.reshape_p`,
- some higher-order functions (:obj:`lax.cond_p`, :obj:`lax.while_p`, :obj:`lax.scan_p`) :obj:`~jax.lax.squeeze_p`, :obj:`~jax.lax.broadcast_in_dim_p`)
- some higher-order functions (:obj:`~jax.lax.cond_p`, :obj:`~jax.lax.while_p`, :obj:`~jax.lax.scan_p`)
- some simple 1D convolutions (:obj:`~jax.lax.conv_general_dilated_p`)
Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used
within a sparsify transform to operate on sparse arrays. This set of primitives is enough within a sparsify transform to operate on sparse arrays. This set of primitives is enough
@ -199,6 +201,7 @@ from jax.experimental.sparse.ad import (
from jax.experimental.sparse.bcoo import ( from jax.experimental.sparse.bcoo import (
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim, bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
bcoo_concatenate as bcoo_concatenate, bcoo_concatenate as bcoo_concatenate,
bcoo_conv_general_dilated as bcoo_conv_general_dilated,
bcoo_dot_general as bcoo_dot_general, bcoo_dot_general as bcoo_dot_general,
bcoo_dot_general_p as bcoo_dot_general_p, bcoo_dot_general_p as bcoo_dot_general_p,
bcoo_dot_general_sampled as bcoo_dot_general_sampled, bcoo_dot_general_sampled as bcoo_dot_general_sampled,

View File

@ -2350,6 +2350,55 @@ def bcoo_gather(operand: BCOO, start_indices: Array,
return result.reshape(out_aval.shape).astype(out_aval.dtype) return result.reshape(out_aval.shape).astype(out_aval.dtype)
def bcoo_conv_general_dilated(lhs, rhs, *, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, precision, preferred_element_type):
# So far, we support just simple padded convolutions.
if not (isinstance(lhs, BCOO) and isinstance(rhs, jax.Array)):
raise NotImplementedError("bcoo_conv_general_dilated only implemented for sparse lhs and dense rhs. "
f"got {type(lhs)=} and {type(rhs)=}")
# Validate inputs using lax.conv_general_dilated abstract evaluation.
out_aval = jax.eval_shape(
functools.partial(lax.conv_general_dilated, window_strides=window_strides, padding=padding,
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count, batch_group_count=batch_group_count,
precision=precision, preferred_element_type=preferred_element_type),
jax.ShapeDtypeStruct(lhs.shape, lhs.dtype), jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
if lhs_dilation != (1,) * (lhs.ndim - 2) or rhs_dilation != (1,) * (rhs.ndim - 2):
raise NotImplementedError("bcoo convolution with non-unit dilation.")
if window_strides != (1,) * (lhs.ndim - 2):
raise NotImplementedError("bcoo convolution with non-unit window_strides.")
if batch_group_count != 1 or feature_group_count != 1:
raise NotImplementedError("bcoo convolution with non-unit group counts.")
if lhs.shape[:2] != rhs.shape[:2] != (1, 1):
raise NotImplementedError("bcoo convolution with leading dimensions other than (1, 1)")
del precision, preferred_element_type # unused
lhs = bcoo_squeeze(lhs, dimensions=(0, 1))
rhs = lax.squeeze(rhs, dimensions=(0, 1))
if lhs.ndim != rhs.ndim != 1:
raise NotImplementedError("only 1-dimensional convoutions are implemented")
if lhs.n_batch != lhs.n_dense != 0:
raise NotImplementedError("bcoo convolution with batch or dense dimensions.")
padding, = padding
assert len(padding) == 2
new_data = (lhs.data[:, None] * rhs[None, :]).astype(out_aval.dtype)
new_data = new_data.ravel()
offset = jnp.arange(len(rhs), dtype=lhs.indices.dtype)[::-1] + 1 + padding[0] - len(rhs)
new_indices = lhs.indices[:, None, :] + offset[None, :, None]
new_indices = new_indices.reshape(lhs.indices.shape[0] * len(rhs), lhs.indices.shape[1])
mask = (new_indices < 0).any(1)
new_indices = jnp.where(mask[:, None], 0, new_indices)
new_data = jnp.where(mask, 0, new_data)
out = BCOO((new_data, new_indices), shape=(lhs.shape[0] + padding[0] + padding[1] - len(rhs) + 1,))
return bcoo_broadcast_in_dim(out, shape=(1, 1, *out.shape), broadcast_dimensions=range(2, 2 + out.ndim))
@tree_util.register_pytree_node_class @tree_util.register_pytree_node_class
class BCOO(JAXSparse): class BCOO(JAXSparse):
"""Experimental batched COO matrix implemented in JAX """Experimental batched COO matrix implemented in JAX

View File

@ -504,6 +504,7 @@ def _standard_sparse_rule(prim, sparse_op):
_BCOO_STANDARD_PRIMITIVES = { _BCOO_STANDARD_PRIMITIVES = {
lax.broadcast_in_dim_p: sparse.bcoo_broadcast_in_dim, lax.broadcast_in_dim_p: sparse.bcoo_broadcast_in_dim,
lax.concatenate_p: lambda *a, **k: sparse.bcoo_concatenate(a, **k), lax.concatenate_p: lambda *a, **k: sparse.bcoo_concatenate(a, **k),
lax.conv_general_dilated_p: sparse.bcoo_conv_general_dilated,
lax.dot_general_p: sparse.bcoo_dot_general, lax.dot_general_p: sparse.bcoo_dot_general,
lax.dynamic_slice_p: lambda *a, **k: sparse.bcoo_dynamic_slice(a[0], a[1:], **k), lax.dynamic_slice_p: lambda *a, **k: sparse.bcoo_dynamic_slice(a[0], a[1:], **k),
lax.reshape_p: sparse.bcoo_reshape, lax.reshape_p: sparse.bcoo_reshape,

View File

@ -761,6 +761,7 @@ jax_test(
}, },
deps = [ deps = [
"//jax:experimental_sparse", "//jax:experimental_sparse",
"//jax:sparse_test_util",
], ],
) )

View File

@ -28,6 +28,7 @@ from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer
from jax.experimental.sparse.transform import ( from jax.experimental.sparse.transform import (
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv) arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
from jax.experimental.sparse.util import CuSparseEfficiencyWarning from jax.experimental.sparse.util import CuSparseEfficiencyWarning
from jax.experimental.sparse import test_util as sptu
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -334,6 +335,24 @@ class SparsifyTest(jtu.JaxTestCase):
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs] sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
self.assertArraysEqual(f(arrs), f(sparrs).todense()) self.assertArraysEqual(f(arrs), f(sparrs).todense())
@jtu.sample_product(
lhs_shape=[(5,), (10,), (15,)],
rhs_shape=[(3,), (4,), (5,)],
mode=['same', 'valid', 'full'],
dtype=jtu.dtypes.numeric,
)
def testSparseConvolve(self, lhs_shape, rhs_shape, mode, dtype):
f = self.sparsify(partial(jnp.convolve, mode=mode, precision='highest'))
sprng = sptu.rand_bcoo(self.rng(), n_batch=0, n_dense=0)
rng = jtu.rand_default(self.rng())
lhs_sp = sprng(lhs_shape, dtype)
lhs = lhs_sp.todense()
rhs = rng(rhs_shape, dtype)
tol = {np.float32: 1E-5, np.complex64: 1E-5, np.float64: 1E-14, np.complex128: 1E-14}
self.assertAllClose(f(lhs, rhs), f(lhs_sp, rhs).todense(), atol=tol, rtol=tol)
def testSparseReshapeMethod(self): def testSparseReshapeMethod(self):
# Note: this is more fully tested in sparse_test.py:test_bcoo_reshape # Note: this is more fully tested in sparse_test.py:test_bcoo_reshape
shape = (2, 3, 4) shape = (2, 3, 4)