Add lax.conv_general_dilated_local

This commit is contained in:
Roman Novak 2021-05-13 12:20:31 -07:00
parent b42e9e3789
commit bc84c9fe8f
4 changed files with 194 additions and 0 deletions

View File

@ -53,6 +53,7 @@ Operators
conv
convert_element_type
conv_general_dilated
conv_general_dilated_local
conv_general_dilated_patches
conv_with_general_padding
conv_transpose

View File

@ -113,3 +113,115 @@ def conv_general_dilated_patches(
preferred_element_type=preferred_element_type
)
return out
def conv_general_dilated_local(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
filter_shape: Sequence[int],
lhs_dilation: Sequence[int] = None,
rhs_dilation: Sequence[int] = None,
dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None,
precision: lax.PrecisionLike = None
) -> jnp.ndarray:
"""General n-dimensional unshared convolution operator with optional dilation.
Also known as locally connected layer, the operation is equivalent to
convolution with a separate (unshared) `rhs` kernel used at each output
spatial location. Docstring below adapted from `jax.lax.conv_general_dilated`.
See Also:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights. Unlike in regular
CNNs, its spatial coordinates (`H`, `W`, ...) correspond to output spatial
locations, while input spatial locations are fused with the input channel
locations in the single `I` dimension, in the order of
`"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where
`rhs_spec = dimension_numbers[1]`. For example, if `rhs_spec == "WHIO",
the unfolded kernel shape is
`"[output W][output H]{I[receptive window W][receptive window H]}O"`.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
filter_shape: a sequence of `n` integers, representing the receptive window
spatial shape in the order as specified in
`rhs_spec = dimension_numbers[1]`.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each input spatial dimension of `rhs`.
RHS dilation is also known as atrous convolution.
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
of length `n+2`.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
Returns:
An array containing the unshared convolution result.
In the string case of `dimension_numbers`, each character identifies by
position:
- the batch dimensions in `lhs`, `rhs`, and the output with the character
'N',
- the feature dimensions in `lhs` and the output with the character 'C',
- the input and output feature dimensions in rhs with the characters 'I'
and 'O' respectively, and
- spatial dimension correspondences between `lhs`, `rhs`, and the output using
any distinct characters.
For example, to indicate dimension numbers consistent with the `conv` function
with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
another example, to indicate dimension numbers consistent with the TensorFlow
Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
latter form of convolution dimension specification, window strides are
associated with spatial dimension character labels according to the order in
which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
is matched with the dimension corresponding to the first character
appearing in rhs_spec that is not `'I'` or `'O'`.
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
(for a 2D convolution).
"""
lhs_precision = (precision[0]
if (isinstance(precision, tuple) and len(precision) == 2)
else precision)
patches = conv_general_dilated_patches(
lhs=lhs,
filter_shape=filter_shape,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
precision=lhs_precision
)
lhs_spec, rhs_spec, out_spec = lax.conv_dimension_numbers(
lhs.shape, (1, 1) + tuple(filter_shape), dimension_numbers)
lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]]
lhs_b_dims = out_spec[2:]
rhs_b_dims = rhs_spec[2:]
rhs_b_dims = [rhs_b_dims[i] for i in sorted(range(len(rhs_b_dims)),
key=lambda k: lhs_b_dims[k])]
lhs_b_dims = sorted(lhs_b_dims)
dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims))
out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision)
out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1]))
return out

View File

@ -352,6 +352,7 @@ from jax._src.lax.parallel import (
xeinsum,
)
from jax._src.lax.other import (
conv_general_dilated_local,
conv_general_dilated_patches
)
from . import linalg

View File

@ -753,6 +753,86 @@ class LaxTest(jtu.JaxTestCase):
[out_spec.index(c) for c in out_spec if c not in ('N', 'C')])
self.assertAllClose(out, patches)
@parameterized.named_parameters(jtu.cases_from_list(
{
"testcase_name":
f"_dtype={dtype}_precision={precision}_n={n}_{padding}"
f"_dn={lhs_spec, rhs_spec, out_spec}]",
"dtype": dtype,
"rng_factory": rng_factory,
"precision": precision,
"n": n,
"padding": padding,
"lhs_spec": lhs_spec,
"rhs_spec": rhs_spec,
"out_spec": out_spec
}
for dtype in inexact_dtypes
for rng_factory in [jtu.rand_small]
for precision in [None,
lax.Precision.DEFAULT,
lax.Precision.HIGH,
lax.Precision.HIGHEST,
(lax.Precision.DEFAULT,
lax.Precision.HIGHEST)]
for n in [1, 2]
for padding in ['SAME', 'VALID']
for lhs_spec in [''.join(s)
for s in itertools.permutations('NCHWD'[:n + 2])]
for rhs_spec in [''.join(s)
for s in itertools.permutations('OIHWDX'[:n + 2])]
for out_spec in [''.join(s)
for s in itertools.permutations('NCHWDX'[:n + 2])]))
def testConvGeneralDilatedLocal(self, dtype, rng_factory, precision, n,
padding, lhs_spec, rhs_spec, out_spec):
"""Make sure LCN with tiled CNN kernel matches CNN."""
lhs_spec_default = 'NCHWDX'[:n + 2]
rhs_spec_default = 'OIHWDX'[:n + 2]
rng = rng_factory(self.rng())
lhs_default = rng((2, 4, 7, 6, 5, 8)[:n + 2], dtype)
rhs_default = rng((5, 4, 2, 3, 1, 2)[:n + 2], dtype)
window_strides = (1, 2, 3, 4)[:n]
rhs_dilation = (2, 1, 3, 2)[:n]
lhs_perm = [lhs_spec_default.index(c) for c in lhs_spec]
lhs = np.transpose(lhs_default, lhs_perm)
rhs_perm = [rhs_spec_default.index(c) for c in rhs_spec]
rhs = np.transpose(rhs_default, rhs_perm)
kwargs = dict(
lhs=lhs,
window_strides=window_strides,
padding=padding,
rhs_dilation=rhs_dilation,
dimension_numbers=(lhs_spec, rhs_spec, out_spec),
precision=precision
)
out_conv = lax.conv_general_dilated(rhs=rhs, **kwargs)
rhs_local = np.moveaxis(rhs, (rhs_spec.index('O'), rhs_spec.index('I')),
(0, 1))
rhs_local = rhs_local.reshape((rhs_local.shape[0], -1) + (1,) * n)
rhs_shape = (rhs_local.shape[:2] +
tuple(out_conv.shape[out_spec.index(c)]
for c in rhs_spec_default[2:]))
rhs_local = np.broadcast_to(rhs_local, rhs_shape)
rhs_local = np.transpose(rhs_local, rhs_perm)
filter_shape = [rhs.shape[i]
for i in range(n + 2) if rhs_spec[i] not in ('O', 'I')]
out_local = lax.conv_general_dilated_local(rhs=rhs_local,
filter_shape=filter_shape,
**kwargs)
self.assertAllClose(out_conv, out_local)
# TODO(mattjj): test conv_general_dilated against numpy
def testConv0DIsDot(self):