Simplify and consolidate dot algorithm control in lax.

In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
This commit is contained in:
Dan Foreman-Mackey 2024-10-07 13:20:24 -07:00 committed by jax authors
parent 8473391467
commit 28bbbf894f
12 changed files with 416 additions and 440 deletions

View File

@ -250,6 +250,10 @@ Argument classes
.. autoclass:: ConvDimensionNumbers
.. autoclass:: ConvGeneralDilatedDimensionNumbers
.. autoclass:: DotAlgorithm
.. autoclass:: DotAlgorithmPreset
:members:
:undoc-members:
:member-order: bysource
.. autoclass:: GatherDimensionNumbers
.. autoclass:: GatherScatterMode
.. autoclass:: Precision

View File

@ -709,27 +709,19 @@ _precision_strings['fastest'] = Precision.DEFAULT
_precision_strings[None] = Precision.DEFAULT
PrecisionLike = Union[
str,
Precision,
tuple[str, str],
tuple[Precision, Precision],
None,
]
class DotAlgorithm(NamedTuple):
"""Specify the algorithm used for computing dot products.
When used as input to :func:`~jax.lax.dot_general`, this data structure is
used for controlling the properties of the algorithm used for computing the
dot product. This API controls the precision used for the computation, and
allows users to access hardware-specific accelerations.
When used to specify the ``precision`` input to :func:`~jax.lax.dot`,
:func:`~jax.lax.dot_general`, and other dot product functions, this data
structure is used for controlling the properties of the algorithm used for
computing the dot product. This API controls the precision used for the
computation, and allows users to access hardware-specific accelerations.
Support for these algorithms is platform dependent, and using an unsupported
algorithm will raise a Python exception when the computation is compiled. The
algorithms that are known to be supported on at least some platforms are
listed in the :class:`~jax.lax.DotAlgorithm.Preset` enum, and these are a
listed in the :class:`~jax.lax.DotAlgorithmPreset` enum, and these are a
good starting point for experimenting with this API.
A "dot algorithm" is specified by the following parameters:
@ -764,13 +756,24 @@ class DotAlgorithm(NamedTuple):
... )
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float32)
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float16)
Or, equivalently, using a preset:
>>> algorithm = DotAlgorithm.Preset.F16_F16_F32
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float16)
Presets can also be specified by name:
>>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float16)
The ``preferred_element_type`` parameter can be used to return the output
without downcasting the accumulation type:
>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float32)
"""
@ -795,50 +798,149 @@ class DotAlgorithm(NamedTuple):
self.allow_imprecise_accumulation,
)
# mypy doesn't currently support nested classes in a NamedTuple definition.
class Preset(enum.Enum): # type: ignore[misc]
DEFAULT = 0
ANY_F8_ANY_F8_F32 = 1
ANY_F8_ANY_F8_F32_FAST_ACCUM = 2
F16_F16_F16 = 3
F16_F16_F32 = 4
BF16_BF16_BF16 = 5
BF16_BF16_F32 = 6
BF16_BF16_F32_X3 = 7
BF16_BF16_F32_X6 = 8
TF32_TF32_F32 = 9
TF32_TF32_F32_X3 = 10
F32_F32_F32 = 11
F64_F64_F64 = 12
def __repr__(self) -> str:
return f'{self.__class__.__name__}.{self.name}'
class DotAlgorithmPreset(enum.Enum):
"""An enum of known algorithms for computing dot products.
def __str__(self) -> str:
return self.name
This ``Enum`` provides a named set of :class:`~jax.lax.DotAlgorithm` objects
that are known to be supported on at least platform. See the
:class:`~jax.lax.DotAlgorithm` documentation for more details about the
behavior of these algorithms.
@property
def accumulation_type(self) -> DTypeLike:
match self:
case DotAlgorithm.Preset.DEFAULT:
raise TypeError(
"The default dot algorithm does not have an accumulation type.")
case DotAlgorithm.Preset.F16_F16_F16:
return np.float16
case DotAlgorithm.Preset.BF16_BF16_BF16:
return dtypes.bfloat16
case DotAlgorithm.Preset.F64_F64_F64:
return np.float64
case _:
return np.float32
An algorithm can be selected from this list when calling :func:`~jax.lax.dot`,
:func:`~jax.lax.dot_general`, or most other JAX dot product functions, by
passing either a member of this ``Enum`` or it's name as a string using the
``precision`` argument.
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
if self == DotAlgorithm.Preset.DEFAULT:
For example, users can specify the preset using this ``Enum`` directly:
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float16)
or, equivalently, they can be specified by name:
>>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float16)
The names of the presets are typically ``LHS_RHS_ACCUM`` where ``LHS`` and
``RHS`` are the element types of the ``lhs`` and ``rhs`` inputs
respectively, and ``ACCUM`` is the element type of the accumulator. Some
presets have an extra suffix, and the meaning of each of these is
documented below. The supported presets are:
"""
DEFAULT = enum.auto()
"""An algorithm will be selected based on input and output types."""
ANY_F8_ANY_F8_F32 = enum.auto()
"""Accepts any float8 input types and accumulates into float32."""
ANY_F8_ANY_F8_F32_FAST_ACCUM = enum.auto()
"""Like ``ANY_F8_ANY_F8_F32``, but using faster accumulation with the cost
of lower accuracy.
"""
ANY_F8_ANY_F8_ANY = enum.auto()
"""Like ``ANY_F8_ANY_F8_F32``, but the accumulation type is controlled by
``preferred_element_type``.
"""
ANY_F8_ANY_F8_ANY_FAST_ACCUM = enum.auto()
"""Like ``ANY_F8_ANY_F8_F32_FAST_ACCUM``, but the accumulation type is
controlled by ``preferred_element_type``.
"""
F16_F16_F16 = enum.auto()
F16_F16_F32 = enum.auto()
BF16_BF16_BF16 = enum.auto()
BF16_BF16_F32 = enum.auto()
BF16_BF16_F32_X3 = enum.auto()
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
emulate higher precision.
"""
BF16_BF16_F32_X6 = enum.auto()
"""Like ``BF16_BF16_F32_X3``, but using 6 operations instead of 3."""
TF32_TF32_F32 = enum.auto()
TF32_TF32_F32_X3 = enum.auto()
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
emulate higher precision.
"""
F32_F32_F32 = enum.auto()
F64_F64_F64 = enum.auto()
def __repr__(self) -> str:
return f'{self.__class__.__name__}.{self.name}'
def __str__(self) -> str:
return self.name
@property
def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
match self:
case (
DotAlgorithmPreset.DEFAULT |
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
return None
case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32:
return np.float16
case (
DotAlgorithmPreset.BF16_BF16_BF16 |
DotAlgorithmPreset.BF16_BF16_F32
):
# These algorithms support either f32 or bf32 input storage types.
# If either of those types are provided as input, we use the provided
# type. If not, we explicitly cast to bfloat16.
return (dtypes.bfloat16, np.float32)
case DotAlgorithmPreset.F64_F64_F64:
return np.float64
case _:
return np.float32
if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM):
@property
def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
return self.lhs_precision_type
@property
def accumulation_type(self) -> DTypeLike | None:
match self:
case (
DotAlgorithmPreset.DEFAULT |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
return None
case DotAlgorithmPreset.F16_F16_F16:
return np.float16
case DotAlgorithmPreset.BF16_BF16_BF16:
return dtypes.bfloat16
case DotAlgorithmPreset.F64_F64_F64:
return np.float64
case _:
return np.float32
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
f64 = ir.F64Type.get()
bf16 = ir.BF16Type.get()
tf32 = ir.FloatTF32Type.get()
match self:
case (
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz),
@ -853,65 +955,53 @@ class DotAlgorithm(NamedTuple):
acc = ir.F32Type.get()
return hlo.DotAlgorithm.get(
lhs, rhs, acc, 1, 1, 1,
self == DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
else:
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
f64 = ir.F64Type.get()
bf16 = ir.BF16Type.get()
tf32 = ir.FloatTF32Type.get()
match self:
case DotAlgorithm.Preset.F16_F16_F16:
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
case DotAlgorithm.Preset.F16_F16_F32:
return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.BF16_BF16_BF16:
return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False)
case DotAlgorithm.Preset.BF16_BF16_F32:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.BF16_BF16_F32_X3:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
case DotAlgorithm.Preset.BF16_BF16_F32_X6:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
case DotAlgorithm.Preset.TF32_TF32_F32:
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.TF32_TF32_F32_X3:
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False)
case DotAlgorithm.Preset.F32_F32_F32:
return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.F64_F64_F64:
return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False)
case _:
raise NotImplementedError("unreachable")
self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
case DotAlgorithmPreset.F16_F16_F16:
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
case DotAlgorithmPreset.F16_F16_F32:
return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False)
case DotAlgorithmPreset.BF16_BF16_BF16:
return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False)
case DotAlgorithmPreset.BF16_BF16_F32:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False)
case DotAlgorithmPreset.BF16_BF16_F32_X3:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
case DotAlgorithmPreset.BF16_BF16_F32_X6:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
case DotAlgorithmPreset.TF32_TF32_F32:
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
case DotAlgorithmPreset.TF32_TF32_F32_X3:
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False)
case DotAlgorithmPreset.F32_F32_F32:
return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False)
case DotAlgorithmPreset.F64_F64_F64:
return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False)
case _:
return None
DotAlgorithmLike = Union[
DotAlgorithm,
DotAlgorithm.Preset,
PrecisionLike = Union[
None,
str,
None,
]
_DotAlgorithmLike = Union[
Precision,
tuple[str, str],
tuple[Precision, Precision],
DotAlgorithm,
DotAlgorithm.Preset,
DotAlgorithmPreset,
]
CanonicalPrecision = Union[
None,
tuple[Precision, Precision],
DotAlgorithm,
DotAlgorithmPreset,
]
DotTransposeAlgorithmLike = Union[
DotAlgorithmLike,
tuple[DotAlgorithmLike, DotAlgorithmLike],
]
DotTransposeAlgorithm = tuple[_DotAlgorithmLike, _DotAlgorithmLike]
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
algorithm: DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
preferred_element_type: DTypeLike | None = None) -> Array:
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
Wraps XLA's `Dot
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
Wraps XLA's `Dot <https://www.tensorflow.org/xla/operation_semantics#dot>`_
operator.
For more general contraction, see the :func:`jax.lax.dot_general` operator.
@ -919,24 +1009,25 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
Args:
lhs: an array of dimension 1 or 2.
rhs: an array of dimension 1 or 2.
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
algorithm: Optional. Specify the algorithm used for accumulating the dot
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
cannot be used with ``precision`` or ``preferred_element_type``.
transpose_algorithm: Optional. This allows specifying the algorithm used when
this operation is transposed, typically as part of reverse-mode automatic
differentiation. This argument can either be a single
:class:`~jax.lax.DotAlgorithm` or a tuple of two
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
algorithm for transposing the LHS and RHS, respectively.
``transpose_algorithm`` must be explicitly specified when transposing a
dot product where a specific ``algorithm`` was used on the forward pass.
precision: Optional. This parameter controls the numerics of the
computation, and it can be one of the following:
- ``None``, which means the default precision for the current backend,
- a :class:`~jax.lax.Precision` enum value or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and
``rhs``, or
- a :class:`~jax.lax.DotAlgorithm` or a
:class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that
must be used to accumulate the dot product.
preferred_element_type: Optional. This parameter controls the data type
output by the dot product. By default, the output element type of this
operation will match the ``lhs`` and ``rhs`` input element types under
the usual type promotion rules. Setting ``preferred_element_type`` to a
specific ``dtype`` will mean that the operation returns that element type.
When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or
:class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides
a hint to the compiler to accumulate the dot product using this data type.
Returns:
An array containing the product.
@ -944,9 +1035,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
precision=precision,
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm)
preferred_element_type=preferred_element_type)
else:
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
lhs.shape, rhs.shape))
@ -957,9 +1046,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
algorithm: DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
preferred_element_type: DTypeLike | None = None) -> Array:
"""General dot product/contraction operator.
Wraps XLA's `DotGeneral
@ -978,29 +1065,31 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
lhs: an array
rhs: an array
dimension_numbers: a tuple of tuples of sequences of ints of the form
``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))``
precision: Optional. Either ``None``, which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
algorithm: Optional. Specify the algorithm used for accumulating the dot
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
cannot be used with ``precision`` or ``preferred_element_type``.
transpose_algorithm: Optional. This allows specifying the algorithm used when
this operation is transposed, typically as part of reverse-mode automatic
differentiation. This argument can either be a single
:class:`~jax.lax.DotAlgorithm` or a tuple of two
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
algorithm for transposing the LHS and RHS, respectively.
``transpose_algorithm`` must be explicitly specified when transposing a
dot product where a specific ``algorithm`` was used on the forward pass.
``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
rhs_batch_dims))``
precision: Optional. This parameter controls the numerics of the
computation, and it can be one of the following:
- ``None``, which means the default precision for the current backend,
- a :class:`~jax.lax.Precision` enum value or a tuple of two
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and
``rhs``, or
- a :class:`~jax.lax.DotAlgorithm` or a
:class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that
must be used to accumulate the dot product.
preferred_element_type: Optional. This parameter controls the data type
output by the dot product. By default, the output element type of this
operation will match the ``lhs`` and ``rhs`` input element types under
the usual type promotion rules. Setting ``preferred_element_type`` to a
specific ``dtype`` will mean that the operation returns that element type.
When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or
:class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides
a hint to the compiler to accumulate the dot product using this data type.
Returns:
An array whose first dimensions are the (shared) batch dimensions, followed by
the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
An array whose first dimensions are the (shared) batch dimensions, followed
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
non-contracting/non-batch dimensions.
"""
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
@ -1014,9 +1103,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(cdims, bdims),
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type,
algorithm=canonicalize_dot_algorithm(algorithm),
transpose_algorithm=canonicalize_dot_transpose_algorithm(transpose_algorithm))
preferred_element_type=preferred_element_type)
def ragged_dot(
@ -3063,9 +3150,7 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
preferred_element_type: DTypeLike | None):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
for d in (lhs_contracting, lhs_batch)):
@ -3141,10 +3226,8 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
del dimension_numbers, precision # unused
preferred_element_type: DTypeLike | None):
del dimension_numbers # unused
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
# primitive_util.h's HigherPrecisionType, e.g.
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
@ -3165,23 +3248,9 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
raise TypeError(
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
result_dtype = lhs.dtype
if transpose_algorithm is not None and algorithm is None:
raise ValueError(
"When the algorithm argument to dot_general is None, the "
"transpose_algorithm argument is unused and must also be None.")
if algorithm is not None and algorithm != DotAlgorithm.Preset.DEFAULT:
if preferred_element_type is not None:
raise ValueError(
"The preferred_element_type and algorithm arguments to dot_general "
"cannot both be specified.")
# This is used to ensure that the output type is equal to the accumulation
# type whenever an algorithm is specified.
preferred_element_type = algorithm.accumulation_type
return _maybe_upcast(result_dtype, preferred_element_type)
has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
return _maybe_upcast(result_dtype, preferred_element_type,
check_bit_width=not has_algorithm)
def _bit_width(d):
if dtypes.issubdtype(d, np.inexact): return dtypes.finfo(d).bits
@ -3189,12 +3258,12 @@ def _bit_width(d):
elif d == np.dtype('bool'): return 1
else: assert False, d # should be unreachable, open an issue!
def _maybe_upcast(result_dtype, preferred_element_type):
def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
# replicates the logic in shape_inference.cc's MaybeUpcast
if (preferred_element_type is None or
result_dtype == preferred_element_type):
return result_dtype
if (not dtypes.issubdtype(result_dtype, np.floating) and
if (check_bit_width and not dtypes.issubdtype(result_dtype, np.floating) and
_bit_width(preferred_element_type) < _bit_width(result_dtype)):
raise TypeError("`preferred_element_type` must not be narrower than the "
"original type, got preferred_element_type of "
@ -3204,8 +3273,6 @@ def _maybe_upcast(result_dtype, preferred_element_type):
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None,
swap_ans=False):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim
@ -3218,36 +3285,20 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
dims = ((ans_y, y_kept), (ans_batch, y_batch))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
if algorithm is not None:
if transpose_algorithm is None or transpose_algorithm[0] is None:
raise ValueError(
"When a dot_general algorithm is specified on the forward pass, "
"transpose_algorithm must be specified for the backward pass.")
lhs_alg, rhs_alg = transpose_algorithm
transpose_algorithm = (algorithm, rhs_alg)
algorithm = lhs_alg
x_bar = transpose(dot_general(g, y, dims, precision=precision,
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm),
preferred_element_type=preferred_element_type),
tuple(out_axes))
if x_bar.dtype != x.aval.dtype:
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
return x_bar
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
preferred_element_type: DTypeLike | None):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
transpose_algorithm = None if transpose_algorithm is None else (
transpose_algorithm[1], transpose_algorithm[0])
y_bar = _dot_general_transpose_lhs(
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
preferred_element_type=preferred_element_type, algorithm=algorithm,
transpose_algorithm=transpose_algorithm,
swap_ans=True)
preferred_element_type=preferred_element_type, swap_ans=True)
if y_bar.dtype != y.aval.dtype:
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
return y_bar
@ -3263,8 +3314,6 @@ def _dot_batch_rule(
dimension_numbers,
precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None,
**_,
):
@ -3298,8 +3347,6 @@ def _dot_batch_rule(
new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm,
)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
@ -3415,7 +3462,7 @@ pe.padding_rules[dot_general_p] = _dot_general_padding_rule
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
def precision_attr(precision: Precision) -> ir.ArrayAttr:
if precision is None:
if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
full_precision = (Precision.DEFAULT, Precision.DEFAULT)
elif not isinstance(precision, tuple):
full_precision = (precision, precision)
@ -3425,19 +3472,16 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr:
[hlo.PrecisionAttr.get(str(p)) for p in full_precision])
def dot_algorithm_attr(algorithm: _DotAlgorithmLike, lhs_dtype: DTypeLike,
def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
if algorithm is None:
if not isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
return None
return algorithm._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: np.dtype | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None,
platform: str = "default"):
del transpose_algorithm # unused
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
@ -3446,63 +3490,87 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
lhs_aval, rhs_aval = ctx.avals_in
lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype
aval_out, = ctx.avals_out
accumulation_aval = aval_out
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
# TODO(b/...): JAX's dot_general primitive accepts the same input dtype
# combinations that are accepted in XLA's shape_inference.cc (the canonical
# reference for the HLO type system), but actually different XLA platforms
# fail on codegen for different accepted cases. To handle those cases, we
# insert ConvertOps on the input, in a platform-dependent way.
if lhs_dtype != rhs_dtype:
if platform == "tpu":
handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or
dtypes.issubdtype(dt, np.integer))
if not (handled(lhs_dtype) and handled(rhs_dtype)):
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
else: # cpu and gpu
# Do not convert mixed fp8 types to output type.
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
dot_dnums = hlo.DotDimensionNumbers.get(
lhs_batching_dimensions=list(lhs_batch),
rhs_batching_dimensions=list(rhs_batch),
lhs_contracting_dimensions=list(lhs_contracting),
rhs_contracting_dimensions=list(rhs_contracting))
if algorithm is not None and precision not in {
None, Precision.DEFAULT, (Precision.DEFAULT, Precision.DEFAULT)}:
raise ValueError(
"The dot_general precision must be None or DEFAULT when an algorithm "
"is specified.")
if jaxlib_version <= (0, 4, 33):
if algorithm is not None:
algorithm_kwarg = {}
if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
# The CPU backend silently ignores the algorithm spec, so we check here to
# make sure that the selected algorithm is supported. We could be a little
# bit more liberal here (any algorithm where the input and output types
# match and all the other parameters have default values should work), but
# it's probably sufficient to just check the presets here.
if platform == "cpu" and precision not in {
DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16,
DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64,
}:
raise ValueError(
"The dot_general algorithm parameter is only supported for jaxlib "
"versions larger than 0.4.33.")
algorithm_kwargs = {}
f"The precision '{precision}' is not supported by dot_general on CPU")
# If an explicit algorithm was specified, we always cast the input types to
# the correct types.
def maybe_convert_dtype(operand, operand_aval, target_dtype):
if target_dtype is None:
return operand, operand_aval.dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
if any(operand_aval.dtype == d for d in target_dtype):
return operand, operand_aval.dtype
aval = core.ShapedArray(operand_aval.shape, target_dtype[0])
return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0]
lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type)
rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type)
accumulation_type = precision.accumulation_type
if accumulation_type is not None:
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type)
if precision != DotAlgorithmPreset.DEFAULT:
algorithm_kwarg = {
"algorithm": dot_algorithm_attr(precision, lhs_dtype, rhs_dtype)
}
else:
algorithm_kwargs = {"algorithm": dot_algorithm_attr(algorithm, lhs_dtype,
rhs_dtype)}
return [
hlo.dot_general(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
dot_dnums,
precision_config=precision_attr(precision),
**algorithm_kwargs,
)
]
# TODO(b/...): JAX's dot_general primitive accepts the same input dtype
# combinations that are accepted in XLA's shape_inference.cc (the canonical
# reference for the HLO type system), but actually different XLA platforms
# fail on codegen for different accepted cases. To handle those cases, we
# insert ConvertOps on the input, in a platform-dependent way.
if lhs_dtype != rhs_dtype:
if platform == "tpu":
handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or
dtypes.issubdtype(dt, np.integer))
if not (handled(lhs_dtype) and handled(rhs_dtype)):
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
else: # cpu and gpu
# Do not convert mixed fp8 types to output type.
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
result = hlo.dot_general(
mlir.aval_to_ir_type(accumulation_aval),
lhs,
rhs,
dot_dnums,
precision_config=precision_attr(precision),
**algorithm_kwarg,
)
if accumulation_aval.dtype != aval_out.dtype:
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
return [result]
mlir.register_lowering(dot_general_p, _dot_general_lower)
@ -3556,8 +3624,7 @@ def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision, preferred_element_type=preferred_element_type,
algorithm=None, transpose_algorithm=None)
precision=precision, preferred_element_type=preferred_element_type)
def _ragged_dot_jvp_rule(
@ -3680,12 +3747,7 @@ def _ragged_dot_invoke_prim(
new_dimension_numbers,
precision,
preferred_element_type,
algorithm,
transpose_algorithm,
):
assert algorithm is None
assert transpose_algorithm is None
return ragged_dot(
lhs,
rhs,
@ -5779,7 +5841,7 @@ def remaining(original, *removed_lists):
return [i for i in original if i not in removed]
def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None:
def canonicalize_precision(precision: PrecisionLike) -> CanonicalPrecision:
"""Turns an API precision specification into a pair of enumeration values.
The API can take the precision as a string, or int, and either as a single
@ -5789,56 +5851,44 @@ def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precisi
if config.default_matmul_precision.value is None:
return None
try:
return (
Precision(config.default_matmul_precision.value),
Precision(config.default_matmul_precision.value),
)
except TypeError:
return canonicalize_precision(config.default_matmul_precision.value)
except ValueError:
raise ValueError(
"jax_default_matmul_precision flag must be set to None or a value in "
f"{list(_precision_strings)}, but got {config.default_matmul_precision.value}"
"jax_default_matmul_precision flag must be set to None, a value in "
f"{list(_precision_strings)}, or the name of a lax.DotAlgorithmPreset, "
f"but got {config.default_matmul_precision.value}"
) from None
elif isinstance(precision, str) and precision in _precision_strings:
return Precision(precision), Precision(precision)
elif isinstance(precision, str):
if precision in _precision_strings:
return Precision(precision), Precision(precision)
else:
try:
return DotAlgorithmPreset[precision]
except KeyError:
pass
elif isinstance(precision, Precision):
return precision, precision
elif isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
return precision
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(p, Precision) for p in precision)):
return type_cast(tuple[Precision, Precision], precision)
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(s, str) for s in precision)):
s1, s2 = precision
s1, s2 = type_cast(tuple[str, str], precision)
p1 = type_cast(tuple[Precision, Precision], canonicalize_precision(s1))[0]
p2 = type_cast(tuple[Precision, Precision], canonicalize_precision(s2))[0]
return (p1, p2)
else:
raise ValueError(
f"Precision argument must be None, a string in {list(_precision_strings)}, "
"a lax.Precision value or a tuple of two lax.Precision values or "
f"strings; got {precision}.")
raise ValueError(
"Precision argument must be one of:\n"
"- None,\n"
f"- a string in {list(_precision_strings)},\n"
"- a lax.Precision value,\n"
"- a tuple of two lax.Precision values or strings,\n"
"- a lax.DotAlgorithmPreset or the name of one of these presets, or\n"
"- a lax.DotAlgorithm value;\n"
f"but got {precision}.")
def canonicalize_dot_algorithm(algorithm: DotAlgorithmLike) -> _DotAlgorithmLike:
if isinstance(algorithm, str):
algorithm = DotAlgorithm.Preset[algorithm]
if algorithm is None or algorithm == DotAlgorithm.Preset.DEFAULT:
return None
return algorithm
def canonicalize_dot_transpose_algorithm(
algorithm: DotTransposeAlgorithmLike) -> DotTransposeAlgorithm | None:
if algorithm is None:
return None
elif isinstance(algorithm, DotAlgorithm):
return (algorithm, algorithm)
elif isinstance(algorithm, tuple):
if len(algorithm) != 2:
raise ValueError(
"The transpose_algorithm argument must be a single value or a tuple "
f"of two values; got {algorithm}.")
return (canonicalize_dot_algorithm(algorithm[0]),
canonicalize_dot_algorithm(algorithm[1]))
algorithm = canonicalize_dot_algorithm(algorithm)
return (algorithm, algorithm)
def _balanced_eq(x, z, y):
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),

View File

@ -205,11 +205,13 @@ def conv_general_dilated_local(
lhs_array = lax.asarray(lhs)
c_precision = lax.canonicalize_precision(precision)
lhs_precision = (
c_precision[0]
if (isinstance(c_precision, tuple) and len(c_precision) == 2)
else c_precision
)
if c_precision is None:
lhs_precision = None
elif isinstance(c_precision, tuple) and len(c_precision) == 2:
lhs_precision = c_precision[0]
else:
raise ValueError(
f"Unsupported precision for conv_general_dilated_local: {precision}")
patches = conv_general_dilated_patches(
lhs=lhs_array,

View File

@ -2090,10 +2090,8 @@ def _dot_general_lowering(
dimension_numbers,
precision,
preferred_element_type,
algorithm,
transpose_algorithm,
):
del preferred_element_type, algorithm, transpose_algorithm # Unused.
del preferred_element_type # Unused.
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
assert batch_dims == ((), ())

View File

@ -364,15 +364,12 @@ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: tuple[PrecisionType, PrecisionType] | None,
preferred_element_type: DType | None,
algorithm: Any, transpose_algorithm: Any,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
# Unused arguments.
del precision
del preferred_element_type
del algorithm
del transpose_algorithm
lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype(
lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)

View File

@ -2183,12 +2183,9 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: tuple[PrecisionType, PrecisionType] | None,
preferred_element_type: DType | None,
algorithm: Any, transpose_algorithm: Any,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
del algorithm, transpose_algorithm # unused
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
# failures. Since we are going to turn on native serializaton soon, wait
# until then to turn on this check.

View File

@ -84,7 +84,7 @@ TODO:
"""
from functools import partial
import math
from typing import Any
from typing import cast, Any
import jax
import numpy as np
@ -228,10 +228,10 @@ def _lstm_cudnn_allow_tf32(precision: lax.PrecisionLike) -> bool:
#
# but we prefer to still invoke it here for consistency
precision = lax.canonicalize_precision(precision)
if precision is None:
if precision is None or not (isinstance(precision, tuple) and len(precision) == 2):
return True
# cuDNN allows only one precision specifier per RNN op
precision, _ = precision
precision, _ = cast(tuple[lax.Precision, lax.Precision], precision)
if precision == lax.Precision.HIGHEST:
return False
elif precision == lax.Precision.HIGH:

View File

@ -609,8 +609,7 @@ mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
precision: None = None, preferred_element_type: None = None,
algorithm: None = None, transpose_algorithm: None = None) -> BCOO | Array:
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
"""A general contraction operation.
Args:
@ -621,8 +620,6 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
algorithm: unused
transpose_algorithm: unused
Returns:
An ndarray or BCOO-format sparse array containing the result. If both inputs
@ -630,7 +627,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
the result will be dense, of type ndarray.
"""
# TODO(jakevdp) make use of these?
del precision, algorithm, transpose_algorithm # unused
del precision # unused
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
dimension_numbers)
@ -1056,9 +1053,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
kwds = {'dimension_numbers': dimension_numbers,
'precision': None,
'preferred_element_type': None,
'algorithm': None,
'transpose_algorithm': None}
'preferred_element_type': None}
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
return A, B, indices

View File

@ -463,9 +463,7 @@ bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None,
algorithm: None = None,
transpose_algorithm: None = None) -> Array:
preferred_element_type: None = None) -> Array:
"""A general contraction operation.
Args:
@ -476,15 +474,13 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
algorithm: unused
transpose_algorithm: unused
Returns:
An ndarray or BCSR-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision, algorithm, transpose_algorithm # unused
del precision # unused
if isinstance(rhs, (np.ndarray, jax.Array)):
if isinstance(lhs, (np.ndarray, jax.Array)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,

View File

@ -113,5 +113,4 @@ def _dot_general_validated_shape(
rhs = core.ShapedArray(rhs_shape, np.float32)
return _dot_general_shape_rule(
lhs, rhs, dimension_numbers=dimension_numbers,
precision=None, preferred_element_type=None, algorithm=None,
transpose_algorithm=None)
precision=None, preferred_element_type=None)

View File

@ -20,8 +20,7 @@ from jax._src.lax.lax import (
Precision as Precision,
PrecisionLike as PrecisionLike,
DotAlgorithm as DotAlgorithm,
DotAlgorithmLike as DotAlgorithmLike,
DotTransposeAlgorithmLike as DotTransposeAlgorithmLike,
DotAlgorithmPreset as DotAlgorithmPreset,
RandomAlgorithm as RandomAlgorithm,
RoundingMethod as RoundingMethod,
abs as abs,

View File

@ -1061,19 +1061,19 @@ class LaxTest(jtu.JaxTestCase):
accumulation_type=np.float32,
), [np.float16]),
("F16_F16_F32", [np.float16]),
(lax.DotAlgorithm.Preset.DEFAULT, lax_test_util.float_dtypes),
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes),
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes),
(lax.DotAlgorithm.Preset.F16_F16_F16, [np.float16]),
(lax.DotAlgorithm.Preset.F16_F16_F32, [np.float16]),
(lax.DotAlgorithm.Preset.BF16_BF16_BF16, [dtypes.bfloat16]),
(lax.DotAlgorithm.Preset.BF16_BF16_F32, [dtypes.bfloat16]),
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, [np.float32]),
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, [np.float32]),
(lax.DotAlgorithm.Preset.TF32_TF32_F32, [np.float32]),
(lax.DotAlgorithm.Preset.TF32_TF32_F32_X3, [np.float32]),
(lax.DotAlgorithm.Preset.F32_F32_F32, [np.float32]),
(lax.DotAlgorithm.Preset.F64_F64_F64, [np.float64]),
(lax.DotAlgorithmPreset.DEFAULT, lax_test_util.float_dtypes),
(lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes),
(lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes),
(lax.DotAlgorithmPreset.F16_F16_F16, [np.float16]),
(lax.DotAlgorithmPreset.F16_F16_F32, [np.float16]),
(lax.DotAlgorithmPreset.BF16_BF16_BF16, [dtypes.bfloat16]),
(lax.DotAlgorithmPreset.BF16_BF16_F32, [dtypes.bfloat16]),
(lax.DotAlgorithmPreset.BF16_BF16_F32_X3, [np.float32]),
(lax.DotAlgorithmPreset.BF16_BF16_F32_X6, [np.float32]),
(lax.DotAlgorithmPreset.TF32_TF32_F32, [np.float32]),
(lax.DotAlgorithmPreset.TF32_TF32_F32_X3, [np.float32]),
(lax.DotAlgorithmPreset.F32_F32_F32, [np.float32]),
(lax.DotAlgorithmPreset.F64_F64_F64, [np.float64]),
] for dtype in test_dtypes
if jtu.dtypes.supported([dtype])
])
@ -1084,26 +1084,35 @@ class LaxTest(jtu.JaxTestCase):
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
if jtu.test_device_matches(["cpu"]):
if algorithm not in {
lax.DotAlgorithmPreset.DEFAULT,
lax.DotAlgorithmPreset.F16_F16_F16,
lax.DotAlgorithmPreset.F32_F32_F32,
lax.DotAlgorithmPreset.F64_F64_F64,
}:
raise SkipTest(
f"The dot algorithm '{algorithm}' is not supported on CPU.")
if jtu.test_device_matches(["gpu"]):
# GPU algorithm support is a little spotty. It is checked in
# xla/service/algorithm_util.cc and the logic is copied here.
if algorithm in {
lax.DotAlgorithm.Preset.F16_F16_F32,
lax.DotAlgorithm.Preset.TF32_TF32_F32,
lax.DotAlgorithm.Preset.BF16_BF16_F32,
lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, # Must have f32 input
lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, # Must have f32 input
lax.DotAlgorithmPreset.F16_F16_F32,
lax.DotAlgorithmPreset.TF32_TF32_F32,
lax.DotAlgorithmPreset.BF16_BF16_F32,
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
}:
if not jtu.is_cuda_compute_capability_at_least("8.0"):
raise SkipTest(
f"The dot algorithm '{algorithm}' requires CUDA compute "
"capability >= 8.0.")
elif algorithm not in {
lax.DotAlgorithm.Preset.DEFAULT,
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
lax.DotAlgorithm.Preset.F32_F32_F32,
lax.DotAlgorithm.Preset.F64_F64_F64,
lax.DotAlgorithmPreset.DEFAULT,
lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32,
lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
lax.DotAlgorithmPreset.F32_F32_F32,
lax.DotAlgorithmPreset.F64_F64_F64,
}:
raise SkipTest(
f"The dot algorithm '{algorithm}' is not supported on GPU.")
@ -1111,12 +1120,8 @@ class LaxTest(jtu.JaxTestCase):
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CompileAndCheck(partial(lax.dot, algorithm=algorithm), args_maker)
# Check that accumulation type sets the output type
output = lax.dot(*args_maker(), algorithm=algorithm)
algorithm = lax_internal.canonicalize_dot_algorithm(algorithm)
expected_dtype = dtype if algorithm is None else algorithm.accumulation_type
self.assertEqual(output.dtype, expected_dtype)
self._CompileAndCheck(partial(lax.dot, precision=algorithm), args_maker)
self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype)
def testDotAlgorithmInvalidFloat8Type(self):
if xla_bridge.using_pjrt_c_api():
@ -1125,95 +1130,29 @@ class LaxTest(jtu.JaxTestCase):
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
if jtu.test_device_matches(["cpu"]):
raise SkipTest("Not supported on CPU.")
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn)
with self.assertRaisesRegex(ValueError, "The dot algorithm"):
lax.dot(lhs, rhs, algorithm="ANY_F8_ANY_F8_F32")
lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32")
@parameterized.parameters([
({"precision": lax.Precision.HIGHEST}, "The dot_general precision must be None or DEFAULT"),
({"preferred_element_type": np.float32}, "The preferred_element_type and algorithm arguments"),
])
def testDotAlgorithmInvalidParameters(self, kwargs, pattern):
def testDotAlgorithmCasting(self):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
def fun(lhs, rhs):
return lax.dot(lhs, rhs, precision="F32_F32_F32")
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
with self.assertRaisesRegex(ValueError, pattern):
lax.dot(lhs, rhs, algorithm="F32_F32_F32", **kwargs)
def testDotAlgorithmTransposeRequired(self):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
fun = partial(lax.dot, algorithm="F32_F32_F32")
out = fun(lhs, rhs)
_, vjp_fun = jax.vjp(fun, lhs, rhs)
with self.assertRaisesRegex(
ValueError, "When a dot_general algorithm is specified"):
vjp_fun(out)
@parameterized.parameters([
("F32_F32_F32", "F16_F16_F32"),
("F32_F32_F32", ("F16_F16_F32", "F64_F64_F64")),
])
def testDotAlgorithmTranspose(self, algorithm, transpose_algorithm):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
def fun(x, y):
return lax.dot(x, y, algorithm=algorithm,
transpose_algorithm=transpose_algorithm)
algorithm_ = lax_internal.canonicalize_dot_algorithm(algorithm)
lhs_alg, rhs_alg = lax_internal.canonicalize_dot_transpose_algorithm(
transpose_algorithm)
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
out = fun(lhs, rhs)
def check_transpose_algorithm(f, arg, alg, trans_alg, trans_trans_alg):
fun_trans = jax.linear_transpose(f, arg)
jaxpr = jax.make_jaxpr(fun_trans)(out)
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr.eqns))
self.assertEqual(eqn.params["algorithm"], alg)
self.assertEqual(eqn.params["transpose_algorithm"], trans_alg)
fun_ = jax.linear_transpose(lambda x: fun_trans(x)[0], out)
jaxpr_ = jax.make_jaxpr(fun_)(arg)
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr_.eqns))
self.assertEqual(eqn.params["algorithm"], algorithm_)
# Note that transposing the RHS of a dot_general introduce extra
# transposes on the input and output, so we don't actually end up with
# the same `transpose_algorithm` parameter after 2 transposes.
self.assertEqual(eqn.params["transpose_algorithm"], trans_trans_alg)
check_transpose_algorithm(partial(fun, y=rhs), lhs, lhs_alg,
(algorithm_, rhs_alg), (lhs_alg, rhs_alg))
check_transpose_algorithm(partial(fun, lhs), rhs, rhs_alg,
(algorithm_, lhs_alg), (rhs_alg, lhs_alg))
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
self.assertEqual(fun(lhs, rhs).dtype, np.float16)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)