mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add documentation for more lax functions, notably concatenate and conv.
This commit is contained in:
parent
d2cabb5fee
commit
f4a73c127e
118
jax/lax.py
118
jax/lax.py
@ -360,11 +360,77 @@ def clamp(min, x, max):
|
||||
return clamp_p.bind(min, x, max)
|
||||
|
||||
def concatenate(operands, dimension):
|
||||
"""Concatenates a sequence of arrays along `dimension`.
|
||||
|
||||
Wraps XLA's `Concatenate
|
||||
<https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
|
||||
operator.
|
||||
|
||||
Args:
|
||||
operands: a sequence of arrays to concatenate. The arrays must have equal
|
||||
shapes, except in the `dimension` axis.
|
||||
dimension: the dimension along which to concatenate the arrays.
|
||||
|
||||
Returns:
|
||||
An array containing the concatenation.
|
||||
"""
|
||||
return concatenate_p.bind(*operands, dimension=dimension,
|
||||
operand_shapes=tuple(o.shape for o in operands))
|
||||
|
||||
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
|
||||
rhs_dilation=None, dimension_numbers=None):
|
||||
"""General n-dimensional convolution operator, with optional dilation.
|
||||
|
||||
Wraps XLA's `Conv
|
||||
<https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
|
||||
operator.
|
||||
|
||||
Args:
|
||||
lhs: a rank `n+2` dimensional input array.
|
||||
rhs: a rank `n+2` dimensional array of kernel weights.
|
||||
window_strides: a sequence of `n` integers, representing the inter-window
|
||||
strides.
|
||||
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
|
||||
`n` `(low, high)` integer pairs that give the padding to apply before and
|
||||
after each spatial dimension.
|
||||
lhs_dilation: `None`, or a sequence of `n` integers, giving the
|
||||
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
|
||||
is also known as transposed convolution.
|
||||
rhs_dilation: `None`, or a sequence of `n` integers, giving the
|
||||
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
|
||||
is also known as atrous convolution.
|
||||
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
|
||||
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
|
||||
of length `n+2`.
|
||||
|
||||
Returns:
|
||||
An array containing the convolution result.
|
||||
|
||||
In the string case of `dimension_numbers`, each character identifies by
|
||||
position:
|
||||
|
||||
- the batch dimensions in `lhs`, `rhs`, and the output with the character
|
||||
'N',
|
||||
- the feature dimensions in `lhs` and the output with the character 'C',
|
||||
- the input and output feature dimensions in rhs with the characters 'I'
|
||||
and 'O' respectively, and
|
||||
- spatial dimension correspondences between lhs, rhs, and the output using
|
||||
any distinct characters.
|
||||
|
||||
For example, to indicate dimension numbers
|
||||
consistent with the Conv operation with two spatial dimensions, one
|
||||
could use `('NCHW', 'OIHW', 'NCHW')`. As another example, to indicate
|
||||
dimension numbers consistent with the TensorFlow Conv2D operation, one
|
||||
could use `('NHWC', 'HWIO', 'NHWC')`. When using the latter form of
|
||||
convolution dimension specification, window strides are associated with
|
||||
spatial dimension character labels according to the order in which the
|
||||
labels appear in the `rhs_spec` string, so that `window_strides[0]` is
|
||||
matched with the dimension corresponding to the first character
|
||||
appearing in rhs_spec that is not `'I'` or `'O'`.
|
||||
|
||||
If `dimension_numbers` is `None`, the default is `(NCHW, OIHW, NCHW)` (for
|
||||
a 2D convolution).
|
||||
"""
|
||||
if type(dimension_numbers) is not ConvDimensionNumbers:
|
||||
dimension_numbers = conv_dimension_numbers(
|
||||
lhs.shape, rhs.shape, dimension_numbers)
|
||||
@ -727,11 +793,43 @@ def psum(x, axis_name):
|
||||
|
||||
|
||||
def conv(lhs, rhs, window_strides, padding):
|
||||
"""Convenience wrapper around `conv_general_dilated`.
|
||||
|
||||
Args:
|
||||
lhs: a rank `n+2` dimensional input array.
|
||||
rhs: a rank `n+2` dimensional array of kernel weights.
|
||||
window_strides: a sequence of `n` integers, representing the inter-window
|
||||
strides.
|
||||
padding: either the string `'SAME'`, the string `'VALID'`.
|
||||
|
||||
Returns:
|
||||
An array containing the convolution result.
|
||||
"""
|
||||
pads = padtype_to_pads(lhs.shape[2:], rhs.shape[2:], window_strides, padding)
|
||||
return conv_general_dilated(lhs, rhs, window_strides, padding)
|
||||
|
||||
def conv_with_general_padding(lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation):
|
||||
"""Convenience wrapper around `conv_general_dilated`.
|
||||
|
||||
Args:
|
||||
lhs: a rank `n+2` dimensional input array.
|
||||
rhs: a rank `n+2` dimensional array of kernel weights.
|
||||
window_strides: a sequence of `n` integers, representing the inter-window
|
||||
strides.
|
||||
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
|
||||
`n` `(low, high)` integer pairs that give the padding to apply before and
|
||||
after each spatial dimension.
|
||||
lhs_dilation: `None`, or a sequence of `n` integers, giving the
|
||||
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
|
||||
is also known as transposed convolution.
|
||||
rhs_dilation: `None`, or a sequence of `n` integers, giving the
|
||||
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
|
||||
is also known as atrous convolution.
|
||||
|
||||
Returns:
|
||||
An array containing the convolution result.
|
||||
"""
|
||||
return conv_general_dilated(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation)
|
||||
@ -3456,11 +3554,21 @@ def remaining(original, *removed_lists):
|
||||
# [batch dim, feature dim, spatial dims ...]
|
||||
# rhs_spec is a list containing:
|
||||
# [out feature dim, in feature dim, spatial dims ...]
|
||||
ConvDimensionNumbers = collections.namedtuple(
|
||||
"ConvDimensionNumbers", ["lhs_spec", "rhs_spec", "out_spec"])
|
||||
class ConvDimensionNumbers(collections.namedtuple(
|
||||
"ConvDimensionNumbers", ["lhs_spec", "rhs_spec", "out_spec"])):
|
||||
"""Describes batch, spatial, and feature dimensions of a convolution.
|
||||
|
||||
Args:
|
||||
lhs_spec: a tuple of nonnegative integer dimension numbers containing
|
||||
`(batch dimension, feature dimension, spatial dimensions...)`.
|
||||
rhs_spec: a tuple of nonnegative integer dimension numbers containing
|
||||
`(out feature dimension, in feature dimension, spatial dimensions...)`.
|
||||
out_spec: a tuple of nonnegative integer dimension numbers containing
|
||||
`(batch dimension, feature dimension, spatial dimensions...)`.
|
||||
"""
|
||||
|
||||
def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers):
|
||||
"""Convert from user spec of dimension_numbers to ConvDimensionNumbers.
|
||||
"""Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`.
|
||||
|
||||
Args:
|
||||
lhs_shape: tuple of nonnegative integers, shape of the convolution input.
|
||||
@ -3469,8 +3577,8 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers):
|
||||
convolution dimension number specification format in xla_client.py.
|
||||
|
||||
Returns:
|
||||
A ConvDimensionNumbers namedtuple representing dimension_numbers in a
|
||||
canonical form that is handled by internal lax functions.
|
||||
A `ConvDimensionNumbers` object that represents `dimension_numbers` in the
|
||||
canonical form used by lax functions.
|
||||
"""
|
||||
if len(lhs_shape) != len(rhs_shape):
|
||||
msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}."
|
||||
|
Loading…
x
Reference in New Issue
Block a user