mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add lax.conv_general_dilated_local
This commit is contained in:
parent
b42e9e3789
commit
bc84c9fe8f
@ -53,6 +53,7 @@ Operators
|
||||
conv
|
||||
convert_element_type
|
||||
conv_general_dilated
|
||||
conv_general_dilated_local
|
||||
conv_general_dilated_patches
|
||||
conv_with_general_padding
|
||||
conv_transpose
|
||||
|
@ -113,3 +113,115 @@ def conv_general_dilated_patches(
|
||||
preferred_element_type=preferred_element_type
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def conv_general_dilated_local(
|
||||
lhs: jnp.ndarray,
|
||||
rhs: jnp.ndarray,
|
||||
window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
filter_shape: Sequence[int],
|
||||
lhs_dilation: Sequence[int] = None,
|
||||
rhs_dilation: Sequence[int] = None,
|
||||
dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None,
|
||||
precision: lax.PrecisionLike = None
|
||||
) -> jnp.ndarray:
|
||||
"""General n-dimensional unshared convolution operator with optional dilation.
|
||||
|
||||
Also known as locally connected layer, the operation is equivalent to
|
||||
convolution with a separate (unshared) `rhs` kernel used at each output
|
||||
spatial location. Docstring below adapted from `jax.lax.conv_general_dilated`.
|
||||
|
||||
See Also:
|
||||
https://www.tensorflow.org/xla/operation_semantics#conv_convolution
|
||||
|
||||
Args:
|
||||
lhs: a rank `n+2` dimensional input array.
|
||||
rhs: a rank `n+2` dimensional array of kernel weights. Unlike in regular
|
||||
CNNs, its spatial coordinates (`H`, `W`, ...) correspond to output spatial
|
||||
locations, while input spatial locations are fused with the input channel
|
||||
locations in the single `I` dimension, in the order of
|
||||
`"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where
|
||||
`rhs_spec = dimension_numbers[1]`. For example, if `rhs_spec == "WHIO",
|
||||
the unfolded kernel shape is
|
||||
`"[output W][output H]{I[receptive window W][receptive window H]}O"`.
|
||||
window_strides: a sequence of `n` integers, representing the inter-window
|
||||
strides.
|
||||
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
|
||||
`n` `(low, high)` integer pairs that give the padding to apply before and
|
||||
after each spatial dimension.
|
||||
filter_shape: a sequence of `n` integers, representing the receptive window
|
||||
spatial shape in the order as specified in
|
||||
`rhs_spec = dimension_numbers[1]`.
|
||||
lhs_dilation: `None`, or a sequence of `n` integers, giving the
|
||||
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
|
||||
is also known as transposed convolution.
|
||||
rhs_dilation: `None`, or a sequence of `n` integers, giving the
|
||||
dilation factor to apply in each input spatial dimension of `rhs`.
|
||||
RHS dilation is also known as atrous convolution.
|
||||
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
|
||||
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
|
||||
of length `n+2`.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
||||
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
|
||||
|
||||
Returns:
|
||||
An array containing the unshared convolution result.
|
||||
|
||||
In the string case of `dimension_numbers`, each character identifies by
|
||||
position:
|
||||
|
||||
- the batch dimensions in `lhs`, `rhs`, and the output with the character
|
||||
'N',
|
||||
- the feature dimensions in `lhs` and the output with the character 'C',
|
||||
- the input and output feature dimensions in rhs with the characters 'I'
|
||||
and 'O' respectively, and
|
||||
- spatial dimension correspondences between `lhs`, `rhs`, and the output using
|
||||
any distinct characters.
|
||||
|
||||
For example, to indicate dimension numbers consistent with the `conv` function
|
||||
with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
|
||||
another example, to indicate dimension numbers consistent with the TensorFlow
|
||||
Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
|
||||
latter form of convolution dimension specification, window strides are
|
||||
associated with spatial dimension character labels according to the order in
|
||||
which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
|
||||
is matched with the dimension corresponding to the first character
|
||||
appearing in rhs_spec that is not `'I'` or `'O'`.
|
||||
|
||||
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
|
||||
(for a 2D convolution).
|
||||
"""
|
||||
lhs_precision = (precision[0]
|
||||
if (isinstance(precision, tuple) and len(precision) == 2)
|
||||
else precision)
|
||||
|
||||
patches = conv_general_dilated_patches(
|
||||
lhs=lhs,
|
||||
filter_shape=filter_shape,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=dimension_numbers,
|
||||
precision=lhs_precision
|
||||
)
|
||||
|
||||
lhs_spec, rhs_spec, out_spec = lax.conv_dimension_numbers(
|
||||
lhs.shape, (1, 1) + tuple(filter_shape), dimension_numbers)
|
||||
|
||||
lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]]
|
||||
|
||||
lhs_b_dims = out_spec[2:]
|
||||
rhs_b_dims = rhs_spec[2:]
|
||||
|
||||
rhs_b_dims = [rhs_b_dims[i] for i in sorted(range(len(rhs_b_dims)),
|
||||
key=lambda k: lhs_b_dims[k])]
|
||||
lhs_b_dims = sorted(lhs_b_dims)
|
||||
|
||||
dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims))
|
||||
out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision)
|
||||
out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1]))
|
||||
return out
|
||||
|
@ -352,6 +352,7 @@ from jax._src.lax.parallel import (
|
||||
xeinsum,
|
||||
)
|
||||
from jax._src.lax.other import (
|
||||
conv_general_dilated_local,
|
||||
conv_general_dilated_patches
|
||||
)
|
||||
from . import linalg
|
||||
|
@ -753,6 +753,86 @@ class LaxTest(jtu.JaxTestCase):
|
||||
[out_spec.index(c) for c in out_spec if c not in ('N', 'C')])
|
||||
self.assertAllClose(out, patches)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{
|
||||
"testcase_name":
|
||||
f"_dtype={dtype}_precision={precision}_n={n}_{padding}"
|
||||
f"_dn={lhs_spec, rhs_spec, out_spec}]",
|
||||
"dtype": dtype,
|
||||
"rng_factory": rng_factory,
|
||||
"precision": precision,
|
||||
"n": n,
|
||||
"padding": padding,
|
||||
"lhs_spec": lhs_spec,
|
||||
"rhs_spec": rhs_spec,
|
||||
"out_spec": out_spec
|
||||
}
|
||||
for dtype in inexact_dtypes
|
||||
for rng_factory in [jtu.rand_small]
|
||||
for precision in [None,
|
||||
lax.Precision.DEFAULT,
|
||||
lax.Precision.HIGH,
|
||||
lax.Precision.HIGHEST,
|
||||
(lax.Precision.DEFAULT,
|
||||
lax.Precision.HIGHEST)]
|
||||
for n in [1, 2]
|
||||
for padding in ['SAME', 'VALID']
|
||||
for lhs_spec in [''.join(s)
|
||||
for s in itertools.permutations('NCHWD'[:n + 2])]
|
||||
for rhs_spec in [''.join(s)
|
||||
for s in itertools.permutations('OIHWDX'[:n + 2])]
|
||||
for out_spec in [''.join(s)
|
||||
for s in itertools.permutations('NCHWDX'[:n + 2])]))
|
||||
def testConvGeneralDilatedLocal(self, dtype, rng_factory, precision, n,
|
||||
padding, lhs_spec, rhs_spec, out_spec):
|
||||
"""Make sure LCN with tiled CNN kernel matches CNN."""
|
||||
lhs_spec_default = 'NCHWDX'[:n + 2]
|
||||
rhs_spec_default = 'OIHWDX'[:n + 2]
|
||||
|
||||
rng = rng_factory(self.rng())
|
||||
|
||||
lhs_default = rng((2, 4, 7, 6, 5, 8)[:n + 2], dtype)
|
||||
rhs_default = rng((5, 4, 2, 3, 1, 2)[:n + 2], dtype)
|
||||
|
||||
window_strides = (1, 2, 3, 4)[:n]
|
||||
rhs_dilation = (2, 1, 3, 2)[:n]
|
||||
|
||||
lhs_perm = [lhs_spec_default.index(c) for c in lhs_spec]
|
||||
lhs = np.transpose(lhs_default, lhs_perm)
|
||||
|
||||
rhs_perm = [rhs_spec_default.index(c) for c in rhs_spec]
|
||||
rhs = np.transpose(rhs_default, rhs_perm)
|
||||
|
||||
kwargs = dict(
|
||||
lhs=lhs,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=(lhs_spec, rhs_spec, out_spec),
|
||||
precision=precision
|
||||
)
|
||||
|
||||
out_conv = lax.conv_general_dilated(rhs=rhs, **kwargs)
|
||||
|
||||
rhs_local = np.moveaxis(rhs, (rhs_spec.index('O'), rhs_spec.index('I')),
|
||||
(0, 1))
|
||||
rhs_local = rhs_local.reshape((rhs_local.shape[0], -1) + (1,) * n)
|
||||
|
||||
rhs_shape = (rhs_local.shape[:2] +
|
||||
tuple(out_conv.shape[out_spec.index(c)]
|
||||
for c in rhs_spec_default[2:]))
|
||||
|
||||
rhs_local = np.broadcast_to(rhs_local, rhs_shape)
|
||||
rhs_local = np.transpose(rhs_local, rhs_perm)
|
||||
|
||||
filter_shape = [rhs.shape[i]
|
||||
for i in range(n + 2) if rhs_spec[i] not in ('O', 'I')]
|
||||
out_local = lax.conv_general_dilated_local(rhs=rhs_local,
|
||||
filter_shape=filter_shape,
|
||||
**kwargs)
|
||||
|
||||
self.assertAllClose(out_conv, out_local)
|
||||
|
||||
# TODO(mattjj): test conv_general_dilated against numpy
|
||||
|
||||
def testConv0DIsDot(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user