mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Add support for setting a dot product "algorithm" for lax.dot_general.
The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API. This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec. Transposition is supported using the `transpose_algorithm` argument. PiperOrigin-RevId: 678672686
This commit is contained in:
parent
eff00cc449
commit
bc1e1a0220
@ -249,6 +249,7 @@ Argument classes
|
||||
|
||||
.. autoclass:: ConvDimensionNumbers
|
||||
.. autoclass:: ConvGeneralDilatedDimensionNumbers
|
||||
.. autoclass:: DotAlgorithm
|
||||
.. autoclass:: GatherDimensionNumbers
|
||||
.. autoclass:: GatherScatterMode
|
||||
.. autoclass:: Precision
|
||||
|
@ -22,7 +22,7 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, TypeVar, Union, cast as type_cast, overload
|
||||
from typing import Any, NamedTuple, TypeVar, Union, cast as type_cast, overload
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -709,8 +709,197 @@ PrecisionLike = Union[
|
||||
None,
|
||||
]
|
||||
|
||||
|
||||
class DotAlgorithm(NamedTuple):
|
||||
"""Specify the algorithm used for computing dot products.
|
||||
|
||||
When used as input to :func:`~jax.lax.dot_general`, this data structure is
|
||||
used for controlling the properties of the algorithm used for computing the
|
||||
dot product. This API controls the precision used for the computation, and
|
||||
allows users to access hardware-specific accelerations.
|
||||
|
||||
Support for these algorithms is platform dependent, and using an unsupported
|
||||
algorithm will raise a Python exception when the computation is compiled. The
|
||||
algorithms that are known to be supported on at least some platforms are
|
||||
listed in the :class:`~jax.lax.DotAlgorithm.Preset` enum, and these are a
|
||||
good starting point for experimenting with this API.
|
||||
|
||||
A "dot algorithm" is specified by the following parameters:
|
||||
|
||||
* ``lhs_precision_type`` and ``rhs_precision_type``, the data types that the
|
||||
LHS and RHS of the operation are rounded to.
|
||||
* ``accumulation_type`` the data type used for accumulation.
|
||||
* ``lhs_component_count``, ``rhs_component_count``, and
|
||||
``num_primitive_operations`` apply to algorithms that decompose the LHS
|
||||
and/or RHS into multiple components and execute multiple operations on
|
||||
those values, usually to emulate a higher precision. For algorithms with no
|
||||
decomposition, these values should be set to ``1``.
|
||||
* ``allow_imprecise_accumulation`` to specify if accumulation in lower
|
||||
precision is permitted for some steps (e.g.
|
||||
``CUBLASLT_MATMUL_DESC_FAST_ACCUM``).
|
||||
|
||||
The `StableHLO spec <https://openxla.org/stablehlo/spec#dot_general>`_ for
|
||||
the dot operation doesn't require that the precision types be the same as the
|
||||
storage types for the inputs or outputs, but some plaforms may require that
|
||||
these types match. Furthermore, the return type of
|
||||
:func:`~jax.lax.dot_general` is always defined by the ``accumulation_type``
|
||||
parameter of the input algorithm, if specified.
|
||||
|
||||
Examples:
|
||||
|
||||
Accumulate two 16-bit floats using a 32-bit float accumulator:
|
||||
|
||||
>>> algorithm = DotAlgorithm(
|
||||
... lhs_precision_type=np.float16,
|
||||
... rhs_precision_type=np.float16,
|
||||
... accumulation_type=np.float32,
|
||||
... )
|
||||
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float32)
|
||||
|
||||
Or, equivalently, using a preset:
|
||||
|
||||
>>> algorithm = DotAlgorithm.Preset.F16_F16_F32
|
||||
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float32)
|
||||
"""
|
||||
|
||||
lhs_precision_type: DTypeLike
|
||||
rhs_precision_type: DTypeLike
|
||||
accumulation_type: DTypeLike
|
||||
lhs_component_count: int = 1
|
||||
rhs_component_count: int = 1
|
||||
num_primitive_operations: int = 1
|
||||
allow_imprecise_accumulation: bool = False
|
||||
|
||||
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
|
||||
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm:
|
||||
del lhs_dtype, rhs_dtype # unused
|
||||
return hlo.DotAlgorithm.get(
|
||||
mlir.dtype_to_ir_type(dtypes.dtype(self.lhs_precision_type)),
|
||||
mlir.dtype_to_ir_type(dtypes.dtype(self.rhs_precision_type)),
|
||||
mlir.dtype_to_ir_type(dtypes.dtype(self.accumulation_type)),
|
||||
self.lhs_component_count,
|
||||
self.rhs_component_count,
|
||||
self.num_primitive_operations,
|
||||
self.allow_imprecise_accumulation,
|
||||
)
|
||||
|
||||
# mypy doesn't currently support nested classes in a NamedTuple definition.
|
||||
class Preset(enum.Enum): # type: ignore[misc]
|
||||
DEFAULT = 0
|
||||
ANY_F8_ANY_F8_F32 = 1
|
||||
ANY_F8_ANY_F8_F32_FAST_ACCUM = 2
|
||||
F16_F16_F16 = 3
|
||||
F16_F16_F32 = 4
|
||||
BF16_BF16_BF16 = 5
|
||||
BF16_BF16_F32 = 6
|
||||
BF16_BF16_F32_X3 = 7
|
||||
BF16_BF16_F32_X6 = 8
|
||||
TF32_TF32_F32 = 9
|
||||
TF32_TF32_F32_X3 = 10
|
||||
F32_F32_F32 = 11
|
||||
F64_F64_F64 = 12
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}.{self.name}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def accumulation_type(self) -> DTypeLike:
|
||||
match self:
|
||||
case DotAlgorithm.Preset.DEFAULT:
|
||||
raise TypeError(
|
||||
"The default dot algorithm does not have an accumulation type.")
|
||||
case DotAlgorithm.Preset.F16_F16_F16:
|
||||
return np.float16
|
||||
case DotAlgorithm.Preset.BF16_BF16_BF16:
|
||||
return dtypes.bfloat16
|
||||
case DotAlgorithm.Preset.F64_F64_F64:
|
||||
return np.float64
|
||||
case _:
|
||||
return np.float32
|
||||
|
||||
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
|
||||
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
|
||||
if self == DotAlgorithm.Preset.DEFAULT:
|
||||
return None
|
||||
|
||||
if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
|
||||
DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM):
|
||||
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
|
||||
np.dtype(dtypes.float8_e4m3fn),
|
||||
np.dtype(dtypes.float8_e4m3fnuz),
|
||||
np.dtype(dtypes.float8_e5m2),
|
||||
np.dtype(dtypes.float8_e5m2fnuz))
|
||||
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
|
||||
raise ValueError(
|
||||
f"The dot algorithm '{self}' requires both inputs to have float8 "
|
||||
f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.")
|
||||
lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype))
|
||||
rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype))
|
||||
acc = ir.F32Type.get()
|
||||
return hlo.DotAlgorithm.get(
|
||||
lhs, rhs, acc, 1, 1, 1,
|
||||
self == DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
|
||||
|
||||
else:
|
||||
f16 = ir.F16Type.get()
|
||||
f32 = ir.F32Type.get()
|
||||
f64 = ir.F64Type.get()
|
||||
bf16 = ir.BF16Type.get()
|
||||
tf32 = ir.FloatTF32Type.get()
|
||||
match self:
|
||||
case DotAlgorithm.Preset.F16_F16_F16:
|
||||
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.F16_F16_F32:
|
||||
return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_BF16:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_F32:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_F32_X3:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_F32_X6:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
|
||||
case DotAlgorithm.Preset.TF32_TF32_F32:
|
||||
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.TF32_TF32_F32_X3:
|
||||
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False)
|
||||
case DotAlgorithm.Preset.F32_F32_F32:
|
||||
return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.F64_F64_F64:
|
||||
return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False)
|
||||
case _:
|
||||
raise NotImplementedError("unreachable")
|
||||
|
||||
|
||||
DotAlgorithmLike = Union[
|
||||
DotAlgorithm,
|
||||
DotAlgorithm.Preset,
|
||||
str,
|
||||
None,
|
||||
]
|
||||
_DotAlgorithmLike = Union[
|
||||
DotAlgorithm,
|
||||
DotAlgorithm.Preset,
|
||||
None,
|
||||
]
|
||||
DotTransposeAlgorithmLike = Union[
|
||||
DotAlgorithmLike,
|
||||
tuple[DotAlgorithmLike, DotAlgorithmLike],
|
||||
]
|
||||
DotTransposeAlgorithm = tuple[_DotAlgorithmLike, _DotAlgorithmLike]
|
||||
|
||||
|
||||
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
algorithm: DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
|
||||
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
||||
|
||||
Wraps XLA's `Dot
|
||||
@ -729,6 +918,17 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional. Either ``None``, which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
algorithm: Optional. Specify the algorithm used for accumulating the dot
|
||||
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
|
||||
cannot be used with ``precision`` or ``preferred_element_type``.
|
||||
transpose_algorithm: Optional. This allows specifying the algorithm used when
|
||||
this operation is transposed, typically as part of reverse-mode automatic
|
||||
differentiation. This argument can either be a single
|
||||
:class:`~jax.lax.DotAlgorithm` or a tuple of two
|
||||
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
|
||||
algorithm for transposing the LHS and RHS, respectively.
|
||||
``transpose_algorithm`` must be explicitly specified when transposing a
|
||||
dot product where a specific ``algorithm`` was used on the forward pass.
|
||||
|
||||
Returns:
|
||||
An array containing the product.
|
||||
@ -736,7 +936,9 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
|
||||
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm)
|
||||
else:
|
||||
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
|
||||
lhs.shape, rhs.shape))
|
||||
@ -747,7 +949,9 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
||||
|
||||
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
algorithm: DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
|
||||
"""General dot product/contraction operator.
|
||||
|
||||
Wraps XLA's `DotGeneral
|
||||
@ -774,6 +978,17 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
preferred_element_type: Optional. Either ``None``, which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
algorithm: Optional. Specify the algorithm used for accumulating the dot
|
||||
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
|
||||
cannot be used with ``precision`` or ``preferred_element_type``.
|
||||
transpose_algorithm: Optional. This allows specifying the algorithm used when
|
||||
this operation is transposed, typically as part of reverse-mode automatic
|
||||
differentiation. This argument can either be a single
|
||||
:class:`~jax.lax.DotAlgorithm` or a tuple of two
|
||||
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
|
||||
algorithm for transposing the LHS and RHS, respectively.
|
||||
``transpose_algorithm`` must be explicitly specified when transposing a
|
||||
dot product where a specific ``algorithm`` was used on the forward pass.
|
||||
|
||||
Returns:
|
||||
An array whose first dimensions are the (shared) batch dimensions, followed by
|
||||
@ -791,7 +1006,9 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
return dot_general_p.bind(lhs, rhs,
|
||||
dimension_numbers=(cdims, bdims),
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type)
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=canonicalize_dot_algorithm(algorithm),
|
||||
transpose_algorithm=canonicalize_dot_transpose_algorithm(transpose_algorithm))
|
||||
|
||||
|
||||
def ragged_dot(
|
||||
@ -2838,7 +3055,9 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
|
||||
|
||||
|
||||
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
for d in (lhs_contracting, lhs_batch)):
|
||||
@ -2914,7 +3133,10 @@ def tuple_delete(tup, idx):
|
||||
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
del dimension_numbers, precision # unused
|
||||
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
|
||||
# primitive_util.h's HigherPrecisionType, e.g.
|
||||
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
|
||||
@ -2936,6 +3158,21 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
|
||||
result_dtype = lhs.dtype
|
||||
|
||||
if transpose_algorithm is not None and algorithm is None:
|
||||
raise ValueError(
|
||||
"When the algorithm argument to dot_general is None, the "
|
||||
"transpose_algorithm argument is unused and must also be None.")
|
||||
|
||||
if algorithm is not None and algorithm != DotAlgorithm.Preset.DEFAULT:
|
||||
if preferred_element_type is not None:
|
||||
raise ValueError(
|
||||
"The preferred_element_type and algorithm arguments to dot_general "
|
||||
"cannot both be specified.")
|
||||
|
||||
# This is used to ensure that the output type is equal to the accumulation
|
||||
# type whenever an algorithm is specified.
|
||||
preferred_element_type = algorithm.accumulation_type
|
||||
|
||||
return _maybe_upcast(result_dtype, preferred_element_type)
|
||||
|
||||
def _bit_width(d):
|
||||
@ -2959,6 +3196,8 @@ def _maybe_upcast(result_dtype, preferred_element_type):
|
||||
|
||||
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None,
|
||||
swap_ans=False):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = x.aval.ndim
|
||||
@ -2971,20 +3210,35 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
||||
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
||||
if algorithm is not None:
|
||||
if transpose_algorithm is None or transpose_algorithm[0] is None:
|
||||
raise ValueError(
|
||||
"When a dot_general algorithm is specified on the forward pass, "
|
||||
"transpose_algorithm must be specified for the backward pass.")
|
||||
lhs_alg, rhs_alg = transpose_algorithm
|
||||
transpose_algorithm = (algorithm, rhs_alg)
|
||||
algorithm = lhs_alg
|
||||
x_bar = transpose(dot_general(g, y, dims, precision=precision,
|
||||
preferred_element_type=preferred_element_type),
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm),
|
||||
tuple(out_axes))
|
||||
if x_bar.dtype != x.aval.dtype:
|
||||
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
|
||||
return x_bar
|
||||
|
||||
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
transpose_algorithm = None if transpose_algorithm is None else (
|
||||
transpose_algorithm[1], transpose_algorithm[0])
|
||||
y_bar = _dot_general_transpose_lhs(
|
||||
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
preferred_element_type=preferred_element_type, algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm,
|
||||
swap_ans=True)
|
||||
if y_bar.dtype != y.aval.dtype:
|
||||
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
|
||||
@ -2992,7 +3246,9 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
|
||||
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
lhs, rhs = batched_args
|
||||
lbd, rbd = batch_dims
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
@ -3018,7 +3274,9 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
rhs_shape = np.shape(rhs)
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm)
|
||||
result_batch_dim = batching.shape_as_bdim(
|
||||
result_stack_dim,
|
||||
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
|
||||
@ -3115,9 +3373,19 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr:
|
||||
[hlo.PrecisionAttr.get(str(p)) for p in full_precision])
|
||||
|
||||
|
||||
def dot_algorithm_attr(algorithm: _DotAlgorithmLike, lhs_dtype: DTypeLike,
|
||||
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
|
||||
if algorithm is None:
|
||||
return None
|
||||
return algorithm._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
|
||||
|
||||
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision, preferred_element_type: np.dtype | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None,
|
||||
platform: str = "default"):
|
||||
del transpose_algorithm # unused
|
||||
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
||||
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
||||
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
||||
@ -3158,13 +3426,30 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
rhs_batching_dimensions=list(rhs_batch),
|
||||
lhs_contracting_dimensions=list(lhs_contracting),
|
||||
rhs_contracting_dimensions=list(rhs_contracting))
|
||||
|
||||
if algorithm is not None and precision not in {
|
||||
None, Precision.DEFAULT, (Precision.DEFAULT, Precision.DEFAULT)}:
|
||||
raise ValueError(
|
||||
"The dot_general precision must be None or DEFAULT when an algorithm "
|
||||
"is specified.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
if algorithm is not None:
|
||||
raise ValueError(
|
||||
"The dot_general algorithm parameter is only supported for jaxlib "
|
||||
"versions larger than 0.4.33.")
|
||||
algorithm_kwargs = {}
|
||||
else:
|
||||
algorithm_kwargs = {"algorithm": dot_algorithm_attr(algorithm, lhs_dtype,
|
||||
rhs_dtype)}
|
||||
return [
|
||||
hlo.dot_general(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
lhs,
|
||||
rhs,
|
||||
dot_dnums,
|
||||
precision_config=precision_attr(precision))
|
||||
precision_config=precision_attr(precision),
|
||||
**algorithm_kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
mlir.register_lowering(dot_general_p, _dot_general_lower)
|
||||
@ -3189,11 +3474,13 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
|
||||
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], []))
|
||||
|
||||
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
||||
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
|
||||
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
|
||||
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
|
||||
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
|
||||
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
|
||||
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type)
|
||||
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision, preferred_element_type=preferred_element_type,
|
||||
algorithm=None, transpose_algorithm=None)
|
||||
|
||||
|
||||
def _ragged_dot_jvp_rule(
|
||||
@ -5387,6 +5674,29 @@ def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precisi
|
||||
"a lax.Precision value or a tuple of two lax.Precision values or "
|
||||
f"strings; got {precision}.")
|
||||
|
||||
def canonicalize_dot_algorithm(algorithm: DotAlgorithmLike) -> _DotAlgorithmLike:
|
||||
if isinstance(algorithm, str):
|
||||
algorithm = DotAlgorithm.Preset[algorithm]
|
||||
if algorithm is None or algorithm == DotAlgorithm.Preset.DEFAULT:
|
||||
return None
|
||||
return algorithm
|
||||
|
||||
def canonicalize_dot_transpose_algorithm(
|
||||
algorithm: DotTransposeAlgorithmLike) -> DotTransposeAlgorithm | None:
|
||||
if algorithm is None:
|
||||
return None
|
||||
elif isinstance(algorithm, DotAlgorithm):
|
||||
return (algorithm, algorithm)
|
||||
elif isinstance(algorithm, tuple):
|
||||
if len(algorithm) != 2:
|
||||
raise ValueError(
|
||||
"The transpose_algorithm argument must be a single value or a tuple "
|
||||
f"of two values; got {algorithm}.")
|
||||
return (canonicalize_dot_algorithm(algorithm[0]),
|
||||
canonicalize_dot_algorithm(algorithm[1]))
|
||||
algorithm = canonicalize_dot_algorithm(algorithm)
|
||||
return (algorithm, algorithm)
|
||||
|
||||
def _balanced_eq(x, z, y):
|
||||
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
|
||||
select(_eq_meet(y, z), _twos(z), _ones(z)))
|
||||
|
@ -2094,8 +2094,10 @@ def _dot_general_lowering(
|
||||
dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
algorithm,
|
||||
transpose_algorithm,
|
||||
):
|
||||
del preferred_element_type # Unused.
|
||||
del preferred_element_type, algorithm, transpose_algorithm # Unused.
|
||||
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
|
||||
assert batch_dims == ((), ())
|
||||
|
||||
|
@ -364,12 +364,15 @@ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
precision: tuple[PrecisionType, PrecisionType] | None,
|
||||
preferred_element_type: DType | None,
|
||||
algorithm: Any, transpose_algorithm: Any,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
# Unused arguments.
|
||||
del precision
|
||||
del preferred_element_type
|
||||
del algorithm
|
||||
del transpose_algorithm
|
||||
|
||||
lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype(
|
||||
lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)
|
||||
|
@ -2176,9 +2176,12 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
precision: tuple[PrecisionType, PrecisionType] | None,
|
||||
preferred_element_type: DType | None,
|
||||
algorithm: Any, transpose_algorithm: Any,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
del algorithm, transpose_algorithm # unused
|
||||
|
||||
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
|
||||
# failures. Since we are going to turn on native serializaton soon, wait
|
||||
# until then to turn on this check.
|
||||
|
@ -609,7 +609,8 @@ mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
|
||||
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
|
||||
|
||||
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
|
||||
precision: None = None, preferred_element_type: None = None,
|
||||
algorithm: None = None, transpose_algorithm: None = None) -> BCOO | Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -620,6 +621,8 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
|
||||
(lhs_batch_dims, rhs_batch_dims))`.
|
||||
precision: unused
|
||||
preferred_element_type: unused
|
||||
algorithm: unused
|
||||
transpose_algorithm: unused
|
||||
|
||||
Returns:
|
||||
An ndarray or BCOO-format sparse array containing the result. If both inputs
|
||||
@ -627,7 +630,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
|
||||
the result will be dense, of type ndarray.
|
||||
"""
|
||||
# TODO(jakevdp) make use of these?
|
||||
del precision # unused
|
||||
del precision, algorithm, transpose_algorithm # unused
|
||||
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
|
||||
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
|
||||
dimension_numbers)
|
||||
@ -1053,7 +1056,9 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
|
||||
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
|
||||
kwds = {'dimension_numbers': dimension_numbers,
|
||||
'precision': None,
|
||||
'preferred_element_type': None}
|
||||
'preferred_element_type': None,
|
||||
'algorithm': None,
|
||||
'transpose_algorithm': None}
|
||||
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
|
||||
return A, B, indices
|
||||
|
||||
|
@ -463,7 +463,9 @@ bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
|
||||
def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None,
|
||||
preferred_element_type: None = None) -> Array:
|
||||
preferred_element_type: None = None,
|
||||
algorithm: None = None,
|
||||
transpose_algorithm: None = None) -> Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -474,13 +476,15 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
(lhs_batch_dims, rhs_batch_dims))`.
|
||||
precision: unused
|
||||
preferred_element_type: unused
|
||||
algorithm: unused
|
||||
transpose_algorithm: unused
|
||||
|
||||
Returns:
|
||||
An ndarray or BCSR-format sparse array containing the result. If both inputs
|
||||
are sparse, the result will be sparse, of type BCSR. If either input is
|
||||
dense, the result will be dense, of type ndarray.
|
||||
"""
|
||||
del precision # unused
|
||||
del precision, algorithm, transpose_algorithm # unused
|
||||
if isinstance(rhs, (np.ndarray, jax.Array)):
|
||||
if isinstance(lhs, (np.ndarray, jax.Array)):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
|
@ -113,4 +113,5 @@ def _dot_general_validated_shape(
|
||||
rhs = core.ShapedArray(rhs_shape, np.float32)
|
||||
return _dot_general_shape_rule(
|
||||
lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
precision=None, preferred_element_type=None)
|
||||
precision=None, preferred_element_type=None, algorithm=None,
|
||||
transpose_algorithm=None)
|
||||
|
@ -19,6 +19,9 @@ from jax._src.lax.lax import (
|
||||
DotDimensionNumbers as DotDimensionNumbers,
|
||||
Precision as Precision,
|
||||
PrecisionLike as PrecisionLike,
|
||||
DotAlgorithm as DotAlgorithm,
|
||||
DotAlgorithmLike as DotAlgorithmLike,
|
||||
DotTransposeAlgorithmLike as DotTransposeAlgorithmLike,
|
||||
RandomAlgorithm as RandomAlgorithm,
|
||||
RoundingMethod as RoundingMethod,
|
||||
abs as abs,
|
||||
|
@ -41,10 +41,12 @@ from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax_reference
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.util import NumpyComplexWarning, safe_zip
|
||||
from jax._src.tree_util import tree_map
|
||||
|
||||
@ -1041,6 +1043,178 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
|
||||
self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker)
|
||||
|
||||
@parameterized.parameters([
|
||||
(algorithm, dtype)
|
||||
for algorithm, test_dtypes in [
|
||||
(lax.DotAlgorithm(
|
||||
lhs_precision_type=np.float32,
|
||||
rhs_precision_type=np.float32,
|
||||
accumulation_type=np.float32,
|
||||
lhs_component_count=1,
|
||||
rhs_component_count=1,
|
||||
num_primitive_operations=1,
|
||||
allow_imprecise_accumulation=False,
|
||||
), [np.float32]),
|
||||
(lax.DotAlgorithm(
|
||||
lhs_precision_type=np.float16,
|
||||
rhs_precision_type=np.float16,
|
||||
accumulation_type=np.float32,
|
||||
), [np.float16]),
|
||||
("F16_F16_F32", [np.float16]),
|
||||
(lax.DotAlgorithm.Preset.DEFAULT, lax_test_util.float_dtypes),
|
||||
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes),
|
||||
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes),
|
||||
(lax.DotAlgorithm.Preset.F16_F16_F16, [np.float16]),
|
||||
(lax.DotAlgorithm.Preset.F16_F16_F32, [np.float16]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_BF16, [dtypes.bfloat16]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_F32, [dtypes.bfloat16]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.TF32_TF32_F32, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.TF32_TF32_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.F32_F32_F32, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.F64_F64_F64, [np.float64]),
|
||||
] for dtype in test_dtypes
|
||||
if jtu.dtypes.supported([dtype])
|
||||
])
|
||||
def testDotAlgorithm(self, algorithm, dtype):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# GPU algorithm support is a little spotty. It is checked in
|
||||
# xla/service/algorithm_util.cc and the logic is copied here.
|
||||
if algorithm in {
|
||||
lax.DotAlgorithm.Preset.F16_F16_F32,
|
||||
lax.DotAlgorithm.Preset.TF32_TF32_F32,
|
||||
lax.DotAlgorithm.Preset.BF16_BF16_F32,
|
||||
lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, # Must have f32 input
|
||||
lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, # Must have f32 input
|
||||
}:
|
||||
if not jtu.is_cuda_compute_capability_at_least("8.0"):
|
||||
raise SkipTest(
|
||||
f"The dot algorithm '{algorithm}' requires CUDA compute "
|
||||
"capability >= 8.0.")
|
||||
elif algorithm not in {
|
||||
lax.DotAlgorithm.Preset.DEFAULT,
|
||||
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
|
||||
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
|
||||
lax.DotAlgorithm.Preset.F32_F32_F32,
|
||||
lax.DotAlgorithm.Preset.F64_F64_F64,
|
||||
}:
|
||||
raise SkipTest(
|
||||
f"The dot algorithm '{algorithm}' is not supported on GPU.")
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
self._CompileAndCheck(partial(lax.dot, algorithm=algorithm), args_maker)
|
||||
# Check that accumulation type sets the output type
|
||||
output = lax.dot(*args_maker(), algorithm=algorithm)
|
||||
algorithm = lax_internal.canonicalize_dot_algorithm(algorithm)
|
||||
expected_dtype = dtype if algorithm is None else algorithm.accumulation_type
|
||||
self.assertEqual(output.dtype, expected_dtype)
|
||||
|
||||
def testDotAlgorithmInvalidFloat8Type(self):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn)
|
||||
with self.assertRaisesRegex(ValueError, "The dot algorithm"):
|
||||
lax.dot(lhs, rhs, algorithm="ANY_F8_ANY_F8_F32")
|
||||
|
||||
@parameterized.parameters([
|
||||
({"precision": lax.Precision.HIGHEST}, "The dot_general precision must be None or DEFAULT"),
|
||||
({"preferred_element_type": np.float32}, "The preferred_element_type and algorithm arguments"),
|
||||
])
|
||||
def testDotAlgorithmInvalidParameters(self, kwargs, pattern):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
|
||||
with self.assertRaisesRegex(ValueError, pattern):
|
||||
lax.dot(lhs, rhs, algorithm="F32_F32_F32", **kwargs)
|
||||
|
||||
def testDotAlgorithmTransposeRequired(self):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
|
||||
fun = partial(lax.dot, algorithm="F32_F32_F32")
|
||||
out = fun(lhs, rhs)
|
||||
_, vjp_fun = jax.vjp(fun, lhs, rhs)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "When a dot_general algorithm is specified"):
|
||||
vjp_fun(out)
|
||||
|
||||
@parameterized.parameters([
|
||||
("F32_F32_F32", "F16_F16_F32"),
|
||||
("F32_F32_F32", ("F16_F16_F32", "F64_F64_F64")),
|
||||
])
|
||||
def testDotAlgorithmTranspose(self, algorithm, transpose_algorithm):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
def fun(x, y):
|
||||
return lax.dot(x, y, algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm)
|
||||
|
||||
algorithm_ = lax_internal.canonicalize_dot_algorithm(algorithm)
|
||||
lhs_alg, rhs_alg = lax_internal.canonicalize_dot_transpose_algorithm(
|
||||
transpose_algorithm)
|
||||
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
|
||||
out = fun(lhs, rhs)
|
||||
|
||||
def check_transpose_algorithm(f, arg, alg, trans_alg, trans_trans_alg):
|
||||
fun_trans = jax.linear_transpose(f, arg)
|
||||
jaxpr = jax.make_jaxpr(fun_trans)(out)
|
||||
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr.eqns))
|
||||
self.assertEqual(eqn.params["algorithm"], alg)
|
||||
self.assertEqual(eqn.params["transpose_algorithm"], trans_alg)
|
||||
|
||||
fun_ = jax.linear_transpose(lambda x: fun_trans(x)[0], out)
|
||||
jaxpr_ = jax.make_jaxpr(fun_)(arg)
|
||||
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr_.eqns))
|
||||
self.assertEqual(eqn.params["algorithm"], algorithm_)
|
||||
|
||||
# Note that transposing the RHS of a dot_general introduce extra
|
||||
# transposes on the input and output, so we don't actually end up with
|
||||
# the same `transpose_algorithm` parameter after 2 transposes.
|
||||
self.assertEqual(eqn.params["transpose_algorithm"], trans_trans_alg)
|
||||
|
||||
check_transpose_algorithm(partial(fun, y=rhs), lhs, lhs_alg,
|
||||
(algorithm_, rhs_alg), (lhs_alg, rhs_alg))
|
||||
check_transpose_algorithm(partial(fun, lhs), rhs, rhs_alg,
|
||||
(algorithm_, lhs_alg), (rhs_alg, lhs_alg))
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
||||
for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user