mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[sparse] add support for simple 1D convolutions
This commit is contained in:
parent
4d56def91f
commit
038798ed25
@ -130,11 +130,13 @@ Support for :func:`sparsify` includes a large number of the most common primitiv
|
|||||||
- generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`)
|
- generalized (batched) matrix products & einstein summations (:obj:`~jax.lax.dot_general_p`)
|
||||||
- zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.)
|
- zero-preserving elementwise binary operations (e.g. :obj:`~jax.lax.add_p`, :obj:`~jax.lax.mul_p`, etc.)
|
||||||
- zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.)
|
- zero-preserving elementwise unary operations (e.g. :obj:`~jax.lax.abs_p`, :obj:`jax.lax.neg_p`, etc.)
|
||||||
- summation reductions (:obj:`lax.reduce_sum_p`)
|
- summation reductions (:obj:`~jax.lax.reduce_sum_p`)
|
||||||
- general indexing operations (:obj:`lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`)
|
- general indexing operations (:obj:`~jax.lax.slice_p`, `lax.dynamic_slice_p`, `lax.gather_p`)
|
||||||
- concatenation and stacking (:obj:`lax.concatenate_p`)
|
- concatenation and stacking (:obj:`~jax.lax.concatenate_p`)
|
||||||
- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`lax.reshape_p`, :obj:`lax.squeeze_p`)
|
- transposition & reshaping ((:obj:`~jax.lax.transpose_p`, :obj:`~jax.lax.reshape_p`,
|
||||||
- some higher-order functions (:obj:`lax.cond_p`, :obj:`lax.while_p`, :obj:`lax.scan_p`)
|
:obj:`~jax.lax.squeeze_p`, :obj:`~jax.lax.broadcast_in_dim_p`)
|
||||||
|
- some higher-order functions (:obj:`~jax.lax.cond_p`, :obj:`~jax.lax.while_p`, :obj:`~jax.lax.scan_p`)
|
||||||
|
- some simple 1D convolutions (:obj:`~jax.lax.conv_general_dilated_p`)
|
||||||
|
|
||||||
Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used
|
Nearly any :mod:`jax.numpy` function that lowers to these supported primitives can be used
|
||||||
within a sparsify transform to operate on sparse arrays. This set of primitives is enough
|
within a sparsify transform to operate on sparse arrays. This set of primitives is enough
|
||||||
@ -199,6 +201,7 @@ from jax.experimental.sparse.ad import (
|
|||||||
from jax.experimental.sparse.bcoo import (
|
from jax.experimental.sparse.bcoo import (
|
||||||
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
|
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
|
||||||
bcoo_concatenate as bcoo_concatenate,
|
bcoo_concatenate as bcoo_concatenate,
|
||||||
|
bcoo_conv_general_dilated as bcoo_conv_general_dilated,
|
||||||
bcoo_dot_general as bcoo_dot_general,
|
bcoo_dot_general as bcoo_dot_general,
|
||||||
bcoo_dot_general_p as bcoo_dot_general_p,
|
bcoo_dot_general_p as bcoo_dot_general_p,
|
||||||
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
|
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
|
||||||
|
@ -2350,6 +2350,55 @@ def bcoo_gather(operand: BCOO, start_indices: Array,
|
|||||||
|
|
||||||
return result.reshape(out_aval.shape).astype(out_aval.dtype)
|
return result.reshape(out_aval.shape).astype(out_aval.dtype)
|
||||||
|
|
||||||
|
def bcoo_conv_general_dilated(lhs, rhs, *, window_strides, padding, lhs_dilation,
|
||||||
|
rhs_dilation, dimension_numbers, feature_group_count,
|
||||||
|
batch_group_count, precision, preferred_element_type):
|
||||||
|
# So far, we support just simple padded convolutions.
|
||||||
|
if not (isinstance(lhs, BCOO) and isinstance(rhs, jax.Array)):
|
||||||
|
raise NotImplementedError("bcoo_conv_general_dilated only implemented for sparse lhs and dense rhs. "
|
||||||
|
f"got {type(lhs)=} and {type(rhs)=}")
|
||||||
|
# Validate inputs using lax.conv_general_dilated abstract evaluation.
|
||||||
|
out_aval = jax.eval_shape(
|
||||||
|
functools.partial(lax.conv_general_dilated, window_strides=window_strides, padding=padding,
|
||||||
|
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
|
||||||
|
feature_group_count=feature_group_count, batch_group_count=batch_group_count,
|
||||||
|
precision=precision, preferred_element_type=preferred_element_type),
|
||||||
|
jax.ShapeDtypeStruct(lhs.shape, lhs.dtype), jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
|
||||||
|
if lhs_dilation != (1,) * (lhs.ndim - 2) or rhs_dilation != (1,) * (rhs.ndim - 2):
|
||||||
|
raise NotImplementedError("bcoo convolution with non-unit dilation.")
|
||||||
|
if window_strides != (1,) * (lhs.ndim - 2):
|
||||||
|
raise NotImplementedError("bcoo convolution with non-unit window_strides.")
|
||||||
|
if batch_group_count != 1 or feature_group_count != 1:
|
||||||
|
raise NotImplementedError("bcoo convolution with non-unit group counts.")
|
||||||
|
if lhs.shape[:2] != rhs.shape[:2] != (1, 1):
|
||||||
|
raise NotImplementedError("bcoo convolution with leading dimensions other than (1, 1)")
|
||||||
|
del precision, preferred_element_type # unused
|
||||||
|
|
||||||
|
lhs = bcoo_squeeze(lhs, dimensions=(0, 1))
|
||||||
|
rhs = lax.squeeze(rhs, dimensions=(0, 1))
|
||||||
|
|
||||||
|
if lhs.ndim != rhs.ndim != 1:
|
||||||
|
raise NotImplementedError("only 1-dimensional convoutions are implemented")
|
||||||
|
if lhs.n_batch != lhs.n_dense != 0:
|
||||||
|
raise NotImplementedError("bcoo convolution with batch or dense dimensions.")
|
||||||
|
|
||||||
|
padding, = padding
|
||||||
|
assert len(padding) == 2
|
||||||
|
|
||||||
|
new_data = (lhs.data[:, None] * rhs[None, :]).astype(out_aval.dtype)
|
||||||
|
new_data = new_data.ravel()
|
||||||
|
|
||||||
|
offset = jnp.arange(len(rhs), dtype=lhs.indices.dtype)[::-1] + 1 + padding[0] - len(rhs)
|
||||||
|
new_indices = lhs.indices[:, None, :] + offset[None, :, None]
|
||||||
|
new_indices = new_indices.reshape(lhs.indices.shape[0] * len(rhs), lhs.indices.shape[1])
|
||||||
|
|
||||||
|
mask = (new_indices < 0).any(1)
|
||||||
|
new_indices = jnp.where(mask[:, None], 0, new_indices)
|
||||||
|
new_data = jnp.where(mask, 0, new_data)
|
||||||
|
|
||||||
|
out = BCOO((new_data, new_indices), shape=(lhs.shape[0] + padding[0] + padding[1] - len(rhs) + 1,))
|
||||||
|
return bcoo_broadcast_in_dim(out, shape=(1, 1, *out.shape), broadcast_dimensions=range(2, 2 + out.ndim))
|
||||||
|
|
||||||
@tree_util.register_pytree_node_class
|
@tree_util.register_pytree_node_class
|
||||||
class BCOO(JAXSparse):
|
class BCOO(JAXSparse):
|
||||||
"""Experimental batched COO matrix implemented in JAX
|
"""Experimental batched COO matrix implemented in JAX
|
||||||
|
@ -504,6 +504,7 @@ def _standard_sparse_rule(prim, sparse_op):
|
|||||||
_BCOO_STANDARD_PRIMITIVES = {
|
_BCOO_STANDARD_PRIMITIVES = {
|
||||||
lax.broadcast_in_dim_p: sparse.bcoo_broadcast_in_dim,
|
lax.broadcast_in_dim_p: sparse.bcoo_broadcast_in_dim,
|
||||||
lax.concatenate_p: lambda *a, **k: sparse.bcoo_concatenate(a, **k),
|
lax.concatenate_p: lambda *a, **k: sparse.bcoo_concatenate(a, **k),
|
||||||
|
lax.conv_general_dilated_p: sparse.bcoo_conv_general_dilated,
|
||||||
lax.dot_general_p: sparse.bcoo_dot_general,
|
lax.dot_general_p: sparse.bcoo_dot_general,
|
||||||
lax.dynamic_slice_p: lambda *a, **k: sparse.bcoo_dynamic_slice(a[0], a[1:], **k),
|
lax.dynamic_slice_p: lambda *a, **k: sparse.bcoo_dynamic_slice(a[0], a[1:], **k),
|
||||||
lax.reshape_p: sparse.bcoo_reshape,
|
lax.reshape_p: sparse.bcoo_reshape,
|
||||||
|
@ -761,6 +761,7 @@ jax_test(
|
|||||||
},
|
},
|
||||||
deps = [
|
deps = [
|
||||||
"//jax:experimental_sparse",
|
"//jax:experimental_sparse",
|
||||||
|
"//jax:sparse_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ from jax.experimental.sparse import BCOO, sparsify, todense, SparseTracer
|
|||||||
from jax.experimental.sparse.transform import (
|
from jax.experimental.sparse.transform import (
|
||||||
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
|
arrays_to_spvalues, spvalues_to_arrays, sparsify_raw, SparsifyValue, SparsifyEnv)
|
||||||
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
|
from jax.experimental.sparse.util import CuSparseEfficiencyWarning
|
||||||
|
from jax.experimental.sparse import test_util as sptu
|
||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
|
||||||
@ -334,6 +335,24 @@ class SparsifyTest(jtu.JaxTestCase):
|
|||||||
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
|
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
|
||||||
self.assertArraysEqual(f(arrs), f(sparrs).todense())
|
self.assertArraysEqual(f(arrs), f(sparrs).todense())
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
lhs_shape=[(5,), (10,), (15,)],
|
||||||
|
rhs_shape=[(3,), (4,), (5,)],
|
||||||
|
mode=['same', 'valid', 'full'],
|
||||||
|
dtype=jtu.dtypes.numeric,
|
||||||
|
)
|
||||||
|
def testSparseConvolve(self, lhs_shape, rhs_shape, mode, dtype):
|
||||||
|
f = self.sparsify(partial(jnp.convolve, mode=mode, precision='highest'))
|
||||||
|
sprng = sptu.rand_bcoo(self.rng(), n_batch=0, n_dense=0)
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
|
||||||
|
lhs_sp = sprng(lhs_shape, dtype)
|
||||||
|
lhs = lhs_sp.todense()
|
||||||
|
rhs = rng(rhs_shape, dtype)
|
||||||
|
|
||||||
|
tol = {np.float32: 1E-5, np.complex64: 1E-5, np.float64: 1E-14, np.complex128: 1E-14}
|
||||||
|
self.assertAllClose(f(lhs, rhs), f(lhs_sp, rhs).todense(), atol=tol, rtol=tol)
|
||||||
|
|
||||||
def testSparseReshapeMethod(self):
|
def testSparseReshapeMethod(self):
|
||||||
# Note: this is more fully tested in sparse_test.py:test_bcoo_reshape
|
# Note: this is more fully tested in sparse_test.py:test_bcoo_reshape
|
||||||
shape = (2, 3, 4)
|
shape = (2, 3, 4)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user