Add documentation for more lax functions, notably concatenate and conv.

This commit is contained in:
Peter Hawkins 2019-02-19 21:28:01 -05:00
parent d2cabb5fee
commit f4a73c127e

View File

@ -360,11 +360,77 @@ def clamp(min, x, max):
return clamp_p.bind(min, x, max)
def concatenate(operands, dimension):
"""Concatenates a sequence of arrays along `dimension`.
Wraps XLA's `Concatenate
<https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
operator.
Args:
operands: a sequence of arrays to concatenate. The arrays must have equal
shapes, except in the `dimension` axis.
dimension: the dimension along which to concatenate the arrays.
Returns:
An array containing the concatenation.
"""
return concatenate_p.bind(*operands, dimension=dimension,
operand_shapes=tuple(o.shape for o in operands))
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
rhs_dilation=None, dimension_numbers=None):
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
<https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
operator.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
of length `n+2`.
Returns:
An array containing the convolution result.
In the string case of `dimension_numbers`, each character identifies by
position:
- the batch dimensions in `lhs`, `rhs`, and the output with the character
'N',
- the feature dimensions in `lhs` and the output with the character 'C',
- the input and output feature dimensions in rhs with the characters 'I'
and 'O' respectively, and
- spatial dimension correspondences between lhs, rhs, and the output using
any distinct characters.
For example, to indicate dimension numbers
consistent with the Conv operation with two spatial dimensions, one
could use `('NCHW', 'OIHW', 'NCHW')`. As another example, to indicate
dimension numbers consistent with the TensorFlow Conv2D operation, one
could use `('NHWC', 'HWIO', 'NHWC')`. When using the latter form of
convolution dimension specification, window strides are associated with
spatial dimension character labels according to the order in which the
labels appear in the `rhs_spec` string, so that `window_strides[0]` is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not `'I'` or `'O'`.
If `dimension_numbers` is `None`, the default is `(NCHW, OIHW, NCHW)` (for
a 2D convolution).
"""
if type(dimension_numbers) is not ConvDimensionNumbers:
dimension_numbers = conv_dimension_numbers(
lhs.shape, rhs.shape, dimension_numbers)
@ -727,11 +793,43 @@ def psum(x, axis_name):
def conv(lhs, rhs, window_strides, padding):
"""Convenience wrapper around `conv_general_dilated`.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`.
Returns:
An array containing the convolution result.
"""
pads = padtype_to_pads(lhs.shape[2:], rhs.shape[2:], window_strides, padding)
return conv_general_dilated(lhs, rhs, window_strides, padding)
def conv_with_general_padding(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation):
"""Convenience wrapper around `conv_general_dilated`.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
Returns:
An array containing the convolution result.
"""
return conv_general_dilated(
lhs, rhs, window_strides, padding, lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation)
@ -3456,11 +3554,21 @@ def remaining(original, *removed_lists):
# [batch dim, feature dim, spatial dims ...]
# rhs_spec is a list containing:
# [out feature dim, in feature dim, spatial dims ...]
ConvDimensionNumbers = collections.namedtuple(
"ConvDimensionNumbers", ["lhs_spec", "rhs_spec", "out_spec"])
class ConvDimensionNumbers(collections.namedtuple(
"ConvDimensionNumbers", ["lhs_spec", "rhs_spec", "out_spec"])):
"""Describes batch, spatial, and feature dimensions of a convolution.
Args:
lhs_spec: a tuple of nonnegative integer dimension numbers containing
`(batch dimension, feature dimension, spatial dimensions...)`.
rhs_spec: a tuple of nonnegative integer dimension numbers containing
`(out feature dimension, in feature dimension, spatial dimensions...)`.
out_spec: a tuple of nonnegative integer dimension numbers containing
`(batch dimension, feature dimension, spatial dimensions...)`.
"""
def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers):
"""Convert from user spec of dimension_numbers to ConvDimensionNumbers.
"""Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`.
Args:
lhs_shape: tuple of nonnegative integers, shape of the convolution input.
@ -3469,8 +3577,8 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers):
convolution dimension number specification format in xla_client.py.
Returns:
A ConvDimensionNumbers namedtuple representing dimension_numbers in a
canonical form that is handled by internal lax functions.
A `ConvDimensionNumbers` object that represents `dimension_numbers` in the
canonical form used by lax functions.
"""
if len(lhs_shape) != len(rhs_shape):
msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}."