Add support for setting a dot product "algorithm" for lax.dot_general.

The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.

This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec.

Transposition is supported using the `transpose_algorithm` argument.

PiperOrigin-RevId: 678672686
This commit is contained in:
Dan Foreman-Mackey 2024-09-25 06:16:22 -07:00 committed by jax authors
parent eff00cc449
commit bc1e1a0220
10 changed files with 528 additions and 22 deletions

View File

@ -249,6 +249,7 @@ Argument classes
.. autoclass:: ConvDimensionNumbers
.. autoclass:: ConvGeneralDilatedDimensionNumbers
.. autoclass:: DotAlgorithm
.. autoclass:: GatherDimensionNumbers
.. autoclass:: GatherScatterMode
.. autoclass:: Precision

View File

@ -22,7 +22,7 @@ from functools import partial
import itertools
import math
import operator
from typing import Any, TypeVar, Union, cast as type_cast, overload
from typing import Any, NamedTuple, TypeVar, Union, cast as type_cast, overload
import warnings
import numpy as np
@ -709,8 +709,197 @@ PrecisionLike = Union[
None,
]
class DotAlgorithm(NamedTuple):
"""Specify the algorithm used for computing dot products.
When used as input to :func:`~jax.lax.dot_general`, this data structure is
used for controlling the properties of the algorithm used for computing the
dot product. This API controls the precision used for the computation, and
allows users to access hardware-specific accelerations.
Support for these algorithms is platform dependent, and using an unsupported
algorithm will raise a Python exception when the computation is compiled. The
algorithms that are known to be supported on at least some platforms are
listed in the :class:`~jax.lax.DotAlgorithm.Preset` enum, and these are a
good starting point for experimenting with this API.
A "dot algorithm" is specified by the following parameters:
* ``lhs_precision_type`` and ``rhs_precision_type``, the data types that the
LHS and RHS of the operation are rounded to.
* ``accumulation_type`` the data type used for accumulation.
* ``lhs_component_count``, ``rhs_component_count``, and
``num_primitive_operations`` apply to algorithms that decompose the LHS
and/or RHS into multiple components and execute multiple operations on
those values, usually to emulate a higher precision. For algorithms with no
decomposition, these values should be set to ``1``.
* ``allow_imprecise_accumulation`` to specify if accumulation in lower
precision is permitted for some steps (e.g.
``CUBLASLT_MATMUL_DESC_FAST_ACCUM``).
The `StableHLO spec <https://openxla.org/stablehlo/spec#dot_general>`_ for
the dot operation doesn't require that the precision types be the same as the
storage types for the inputs or outputs, but some plaforms may require that
these types match. Furthermore, the return type of
:func:`~jax.lax.dot_general` is always defined by the ``accumulation_type``
parameter of the input algorithm, if specified.
Examples:
Accumulate two 16-bit floats using a 32-bit float accumulator:
>>> algorithm = DotAlgorithm(
... lhs_precision_type=np.float16,
... rhs_precision_type=np.float16,
... accumulation_type=np.float32,
... )
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float32)
Or, equivalently, using a preset:
>>> algorithm = DotAlgorithm.Preset.F16_F16_F32
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
array([ 1., 4., 9., 16.], dtype=float32)
"""
lhs_precision_type: DTypeLike
rhs_precision_type: DTypeLike
accumulation_type: DTypeLike
lhs_component_count: int = 1
rhs_component_count: int = 1
num_primitive_operations: int = 1
allow_imprecise_accumulation: bool = False
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm:
del lhs_dtype, rhs_dtype # unused
return hlo.DotAlgorithm.get(
mlir.dtype_to_ir_type(dtypes.dtype(self.lhs_precision_type)),
mlir.dtype_to_ir_type(dtypes.dtype(self.rhs_precision_type)),
mlir.dtype_to_ir_type(dtypes.dtype(self.accumulation_type)),
self.lhs_component_count,
self.rhs_component_count,
self.num_primitive_operations,
self.allow_imprecise_accumulation,
)
# mypy doesn't currently support nested classes in a NamedTuple definition.
class Preset(enum.Enum): # type: ignore[misc]
DEFAULT = 0
ANY_F8_ANY_F8_F32 = 1
ANY_F8_ANY_F8_F32_FAST_ACCUM = 2
F16_F16_F16 = 3
F16_F16_F32 = 4
BF16_BF16_BF16 = 5
BF16_BF16_F32 = 6
BF16_BF16_F32_X3 = 7
BF16_BF16_F32_X6 = 8
TF32_TF32_F32 = 9
TF32_TF32_F32_X3 = 10
F32_F32_F32 = 11
F64_F64_F64 = 12
def __repr__(self) -> str:
return f'{self.__class__.__name__}.{self.name}'
def __str__(self) -> str:
return self.name
@property
def accumulation_type(self) -> DTypeLike:
match self:
case DotAlgorithm.Preset.DEFAULT:
raise TypeError(
"The default dot algorithm does not have an accumulation type.")
case DotAlgorithm.Preset.F16_F16_F16:
return np.float16
case DotAlgorithm.Preset.BF16_BF16_BF16:
return dtypes.bfloat16
case DotAlgorithm.Preset.F64_F64_F64:
return np.float64
case _:
return np.float32
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
if self == DotAlgorithm.Preset.DEFAULT:
return None
if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM):
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz),
np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz))
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.")
lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype))
rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype))
acc = ir.F32Type.get()
return hlo.DotAlgorithm.get(
lhs, rhs, acc, 1, 1, 1,
self == DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
else:
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
f64 = ir.F64Type.get()
bf16 = ir.BF16Type.get()
tf32 = ir.FloatTF32Type.get()
match self:
case DotAlgorithm.Preset.F16_F16_F16:
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
case DotAlgorithm.Preset.F16_F16_F32:
return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.BF16_BF16_BF16:
return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False)
case DotAlgorithm.Preset.BF16_BF16_F32:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.BF16_BF16_F32_X3:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
case DotAlgorithm.Preset.BF16_BF16_F32_X6:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
case DotAlgorithm.Preset.TF32_TF32_F32:
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.TF32_TF32_F32_X3:
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False)
case DotAlgorithm.Preset.F32_F32_F32:
return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False)
case DotAlgorithm.Preset.F64_F64_F64:
return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False)
case _:
raise NotImplementedError("unreachable")
DotAlgorithmLike = Union[
DotAlgorithm,
DotAlgorithm.Preset,
str,
None,
]
_DotAlgorithmLike = Union[
DotAlgorithm,
DotAlgorithm.Preset,
None,
]
DotTransposeAlgorithmLike = Union[
DotAlgorithmLike,
tuple[DotAlgorithmLike, DotAlgorithmLike],
]
DotTransposeAlgorithm = tuple[_DotAlgorithmLike, _DotAlgorithmLike]
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
preferred_element_type: DTypeLike | None = None,
algorithm: DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
Wraps XLA's `Dot
@ -729,6 +918,17 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
algorithm: Optional. Specify the algorithm used for accumulating the dot
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
cannot be used with ``precision`` or ``preferred_element_type``.
transpose_algorithm: Optional. This allows specifying the algorithm used when
this operation is transposed, typically as part of reverse-mode automatic
differentiation. This argument can either be a single
:class:`~jax.lax.DotAlgorithm` or a tuple of two
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
algorithm for transposing the LHS and RHS, respectively.
``transpose_algorithm`` must be explicitly specified when transposing a
dot product where a specific ``algorithm`` was used on the forward pass.
Returns:
An array containing the product.
@ -736,7 +936,9 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
precision=precision,
preferred_element_type=preferred_element_type)
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm)
else:
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
lhs.shape, rhs.shape))
@ -747,7 +949,9 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
preferred_element_type: DTypeLike | None = None,
algorithm: DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
"""General dot product/contraction operator.
Wraps XLA's `DotGeneral
@ -774,6 +978,17 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
algorithm: Optional. Specify the algorithm used for accumulating the dot
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
cannot be used with ``precision`` or ``preferred_element_type``.
transpose_algorithm: Optional. This allows specifying the algorithm used when
this operation is transposed, typically as part of reverse-mode automatic
differentiation. This argument can either be a single
:class:`~jax.lax.DotAlgorithm` or a tuple of two
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
algorithm for transposing the LHS and RHS, respectively.
``transpose_algorithm`` must be explicitly specified when transposing a
dot product where a specific ``algorithm`` was used on the forward pass.
Returns:
An array whose first dimensions are the (shared) batch dimensions, followed by
@ -791,7 +1006,9 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(cdims, bdims),
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type)
preferred_element_type=preferred_element_type,
algorithm=canonicalize_dot_algorithm(algorithm),
transpose_algorithm=canonicalize_dot_transpose_algorithm(transpose_algorithm))
def ragged_dot(
@ -2838,7 +3055,9 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
for d in (lhs_contracting, lhs_batch)):
@ -2914,7 +3133,10 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
del dimension_numbers, precision # unused
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
# primitive_util.h's HigherPrecisionType, e.g.
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
@ -2936,6 +3158,21 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
result_dtype = lhs.dtype
if transpose_algorithm is not None and algorithm is None:
raise ValueError(
"When the algorithm argument to dot_general is None, the "
"transpose_algorithm argument is unused and must also be None.")
if algorithm is not None and algorithm != DotAlgorithm.Preset.DEFAULT:
if preferred_element_type is not None:
raise ValueError(
"The preferred_element_type and algorithm arguments to dot_general "
"cannot both be specified.")
# This is used to ensure that the output type is equal to the accumulation
# type whenever an algorithm is specified.
preferred_element_type = algorithm.accumulation_type
return _maybe_upcast(result_dtype, preferred_element_type)
def _bit_width(d):
@ -2959,6 +3196,8 @@ def _maybe_upcast(result_dtype, preferred_element_type):
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None,
swap_ans=False):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim
@ -2971,20 +3210,35 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
dims = ((ans_y, y_kept), (ans_batch, y_batch))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
if algorithm is not None:
if transpose_algorithm is None or transpose_algorithm[0] is None:
raise ValueError(
"When a dot_general algorithm is specified on the forward pass, "
"transpose_algorithm must be specified for the backward pass.")
lhs_alg, rhs_alg = transpose_algorithm
transpose_algorithm = (algorithm, rhs_alg)
algorithm = lhs_alg
x_bar = transpose(dot_general(g, y, dims, precision=precision,
preferred_element_type=preferred_element_type),
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm),
tuple(out_axes))
if x_bar.dtype != x.aval.dtype:
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
return x_bar
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
transpose_algorithm = None if transpose_algorithm is None else (
transpose_algorithm[1], transpose_algorithm[0])
y_bar = _dot_general_transpose_lhs(
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
preferred_element_type=preferred_element_type,
preferred_element_type=preferred_element_type, algorithm=algorithm,
transpose_algorithm=transpose_algorithm,
swap_ans=True)
if y_bar.dtype != y.aval.dtype:
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
@ -2992,7 +3246,9 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
precision,
preferred_element_type: DTypeLike | None):
preferred_element_type: DTypeLike | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None):
lhs, rhs = batched_args
lbd, rbd = batch_dims
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
@ -3018,7 +3274,9 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
rhs_shape = np.shape(rhs)
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type)
preferred_element_type=preferred_element_type,
algorithm=algorithm,
transpose_algorithm=transpose_algorithm)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
@ -3115,9 +3373,19 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr:
[hlo.PrecisionAttr.get(str(p)) for p in full_precision])
def dot_algorithm_attr(algorithm: _DotAlgorithmLike, lhs_dtype: DTypeLike,
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
if algorithm is None:
return None
return algorithm._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: np.dtype | None,
algorithm: _DotAlgorithmLike = None,
transpose_algorithm: DotTransposeAlgorithm | None = None,
platform: str = "default"):
del transpose_algorithm # unused
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
@ -3158,13 +3426,30 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
rhs_batching_dimensions=list(rhs_batch),
lhs_contracting_dimensions=list(lhs_contracting),
rhs_contracting_dimensions=list(rhs_contracting))
if algorithm is not None and precision not in {
None, Precision.DEFAULT, (Precision.DEFAULT, Precision.DEFAULT)}:
raise ValueError(
"The dot_general precision must be None or DEFAULT when an algorithm "
"is specified.")
if jaxlib_version <= (0, 4, 33):
if algorithm is not None:
raise ValueError(
"The dot_general algorithm parameter is only supported for jaxlib "
"versions larger than 0.4.33.")
algorithm_kwargs = {}
else:
algorithm_kwargs = {"algorithm": dot_algorithm_attr(algorithm, lhs_dtype,
rhs_dtype)}
return [
hlo.dot_general(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
dot_dnums,
precision_config=precision_attr(precision))
precision_config=precision_attr(precision),
**algorithm_kwargs,
)
]
mlir.register_lowering(dot_general_p, _dot_general_lower)
@ -3189,11 +3474,13 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], []))
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type)
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
precision=precision, preferred_element_type=preferred_element_type,
algorithm=None, transpose_algorithm=None)
def _ragged_dot_jvp_rule(
@ -5387,6 +5674,29 @@ def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precisi
"a lax.Precision value or a tuple of two lax.Precision values or "
f"strings; got {precision}.")
def canonicalize_dot_algorithm(algorithm: DotAlgorithmLike) -> _DotAlgorithmLike:
if isinstance(algorithm, str):
algorithm = DotAlgorithm.Preset[algorithm]
if algorithm is None or algorithm == DotAlgorithm.Preset.DEFAULT:
return None
return algorithm
def canonicalize_dot_transpose_algorithm(
algorithm: DotTransposeAlgorithmLike) -> DotTransposeAlgorithm | None:
if algorithm is None:
return None
elif isinstance(algorithm, DotAlgorithm):
return (algorithm, algorithm)
elif isinstance(algorithm, tuple):
if len(algorithm) != 2:
raise ValueError(
"The transpose_algorithm argument must be a single value or a tuple "
f"of two values; got {algorithm}.")
return (canonicalize_dot_algorithm(algorithm[0]),
canonicalize_dot_algorithm(algorithm[1]))
algorithm = canonicalize_dot_algorithm(algorithm)
return (algorithm, algorithm)
def _balanced_eq(x, z, y):
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
select(_eq_meet(y, z), _twos(z), _ones(z)))

View File

@ -2094,8 +2094,10 @@ def _dot_general_lowering(
dimension_numbers,
precision,
preferred_element_type,
algorithm,
transpose_algorithm,
):
del preferred_element_type # Unused.
del preferred_element_type, algorithm, transpose_algorithm # Unused.
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
assert batch_dims == ((), ())

View File

@ -364,12 +364,15 @@ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: tuple[PrecisionType, PrecisionType] | None,
preferred_element_type: DType | None,
algorithm: Any, transpose_algorithm: Any,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
# Unused arguments.
del precision
del preferred_element_type
del algorithm
del transpose_algorithm
lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype(
lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)

View File

@ -2176,9 +2176,12 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, *, dimension_numbers,
precision: tuple[PrecisionType, PrecisionType] | None,
preferred_element_type: DType | None,
algorithm: Any, transpose_algorithm: Any,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
del algorithm, transpose_algorithm # unused
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
# failures. Since we are going to turn on native serializaton soon, wait
# until then to turn on this check.

View File

@ -609,7 +609,8 @@ mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
precision: None = None, preferred_element_type: None = None,
algorithm: None = None, transpose_algorithm: None = None) -> BCOO | Array:
"""A general contraction operation.
Args:
@ -620,6 +621,8 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
algorithm: unused
transpose_algorithm: unused
Returns:
An ndarray or BCOO-format sparse array containing the result. If both inputs
@ -627,7 +630,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
the result will be dense, of type ndarray.
"""
# TODO(jakevdp) make use of these?
del precision # unused
del precision, algorithm, transpose_algorithm # unused
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
dimension_numbers)
@ -1053,7 +1056,9 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
kwds = {'dimension_numbers': dimension_numbers,
'precision': None,
'preferred_element_type': None}
'preferred_element_type': None,
'algorithm': None,
'transpose_algorithm': None}
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
return A, B, indices

View File

@ -463,7 +463,9 @@ bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
precision: None = None,
preferred_element_type: None = None) -> Array:
preferred_element_type: None = None,
algorithm: None = None,
transpose_algorithm: None = None) -> Array:
"""A general contraction operation.
Args:
@ -474,13 +476,15 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
(lhs_batch_dims, rhs_batch_dims))`.
precision: unused
preferred_element_type: unused
algorithm: unused
transpose_algorithm: unused
Returns:
An ndarray or BCSR-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCSR. If either input is
dense, the result will be dense, of type ndarray.
"""
del precision # unused
del precision, algorithm, transpose_algorithm # unused
if isinstance(rhs, (np.ndarray, jax.Array)):
if isinstance(lhs, (np.ndarray, jax.Array)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,

View File

@ -113,4 +113,5 @@ def _dot_general_validated_shape(
rhs = core.ShapedArray(rhs_shape, np.float32)
return _dot_general_shape_rule(
lhs, rhs, dimension_numbers=dimension_numbers,
precision=None, preferred_element_type=None)
precision=None, preferred_element_type=None, algorithm=None,
transpose_algorithm=None)

View File

@ -19,6 +19,9 @@ from jax._src.lax.lax import (
DotDimensionNumbers as DotDimensionNumbers,
Precision as Precision,
PrecisionLike as PrecisionLike,
DotAlgorithm as DotAlgorithm,
DotAlgorithmLike as DotAlgorithmLike,
DotTransposeAlgorithmLike as DotTransposeAlgorithmLike,
RandomAlgorithm as RandomAlgorithm,
RoundingMethod as RoundingMethod,
abs as abs,

View File

@ -41,10 +41,12 @@ from jax._src import config
from jax._src import dtypes
from jax._src import lax_reference
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.lib import version as jaxlib_version
from jax._src.util import NumpyComplexWarning, safe_zip
from jax._src.tree_util import tree_map
@ -1041,6 +1043,178 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker)
@parameterized.parameters([
(algorithm, dtype)
for algorithm, test_dtypes in [
(lax.DotAlgorithm(
lhs_precision_type=np.float32,
rhs_precision_type=np.float32,
accumulation_type=np.float32,
lhs_component_count=1,
rhs_component_count=1,
num_primitive_operations=1,
allow_imprecise_accumulation=False,
), [np.float32]),
(lax.DotAlgorithm(
lhs_precision_type=np.float16,
rhs_precision_type=np.float16,
accumulation_type=np.float32,
), [np.float16]),
("F16_F16_F32", [np.float16]),
(lax.DotAlgorithm.Preset.DEFAULT, lax_test_util.float_dtypes),
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes),
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes),
(lax.DotAlgorithm.Preset.F16_F16_F16, [np.float16]),
(lax.DotAlgorithm.Preset.F16_F16_F32, [np.float16]),
(lax.DotAlgorithm.Preset.BF16_BF16_BF16, [dtypes.bfloat16]),
(lax.DotAlgorithm.Preset.BF16_BF16_F32, [dtypes.bfloat16]),
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, [np.float32]),
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, [np.float32]),
(lax.DotAlgorithm.Preset.TF32_TF32_F32, [np.float32]),
(lax.DotAlgorithm.Preset.TF32_TF32_F32_X3, [np.float32]),
(lax.DotAlgorithm.Preset.F32_F32_F32, [np.float32]),
(lax.DotAlgorithm.Preset.F64_F64_F64, [np.float64]),
] for dtype in test_dtypes
if jtu.dtypes.supported([dtype])
])
def testDotAlgorithm(self, algorithm, dtype):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
if jtu.test_device_matches(["gpu"]):
# GPU algorithm support is a little spotty. It is checked in
# xla/service/algorithm_util.cc and the logic is copied here.
if algorithm in {
lax.DotAlgorithm.Preset.F16_F16_F32,
lax.DotAlgorithm.Preset.TF32_TF32_F32,
lax.DotAlgorithm.Preset.BF16_BF16_F32,
lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, # Must have f32 input
lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, # Must have f32 input
}:
if not jtu.is_cuda_compute_capability_at_least("8.0"):
raise SkipTest(
f"The dot algorithm '{algorithm}' requires CUDA compute "
"capability >= 8.0.")
elif algorithm not in {
lax.DotAlgorithm.Preset.DEFAULT,
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
lax.DotAlgorithm.Preset.F32_F32_F32,
lax.DotAlgorithm.Preset.F64_F64_F64,
}:
raise SkipTest(
f"The dot algorithm '{algorithm}' is not supported on GPU.")
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
self._CompileAndCheck(partial(lax.dot, algorithm=algorithm), args_maker)
# Check that accumulation type sets the output type
output = lax.dot(*args_maker(), algorithm=algorithm)
algorithm = lax_internal.canonicalize_dot_algorithm(algorithm)
expected_dtype = dtype if algorithm is None else algorithm.accumulation_type
self.assertEqual(output.dtype, expected_dtype)
def testDotAlgorithmInvalidFloat8Type(self):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn)
with self.assertRaisesRegex(ValueError, "The dot algorithm"):
lax.dot(lhs, rhs, algorithm="ANY_F8_ANY_F8_F32")
@parameterized.parameters([
({"precision": lax.Precision.HIGHEST}, "The dot_general precision must be None or DEFAULT"),
({"preferred_element_type": np.float32}, "The preferred_element_type and algorithm arguments"),
])
def testDotAlgorithmInvalidParameters(self, kwargs, pattern):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
with self.assertRaisesRegex(ValueError, pattern):
lax.dot(lhs, rhs, algorithm="F32_F32_F32", **kwargs)
def testDotAlgorithmTransposeRequired(self):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
fun = partial(lax.dot, algorithm="F32_F32_F32")
out = fun(lhs, rhs)
_, vjp_fun = jax.vjp(fun, lhs, rhs)
with self.assertRaisesRegex(
ValueError, "When a dot_general algorithm is specified"):
vjp_fun(out)
@parameterized.parameters([
("F32_F32_F32", "F16_F16_F32"),
("F32_F32_F32", ("F16_F16_F32", "F64_F64_F64")),
])
def testDotAlgorithmTranspose(self, algorithm, transpose_algorithm):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jaxlib_version <= (0, 4, 33):
raise SkipTest(
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
def fun(x, y):
return lax.dot(x, y, algorithm=algorithm,
transpose_algorithm=transpose_algorithm)
algorithm_ = lax_internal.canonicalize_dot_algorithm(algorithm)
lhs_alg, rhs_alg = lax_internal.canonicalize_dot_transpose_algorithm(
transpose_algorithm)
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
out = fun(lhs, rhs)
def check_transpose_algorithm(f, arg, alg, trans_alg, trans_trans_alg):
fun_trans = jax.linear_transpose(f, arg)
jaxpr = jax.make_jaxpr(fun_trans)(out)
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr.eqns))
self.assertEqual(eqn.params["algorithm"], alg)
self.assertEqual(eqn.params["transpose_algorithm"], trans_alg)
fun_ = jax.linear_transpose(lambda x: fun_trans(x)[0], out)
jaxpr_ = jax.make_jaxpr(fun_)(arg)
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr_.eqns))
self.assertEqual(eqn.params["algorithm"], algorithm_)
# Note that transposing the RHS of a dot_general introduce extra
# transposes on the input and output, so we don't actually end up with
# the same `transpose_algorithm` parameter after 2 transposes.
self.assertEqual(eqn.params["transpose_algorithm"], trans_trans_alg)
check_transpose_algorithm(partial(fun, y=rhs), lhs, lhs_alg,
(algorithm_, rhs_alg), (lhs_alg, rhs_alg))
check_transpose_algorithm(partial(fun, lhs), rhs, rhs_alg,
(algorithm_, lhs_alg), (rhs_alg, lhs_alg))
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]],