mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[sparse] add support for simple 1D convolutions
This commit is contained in:
parent
4d56def91f
commit
038798ed25
@ -130,11 +130,13 @@ Support for :func:`sparsify` includes a large number of the most common primitiv
|
||||
- generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`)
|
||||
- zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.)
|
||||
- zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.)
|
||||
- summation reductions (:obj:`lax.reduce_sum_p`)
|
||||
- general indexing operations (:obj:`lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`)
|
||||
- concatenation and stacking (:obj:`lax.concatenate_p`)
|
||||
- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`lax.reshape_p`, :obj:`lax.squeeze_p`)
|
||||
- some higher-order functions (:obj:`lax.cond_p`, :obj:`lax.while_p`, :obj:`lax.scan_p`)
|
||||
- summation reductions (:obj:`~jax.lax.reduce_sum_p`)
|
||||
- general indexing operations (:obj:`~jax.lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`)
|
||||
- concatenation and stacking (:obj:`~jax.lax.concatenate_p`)
|
||||
- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`~jax.lax.reshape_p`,
|
||||
:obj:`~jax.lax.squeeze_p`, :obj:`~jax.lax.broadcast_in_dim_p`)
|
||||
- some higher-order functions (:obj:`~jax.lax.cond_p`, :obj:`~jax.lax.while_p`, :obj:`~jax.lax.scan_p`)
|
||||
- some simple 1D convolutions (:obj:`~jax.lax.conv_general_dilated_p`)
|
||||
|
||||
Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used
|
||||
within a sparsify transform to operate on sparse arrays. This set of primitives is enough
|
||||
@ -199,6 +201,7 @@ from jax.experimental.sparse.ad import (
|
||||
from jax.experimental.sparse.bcoo import (
|
||||
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
|
||||
bcoo_concatenate as bcoo_concatenate,
|
||||
bcoo_conv_general_dilated as bcoo_conv_general_dilated,
|
||||
bcoo_dot_general as bcoo_dot_general,
|
||||
bcoo_dot_general_p as bcoo_dot_general_p,
|
||||
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
|
||||
|
@ -2350,6 +2350,55 @@ def bcoo_gather(operand: BCOO, start_indices: Array,
|
||||
|
||||
return result.reshape(out_aval.shape).astype(out_aval.dtype)
|
||||
|
||||
def bcoo_conv_general_dilated(lhs, rhs, *, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision, preferred_element_type):
|
||||
# So far, we support just simple padded convolutions.
|
||||
if not (isinstance(lhs, BCOO) and isinstance(rhs, jax.Array)):
|
||||
raise NotImplementedError("bcoo_conv_general_dilated only implemented for sparse lhs and dense rhs. "
|
||||
f"got {type(lhs)=} and {type(rhs)=}")
|
||||
# Validate inputs using lax.conv_general_dilated abstract evaluation.
|
||||
out_aval = jax.eval_shape(
|
||||
functools.partial(lax.conv_general_dilated, window_strides=window_strides, padding=padding,
|
||||
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count, batch_group_count=batch_group_count,
|
||||
precision=precision, preferred_element_type=preferred_element_type),
|
||||
jax.ShapeDtypeStruct(lhs.shape, lhs.dtype), jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
|
||||
if lhs_dilation != (1,) * (lhs.ndim - 2) or rhs_dilation != (1,) * (rhs.ndim - 2):
|
||||
raise NotImplementedError("bcoo convolution with non-unit dilation.")
|
||||
if window_strides != (1,) * (lhs.ndim - 2):
|
||||
raise NotImplementedError("bcoo convolution with non-unit window_strides.")
|
||||
if batch_group_count != 1 or feature_group_count != 1:
|
||||
raise NotImplementedError("bcoo convolution with non-unit group counts.")
|
||||
if lhs.shape[:2] != rhs.shape[:2] != (1, 1):
|
||||
raise NotImplementedError("bcoo convolution with leading dimensions other than (1, 1)")
|
||||
del precision, preferred_element_type # unused
|
||||
|
||||
lhs = bcoo_squeeze(lhs, dimensions=(0, 1))
|
||||
rhs = lax.squeeze(rhs, dimensions=(0, 1))
|
||||
|
||||
if lhs.ndim != rhs.ndim != 1:
|
||||
raise NotImplementedError("only 1-dimensional convoutions are implemented")
|
||||
if lhs.n_batch != lhs.n_dense != 0:
|
||||
raise NotImplementedError("bcoo convolution with batch or dense dimensions.")
|
||||
|
||||
padding, = padding
|
||||
assert len(padding) == 2
|
||||
|
||||
new_data = (lhs.data[:, None] * rhs[None, :]).astype(out_aval.dtype)
|
||||
new_data = new_data.ravel()
|
||||
|
||||
offset = jnp.arange(len(rhs), dtype=lhs.indices.dtype)[::-1] + 1 + padding[0] - len(rhs)
|
||||
new_indices = lhs.indices[:, None, :] + offset[None, :, None]
|
||||
new_indices = new_indices.reshape(lhs.indices.shape[0] * len(rhs), lhs.indices.shape[1])
|
||||
|
||||
mask = (new_indices < 0).any(1)
|
||||
new_indices = jnp.where(mask[:, None], 0, new_indices)
|
||||
new_data = jnp.where(mask, 0, new_data)
|
||||
|
||||
out = BCOO((new_data, new_indices), shape=(lhs.shape[0] + padding[0] + padding[1] - len(rhs) + 1,))
|
||||
return bcoo_broadcast_in_dim(out, shape=(1, 1, *out.shape), broadcast_dimensions=range(2, 2 + out.ndim))
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class BCOO(JAXSparse):
|
||||
"""Experimental batched COO matrix implemented in JAX
|
||||
|
@ -504,6 +504,7 @@ def _standard_sparse_rule(prim, sparse_op):
|
||||
_BCOO_STANDARD_PRIMITIVES = {
|
||||
lax.broadcast_in_dim_p: sparse.bcoo_broadcast_in_dim,
|
||||
lax.concatenate_p: lambda *a, **k: sparse.bcoo_concatenate(a, **k),
|
||||
lax.conv_general_dilated_p: sparse.bcoo_conv_general_dilated,
|
||||
lax.dot_general_p: sparse.bcoo_dot_general,
|
||||
lax.dynamic_slice_p: lambda *a, **k: sparse.bcoo_dynamic_slice(a[0], a[1:], **k),
|
||||
lax.reshape_p: sparse.bcoo_reshape,
|
||||
|
@ -761,6 +761,7 @@ jax_test(
|
||||
},
|
||||
deps = [
|
||||
"//jax:experimental_sparse",
|
||||
"//jax:sparse_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -28,6 +28,7 @@ from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer
|
||||
from jax.experimental.sparse.transform import (
|
||||
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
|
||||
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
|
||||
from jax.experimental.sparse import test_util as sptu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -334,6 +335,24 @@ class SparsifyTest(jtu.JaxTestCase):
|
||||
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
|
||||
self.assertArraysEqual(f(arrs), f(sparrs).todense())
|
||||
|
||||
@jtu.sample_product(
|
||||
lhs_shape=[(5,), (10,), (15,)],
|
||||
rhs_shape=[(3,), (4,), (5,)],
|
||||
mode=['same', 'valid', 'full'],
|
||||
dtype=jtu.dtypes.numeric,
|
||||
)
|
||||
def testSparseConvolve(self, lhs_shape, rhs_shape, mode, dtype):
|
||||
f = self.sparsify(partial(jnp.convolve, mode=mode, precision='highest'))
|
||||
sprng = sptu.rand_bcoo(self.rng(), n_batch=0, n_dense=0)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
lhs_sp = sprng(lhs_shape, dtype)
|
||||
lhs = lhs_sp.todense()
|
||||
rhs = rng(rhs_shape, dtype)
|
||||
|
||||
tol = {np.float32: 1E-5, np.complex64: 1E-5, np.float64: 1E-14, np.complex128: 1E-14}
|
||||
self.assertAllClose(f(lhs, rhs), f(lhs_sp, rhs).todense(), atol=tol, rtol=tol)
|
||||
|
||||
def testSparseReshapeMethod(self):
|
||||
# Note: this is more fully tested in sparse_test.py:test_bcoo_reshape
|
||||
shape = (2, 3, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user