rocm_jax/jax/_src/lax/other.py
Peter Hawkins 6f7be3cf04 Define lax.Precision directly in Python, rather than inheriting from a C++ type in jaxlib.
Historically, we defined Precision to be an enum exported from jaxlib using pybind11, since that was the type the old XLA ComputationBuilder classes expected as input. But we build IR using StableHLO MLIR builders these days, and there's no reason for the JAX-level Precision type to match the XLA-internal one.

In a future change I plan to change the definition of Precision in jaxlib to be defined using nanobind instead of pybind11. Nanobind defines its enum classes to be final by default, which precludes this inheritance, and that's probably a good design decision by nanobind. But as discussed above, there's no good reason to inherit in the first place.

PiperOrigin-RevId: 612575404
2024-03-04 14:01:31 -08:00

239 lines
10 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import math
from typing import Any, Union, cast as type_cast
import jax
from jax._src.numpy import lax_numpy as jnp
from jax._src.lax import lax
from jax._src.lax import convolution
DType = Any
def conv_general_dilated_patches(
lhs: jax.typing.ArrayLike,
filter_shape: Sequence[int],
window_strides: Sequence[int],
padding: str | Sequence[tuple[int, int]],
lhs_dilation: Sequence[int] | None = None,
rhs_dilation: Sequence[int] | None = None,
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None,
precision: lax.Precision | None = None,
preferred_element_type: DType | None = None,
) -> jax.Array:
"""Extract patches subject to the receptive field of `conv_general_dilated`.
Runs the input through a convolution with given parameters. The kernel of the
convolution is constructed such that the output channel dimension `"C"`
contains flattened image patches, so instead a single `"C"` dimension
represents, for example, three dimensions `"chw"` collapsed. The order of
these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`,
where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"`
dimension is therefore the size of each patch, i.e.
`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where
`lhs_spec == dimension_numbers[0]`.
Docstring below adapted from `jax.lax.conv_general_dilated`.
See Also:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution
Args:
lhs: a rank `n+2` dimensional input array.
filter_shape: a sequence of `n` integers, representing the receptive window
spatial shape in the order as specified in
`rhs_spec = dimension_numbers[1]`.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
dimension_numbers: either `None`, or a 3-tuple
`(lhs_spec, rhs_spec, out_spec)`, where each element is a string
of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
A rank `n+2` array containing the flattened image patches in the output
channel (`"C"`) dimension. For example if
`dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension
numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to
the size of each patch
(`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).
"""
lhs_array = jnp.asarray(lhs)
filter_shape = tuple(filter_shape)
dimension_numbers = convolution.conv_dimension_numbers(
lhs_array.shape, (1, 1) + filter_shape, dimension_numbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
spatial_size = math.prod(filter_shape)
n_channels = lhs_array.shape[lhs_spec[1]]
# Move separate `lhs` spatial locations into separate `rhs` channels.
rhs = jnp.eye(spatial_size, dtype=lhs_array.dtype).reshape(filter_shape * 2)
rhs = rhs.reshape((spatial_size, 1) + filter_shape)
rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1))
rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))
out = convolution.conv_general_dilated(
lhs=lhs_array,
rhs=rhs,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
precision=None if precision is None else (precision,
lax.Precision.DEFAULT),
feature_group_count=n_channels,
preferred_element_type=preferred_element_type
)
return out
def conv_general_dilated_local(
lhs: jax.typing.ArrayLike,
rhs: jax.typing.ArrayLike,
window_strides: Sequence[int],
padding: str | Sequence[tuple[int, int]],
filter_shape: Sequence[int],
lhs_dilation: Sequence[int] | None = None,
rhs_dilation: Sequence[int] | None = None,
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None,
precision: lax.PrecisionLike = None
) -> jax.Array:
"""General n-dimensional unshared convolution operator with optional dilation.
Also known as locally connected layer, the operation is equivalent to
convolution with a separate (unshared) `rhs` kernel used at each output
spatial location. Docstring below adapted from `jax.lax.conv_general_dilated`.
See Also:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights. Unlike in regular
CNNs, its spatial coordinates (`H`, `W`, ...) correspond to output spatial
locations, while input spatial locations are fused with the input channel
locations in the single `I` dimension, in the order of
`"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where
`rhs_spec = dimension_numbers[1]`. For example, if `rhs_spec == "WHIO",
the unfolded kernel shape is
`"[output W][output H]{I[receptive window W][receptive window H]}O"`.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
filter_shape: a sequence of `n` integers, representing the receptive window
spatial shape in the order as specified in
`rhs_spec = dimension_numbers[1]`.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each input spatial dimension of `rhs`.
RHS dilation is also known as atrous convolution.
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
of length `n+2`.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
Returns:
An array containing the unshared convolution result.
In the string case of `dimension_numbers`, each character identifies by
position:
- the batch dimensions in `lhs`, `rhs`, and the output with the character
'N',
- the feature dimensions in `lhs` and the output with the character 'C',
- the input and output feature dimensions in rhs with the characters 'I'
and 'O' respectively, and
- spatial dimension correspondences between `lhs`, `rhs`, and the output using
any distinct characters.
For example, to indicate dimension numbers consistent with the `conv` function
with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
another example, to indicate dimension numbers consistent with the TensorFlow
Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
latter form of convolution dimension specification, window strides are
associated with spatial dimension character labels according to the order in
which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
is matched with the dimension corresponding to the first character
appearing in rhs_spec that is not `'I'` or `'O'`.
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
(for a 2D convolution).
"""
lhs_array = jnp.asarray(lhs)
c_precision = lax.canonicalize_precision(precision)
lhs_precision = type_cast(
Union[lax.Precision, None],
(c_precision[0]
if (isinstance(c_precision, tuple) and len(c_precision) == 2)
else c_precision))
patches = conv_general_dilated_patches(
lhs=lhs_array,
filter_shape=filter_shape,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
precision=lhs_precision
)
lhs_spec, rhs_spec, out_spec = convolution.conv_dimension_numbers(
lhs_array.shape, (1, 1) + tuple(filter_shape), dimension_numbers)
lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]]
lhs_b_dims = out_spec[2:]
rhs_b_dims = rhs_spec[2:]
rhs_b_dims = [rhs_b_dims[i] for i in sorted(range(len(rhs_b_dims)),
key=lambda k: lhs_b_dims[k])]
lhs_b_dims = sorted(lhs_b_dims)
dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims))
out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision)
out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1]))
return out