mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases. The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested. Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected. To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.) With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`. Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this. One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach. PiperOrigin-RevId: 683302687
This commit is contained in:
parent
8473391467
commit
28bbbf894f
@ -250,6 +250,10 @@ Argument classes
|
||||
.. autoclass:: ConvDimensionNumbers
|
||||
.. autoclass:: ConvGeneralDilatedDimensionNumbers
|
||||
.. autoclass:: DotAlgorithm
|
||||
.. autoclass:: DotAlgorithmPreset
|
||||
:members:
|
||||
:undoc-members:
|
||||
:member-order: bysource
|
||||
.. autoclass:: GatherDimensionNumbers
|
||||
.. autoclass:: GatherScatterMode
|
||||
.. autoclass:: Precision
|
||||
|
@ -709,27 +709,19 @@ _precision_strings['fastest'] = Precision.DEFAULT
|
||||
_precision_strings[None] = Precision.DEFAULT
|
||||
|
||||
|
||||
PrecisionLike = Union[
|
||||
str,
|
||||
Precision,
|
||||
tuple[str, str],
|
||||
tuple[Precision, Precision],
|
||||
None,
|
||||
]
|
||||
|
||||
|
||||
class DotAlgorithm(NamedTuple):
|
||||
"""Specify the algorithm used for computing dot products.
|
||||
|
||||
When used as input to :func:`~jax.lax.dot_general`, this data structure is
|
||||
used for controlling the properties of the algorithm used for computing the
|
||||
dot product. This API controls the precision used for the computation, and
|
||||
allows users to access hardware-specific accelerations.
|
||||
When used to specify the ``precision`` input to :func:`~jax.lax.dot`,
|
||||
:func:`~jax.lax.dot_general`, and other dot product functions, this data
|
||||
structure is used for controlling the properties of the algorithm used for
|
||||
computing the dot product. This API controls the precision used for the
|
||||
computation, and allows users to access hardware-specific accelerations.
|
||||
|
||||
Support for these algorithms is platform dependent, and using an unsupported
|
||||
algorithm will raise a Python exception when the computation is compiled. The
|
||||
algorithms that are known to be supported on at least some platforms are
|
||||
listed in the :class:`~jax.lax.DotAlgorithm.Preset` enum, and these are a
|
||||
listed in the :class:`~jax.lax.DotAlgorithmPreset` enum, and these are a
|
||||
good starting point for experimenting with this API.
|
||||
|
||||
A "dot algorithm" is specified by the following parameters:
|
||||
@ -764,13 +756,24 @@ class DotAlgorithm(NamedTuple):
|
||||
... )
|
||||
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float32)
|
||||
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float16)
|
||||
|
||||
Or, equivalently, using a preset:
|
||||
|
||||
>>> algorithm = DotAlgorithm.Preset.F16_F16_F32
|
||||
>>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP
|
||||
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
|
||||
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float16)
|
||||
|
||||
Presets can also be specified by name:
|
||||
|
||||
>>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float16)
|
||||
|
||||
The ``preferred_element_type`` parameter can be used to return the output
|
||||
without downcasting the accumulation type:
|
||||
|
||||
>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float32)
|
||||
"""
|
||||
|
||||
@ -795,50 +798,149 @@ class DotAlgorithm(NamedTuple):
|
||||
self.allow_imprecise_accumulation,
|
||||
)
|
||||
|
||||
# mypy doesn't currently support nested classes in a NamedTuple definition.
|
||||
class Preset(enum.Enum): # type: ignore[misc]
|
||||
DEFAULT = 0
|
||||
ANY_F8_ANY_F8_F32 = 1
|
||||
ANY_F8_ANY_F8_F32_FAST_ACCUM = 2
|
||||
F16_F16_F16 = 3
|
||||
F16_F16_F32 = 4
|
||||
BF16_BF16_BF16 = 5
|
||||
BF16_BF16_F32 = 6
|
||||
BF16_BF16_F32_X3 = 7
|
||||
BF16_BF16_F32_X6 = 8
|
||||
TF32_TF32_F32 = 9
|
||||
TF32_TF32_F32_X3 = 10
|
||||
F32_F32_F32 = 11
|
||||
F64_F64_F64 = 12
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}.{self.name}'
|
||||
class DotAlgorithmPreset(enum.Enum):
|
||||
"""An enum of known algorithms for computing dot products.
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
This ``Enum`` provides a named set of :class:`~jax.lax.DotAlgorithm` objects
|
||||
that are known to be supported on at least platform. See the
|
||||
:class:`~jax.lax.DotAlgorithm` documentation for more details about the
|
||||
behavior of these algorithms.
|
||||
|
||||
@property
|
||||
def accumulation_type(self) -> DTypeLike:
|
||||
match self:
|
||||
case DotAlgorithm.Preset.DEFAULT:
|
||||
raise TypeError(
|
||||
"The default dot algorithm does not have an accumulation type.")
|
||||
case DotAlgorithm.Preset.F16_F16_F16:
|
||||
return np.float16
|
||||
case DotAlgorithm.Preset.BF16_BF16_BF16:
|
||||
return dtypes.bfloat16
|
||||
case DotAlgorithm.Preset.F64_F64_F64:
|
||||
return np.float64
|
||||
case _:
|
||||
return np.float32
|
||||
An algorithm can be selected from this list when calling :func:`~jax.lax.dot`,
|
||||
:func:`~jax.lax.dot_general`, or most other JAX dot product functions, by
|
||||
passing either a member of this ``Enum`` or it's name as a string using the
|
||||
``precision`` argument.
|
||||
|
||||
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
|
||||
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
|
||||
if self == DotAlgorithm.Preset.DEFAULT:
|
||||
For example, users can specify the preset using this ``Enum`` directly:
|
||||
|
||||
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
|
||||
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float16)
|
||||
|
||||
or, equivalently, they can be specified by name:
|
||||
|
||||
>>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP
|
||||
array([ 1., 4., 9., 16.], dtype=float16)
|
||||
|
||||
The names of the presets are typically ``LHS_RHS_ACCUM`` where ``LHS`` and
|
||||
``RHS`` are the element types of the ``lhs`` and ``rhs`` inputs
|
||||
respectively, and ``ACCUM`` is the element type of the accumulator. Some
|
||||
presets have an extra suffix, and the meaning of each of these is
|
||||
documented below. The supported presets are:
|
||||
"""
|
||||
DEFAULT = enum.auto()
|
||||
"""An algorithm will be selected based on input and output types."""
|
||||
|
||||
ANY_F8_ANY_F8_F32 = enum.auto()
|
||||
"""Accepts any float8 input types and accumulates into float32."""
|
||||
|
||||
ANY_F8_ANY_F8_F32_FAST_ACCUM = enum.auto()
|
||||
"""Like ``ANY_F8_ANY_F8_F32``, but using faster accumulation with the cost
|
||||
of lower accuracy.
|
||||
"""
|
||||
|
||||
ANY_F8_ANY_F8_ANY = enum.auto()
|
||||
"""Like ``ANY_F8_ANY_F8_F32``, but the accumulation type is controlled by
|
||||
``preferred_element_type``.
|
||||
"""
|
||||
|
||||
ANY_F8_ANY_F8_ANY_FAST_ACCUM = enum.auto()
|
||||
"""Like ``ANY_F8_ANY_F8_F32_FAST_ACCUM``, but the accumulation type is
|
||||
controlled by ``preferred_element_type``.
|
||||
"""
|
||||
|
||||
F16_F16_F16 = enum.auto()
|
||||
F16_F16_F32 = enum.auto()
|
||||
BF16_BF16_BF16 = enum.auto()
|
||||
BF16_BF16_F32 = enum.auto()
|
||||
BF16_BF16_F32_X3 = enum.auto()
|
||||
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
|
||||
emulate higher precision.
|
||||
"""
|
||||
|
||||
BF16_BF16_F32_X6 = enum.auto()
|
||||
"""Like ``BF16_BF16_F32_X3``, but using 6 operations instead of 3."""
|
||||
|
||||
TF32_TF32_F32 = enum.auto()
|
||||
TF32_TF32_F32_X3 = enum.auto()
|
||||
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
|
||||
emulate higher precision.
|
||||
"""
|
||||
|
||||
F32_F32_F32 = enum.auto()
|
||||
F64_F64_F64 = enum.auto()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}.{self.name}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
|
||||
match self:
|
||||
case (
|
||||
DotAlgorithmPreset.DEFAULT |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
|
||||
):
|
||||
return None
|
||||
case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32:
|
||||
return np.float16
|
||||
case (
|
||||
DotAlgorithmPreset.BF16_BF16_BF16 |
|
||||
DotAlgorithmPreset.BF16_BF16_F32
|
||||
):
|
||||
# These algorithms support either f32 or bf32 input storage types.
|
||||
# If either of those types are provided as input, we use the provided
|
||||
# type. If not, we explicitly cast to bfloat16.
|
||||
return (dtypes.bfloat16, np.float32)
|
||||
case DotAlgorithmPreset.F64_F64_F64:
|
||||
return np.float64
|
||||
case _:
|
||||
return np.float32
|
||||
|
||||
if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
|
||||
DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM):
|
||||
@property
|
||||
def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
|
||||
return self.lhs_precision_type
|
||||
|
||||
@property
|
||||
def accumulation_type(self) -> DTypeLike | None:
|
||||
match self:
|
||||
case (
|
||||
DotAlgorithmPreset.DEFAULT |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
|
||||
):
|
||||
return None
|
||||
case DotAlgorithmPreset.F16_F16_F16:
|
||||
return np.float16
|
||||
case DotAlgorithmPreset.BF16_BF16_BF16:
|
||||
return dtypes.bfloat16
|
||||
case DotAlgorithmPreset.F64_F64_F64:
|
||||
return np.float64
|
||||
case _:
|
||||
return np.float32
|
||||
|
||||
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
|
||||
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
|
||||
f16 = ir.F16Type.get()
|
||||
f32 = ir.F32Type.get()
|
||||
f64 = ir.F64Type.get()
|
||||
bf16 = ir.BF16Type.get()
|
||||
tf32 = ir.FloatTF32Type.get()
|
||||
match self:
|
||||
case (
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
|
||||
):
|
||||
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
|
||||
np.dtype(dtypes.float8_e4m3fn),
|
||||
np.dtype(dtypes.float8_e4m3fnuz),
|
||||
@ -853,65 +955,53 @@ class DotAlgorithm(NamedTuple):
|
||||
acc = ir.F32Type.get()
|
||||
return hlo.DotAlgorithm.get(
|
||||
lhs, rhs, acc, 1, 1, 1,
|
||||
self == DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
|
||||
|
||||
else:
|
||||
f16 = ir.F16Type.get()
|
||||
f32 = ir.F32Type.get()
|
||||
f64 = ir.F64Type.get()
|
||||
bf16 = ir.BF16Type.get()
|
||||
tf32 = ir.FloatTF32Type.get()
|
||||
match self:
|
||||
case DotAlgorithm.Preset.F16_F16_F16:
|
||||
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.F16_F16_F32:
|
||||
return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_BF16:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_F32:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_F32_X3:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
|
||||
case DotAlgorithm.Preset.BF16_BF16_F32_X6:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
|
||||
case DotAlgorithm.Preset.TF32_TF32_F32:
|
||||
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.TF32_TF32_F32_X3:
|
||||
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False)
|
||||
case DotAlgorithm.Preset.F32_F32_F32:
|
||||
return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False)
|
||||
case DotAlgorithm.Preset.F64_F64_F64:
|
||||
return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False)
|
||||
case _:
|
||||
raise NotImplementedError("unreachable")
|
||||
self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
|
||||
case DotAlgorithmPreset.F16_F16_F16:
|
||||
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
|
||||
case DotAlgorithmPreset.F16_F16_F32:
|
||||
return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False)
|
||||
case DotAlgorithmPreset.BF16_BF16_BF16:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False)
|
||||
case DotAlgorithmPreset.BF16_BF16_F32:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False)
|
||||
case DotAlgorithmPreset.BF16_BF16_F32_X3:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
|
||||
case DotAlgorithmPreset.BF16_BF16_F32_X6:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
|
||||
case DotAlgorithmPreset.TF32_TF32_F32:
|
||||
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
|
||||
case DotAlgorithmPreset.TF32_TF32_F32_X3:
|
||||
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False)
|
||||
case DotAlgorithmPreset.F32_F32_F32:
|
||||
return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False)
|
||||
case DotAlgorithmPreset.F64_F64_F64:
|
||||
return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False)
|
||||
case _:
|
||||
return None
|
||||
|
||||
|
||||
DotAlgorithmLike = Union[
|
||||
DotAlgorithm,
|
||||
DotAlgorithm.Preset,
|
||||
PrecisionLike = Union[
|
||||
None,
|
||||
str,
|
||||
None,
|
||||
]
|
||||
_DotAlgorithmLike = Union[
|
||||
Precision,
|
||||
tuple[str, str],
|
||||
tuple[Precision, Precision],
|
||||
DotAlgorithm,
|
||||
DotAlgorithm.Preset,
|
||||
DotAlgorithmPreset,
|
||||
]
|
||||
CanonicalPrecision = Union[
|
||||
None,
|
||||
tuple[Precision, Precision],
|
||||
DotAlgorithm,
|
||||
DotAlgorithmPreset,
|
||||
]
|
||||
DotTransposeAlgorithmLike = Union[
|
||||
DotAlgorithmLike,
|
||||
tuple[DotAlgorithmLike, DotAlgorithmLike],
|
||||
]
|
||||
DotTransposeAlgorithm = tuple[_DotAlgorithmLike, _DotAlgorithmLike]
|
||||
|
||||
|
||||
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
algorithm: DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
||||
|
||||
Wraps XLA's `Dot
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
|
||||
Wraps XLA's `Dot <https://www.tensorflow.org/xla/operation_semantics#dot>`_
|
||||
operator.
|
||||
|
||||
For more general contraction, see the :func:`jax.lax.dot_general` operator.
|
||||
@ -919,24 +1009,25 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
Args:
|
||||
lhs: an array of dimension 1 or 2.
|
||||
rhs: an array of dimension 1 or 2.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
||||
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
|
||||
preferred_element_type: Optional. Either ``None``, which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
algorithm: Optional. Specify the algorithm used for accumulating the dot
|
||||
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
|
||||
cannot be used with ``precision`` or ``preferred_element_type``.
|
||||
transpose_algorithm: Optional. This allows specifying the algorithm used when
|
||||
this operation is transposed, typically as part of reverse-mode automatic
|
||||
differentiation. This argument can either be a single
|
||||
:class:`~jax.lax.DotAlgorithm` or a tuple of two
|
||||
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
|
||||
algorithm for transposing the LHS and RHS, respectively.
|
||||
``transpose_algorithm`` must be explicitly specified when transposing a
|
||||
dot product where a specific ``algorithm`` was used on the forward pass.
|
||||
precision: Optional. This parameter controls the numerics of the
|
||||
computation, and it can be one of the following:
|
||||
|
||||
- ``None``, which means the default precision for the current backend,
|
||||
- a :class:`~jax.lax.Precision` enum value or a tuple of two
|
||||
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and
|
||||
``rhs``, or
|
||||
- a :class:`~jax.lax.DotAlgorithm` or a
|
||||
:class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that
|
||||
must be used to accumulate the dot product.
|
||||
|
||||
preferred_element_type: Optional. This parameter controls the data type
|
||||
output by the dot product. By default, the output element type of this
|
||||
operation will match the ``lhs`` and ``rhs`` input element types under
|
||||
the usual type promotion rules. Setting ``preferred_element_type`` to a
|
||||
specific ``dtype`` will mean that the operation returns that element type.
|
||||
When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or
|
||||
:class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides
|
||||
a hint to the compiler to accumulate the dot product using this data type.
|
||||
|
||||
Returns:
|
||||
An array containing the product.
|
||||
@ -944,9 +1035,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
|
||||
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm)
|
||||
preferred_element_type=preferred_element_type)
|
||||
else:
|
||||
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
|
||||
lhs.shape, rhs.shape))
|
||||
@ -957,9 +1046,7 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
||||
|
||||
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
algorithm: DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array:
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
"""General dot product/contraction operator.
|
||||
|
||||
Wraps XLA's `DotGeneral
|
||||
@ -978,29 +1065,31 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
lhs: an array
|
||||
rhs: an array
|
||||
dimension_numbers: a tuple of tuples of sequences of ints of the form
|
||||
``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))``
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
||||
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
|
||||
preferred_element_type: Optional. Either ``None``, which means the default
|
||||
accumulation type for the input types, or a datatype, indicating to
|
||||
accumulate results to and return a result with that datatype.
|
||||
algorithm: Optional. Specify the algorithm used for accumulating the dot
|
||||
product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument
|
||||
cannot be used with ``precision`` or ``preferred_element_type``.
|
||||
transpose_algorithm: Optional. This allows specifying the algorithm used when
|
||||
this operation is transposed, typically as part of reverse-mode automatic
|
||||
differentiation. This argument can either be a single
|
||||
:class:`~jax.lax.DotAlgorithm` or a tuple of two
|
||||
:class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the
|
||||
algorithm for transposing the LHS and RHS, respectively.
|
||||
``transpose_algorithm`` must be explicitly specified when transposing a
|
||||
dot product where a specific ``algorithm`` was used on the forward pass.
|
||||
``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
||||
rhs_batch_dims))``
|
||||
precision: Optional. This parameter controls the numerics of the
|
||||
computation, and it can be one of the following:
|
||||
|
||||
- ``None``, which means the default precision for the current backend,
|
||||
- a :class:`~jax.lax.Precision` enum value or a tuple of two
|
||||
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and
|
||||
``rhs``, or
|
||||
- a :class:`~jax.lax.DotAlgorithm` or a
|
||||
:class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that
|
||||
must be used to accumulate the dot product.
|
||||
|
||||
preferred_element_type: Optional. This parameter controls the data type
|
||||
output by the dot product. By default, the output element type of this
|
||||
operation will match the ``lhs`` and ``rhs`` input element types under
|
||||
the usual type promotion rules. Setting ``preferred_element_type`` to a
|
||||
specific ``dtype`` will mean that the operation returns that element type.
|
||||
When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or
|
||||
:class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides
|
||||
a hint to the compiler to accumulate the dot product using this data type.
|
||||
|
||||
Returns:
|
||||
An array whose first dimensions are the (shared) batch dimensions, followed by
|
||||
the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
|
||||
An array whose first dimensions are the (shared) batch dimensions, followed
|
||||
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
|
||||
non-contracting/non-batch dimensions.
|
||||
"""
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
@ -1014,9 +1103,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
return dot_general_p.bind(lhs, rhs,
|
||||
dimension_numbers=(cdims, bdims),
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=canonicalize_dot_algorithm(algorithm),
|
||||
transpose_algorithm=canonicalize_dot_transpose_algorithm(transpose_algorithm))
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
def ragged_dot(
|
||||
@ -3063,9 +3150,7 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
|
||||
|
||||
|
||||
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
preferred_element_type: DTypeLike | None):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
for d in (lhs_contracting, lhs_batch)):
|
||||
@ -3141,10 +3226,8 @@ def tuple_delete(tup, idx):
|
||||
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
del dimension_numbers, precision # unused
|
||||
preferred_element_type: DTypeLike | None):
|
||||
del dimension_numbers # unused
|
||||
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
|
||||
# primitive_util.h's HigherPrecisionType, e.g.
|
||||
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
|
||||
@ -3165,23 +3248,9 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
raise TypeError(
|
||||
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
|
||||
result_dtype = lhs.dtype
|
||||
|
||||
if transpose_algorithm is not None and algorithm is None:
|
||||
raise ValueError(
|
||||
"When the algorithm argument to dot_general is None, the "
|
||||
"transpose_algorithm argument is unused and must also be None.")
|
||||
|
||||
if algorithm is not None and algorithm != DotAlgorithm.Preset.DEFAULT:
|
||||
if preferred_element_type is not None:
|
||||
raise ValueError(
|
||||
"The preferred_element_type and algorithm arguments to dot_general "
|
||||
"cannot both be specified.")
|
||||
|
||||
# This is used to ensure that the output type is equal to the accumulation
|
||||
# type whenever an algorithm is specified.
|
||||
preferred_element_type = algorithm.accumulation_type
|
||||
|
||||
return _maybe_upcast(result_dtype, preferred_element_type)
|
||||
has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
|
||||
return _maybe_upcast(result_dtype, preferred_element_type,
|
||||
check_bit_width=not has_algorithm)
|
||||
|
||||
def _bit_width(d):
|
||||
if dtypes.issubdtype(d, np.inexact): return dtypes.finfo(d).bits
|
||||
@ -3189,12 +3258,12 @@ def _bit_width(d):
|
||||
elif d == np.dtype('bool'): return 1
|
||||
else: assert False, d # should be unreachable, open an issue!
|
||||
|
||||
def _maybe_upcast(result_dtype, preferred_element_type):
|
||||
def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
|
||||
# replicates the logic in shape_inference.cc's MaybeUpcast
|
||||
if (preferred_element_type is None or
|
||||
result_dtype == preferred_element_type):
|
||||
return result_dtype
|
||||
if (not dtypes.issubdtype(result_dtype, np.floating) and
|
||||
if (check_bit_width and not dtypes.issubdtype(result_dtype, np.floating) and
|
||||
_bit_width(preferred_element_type) < _bit_width(result_dtype)):
|
||||
raise TypeError("`preferred_element_type` must not be narrower than the "
|
||||
"original type, got preferred_element_type of "
|
||||
@ -3204,8 +3273,6 @@ def _maybe_upcast(result_dtype, preferred_element_type):
|
||||
|
||||
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None,
|
||||
swap_ans=False):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = x.aval.ndim
|
||||
@ -3218,36 +3285,20 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
||||
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
||||
if algorithm is not None:
|
||||
if transpose_algorithm is None or transpose_algorithm[0] is None:
|
||||
raise ValueError(
|
||||
"When a dot_general algorithm is specified on the forward pass, "
|
||||
"transpose_algorithm must be specified for the backward pass.")
|
||||
lhs_alg, rhs_alg = transpose_algorithm
|
||||
transpose_algorithm = (algorithm, rhs_alg)
|
||||
algorithm = lhs_alg
|
||||
x_bar = transpose(dot_general(g, y, dims, precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm),
|
||||
preferred_element_type=preferred_element_type),
|
||||
tuple(out_axes))
|
||||
if x_bar.dtype != x.aval.dtype:
|
||||
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
|
||||
return x_bar
|
||||
|
||||
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None):
|
||||
preferred_element_type: DTypeLike | None):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
transpose_algorithm = None if transpose_algorithm is None else (
|
||||
transpose_algorithm[1], transpose_algorithm[0])
|
||||
y_bar = _dot_general_transpose_lhs(
|
||||
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
||||
preferred_element_type=preferred_element_type, algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm,
|
||||
swap_ans=True)
|
||||
preferred_element_type=preferred_element_type, swap_ans=True)
|
||||
if y_bar.dtype != y.aval.dtype:
|
||||
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
|
||||
return y_bar
|
||||
@ -3263,8 +3314,6 @@ def _dot_batch_rule(
|
||||
dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None,
|
||||
**_,
|
||||
):
|
||||
|
||||
@ -3298,8 +3347,6 @@ def _dot_batch_rule(
|
||||
new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm,
|
||||
)
|
||||
result_batch_dim = batching.shape_as_bdim(
|
||||
result_stack_dim,
|
||||
@ -3415,7 +3462,7 @@ pe.padding_rules[dot_general_p] = _dot_general_padding_rule
|
||||
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
|
||||
|
||||
def precision_attr(precision: Precision) -> ir.ArrayAttr:
|
||||
if precision is None:
|
||||
if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
||||
full_precision = (Precision.DEFAULT, Precision.DEFAULT)
|
||||
elif not isinstance(precision, tuple):
|
||||
full_precision = (precision, precision)
|
||||
@ -3425,19 +3472,16 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr:
|
||||
[hlo.PrecisionAttr.get(str(p)) for p in full_precision])
|
||||
|
||||
|
||||
def dot_algorithm_attr(algorithm: _DotAlgorithmLike, lhs_dtype: DTypeLike,
|
||||
def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
|
||||
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
|
||||
if algorithm is None:
|
||||
if not isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
||||
return None
|
||||
return algorithm._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
|
||||
return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
|
||||
|
||||
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision, preferred_element_type: np.dtype | None,
|
||||
algorithm: _DotAlgorithmLike = None,
|
||||
transpose_algorithm: DotTransposeAlgorithm | None = None,
|
||||
platform: str = "default"):
|
||||
del transpose_algorithm # unused
|
||||
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
||||
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
||||
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
||||
@ -3446,63 +3490,87 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
lhs_aval, rhs_aval = ctx.avals_in
|
||||
lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype
|
||||
aval_out, = ctx.avals_out
|
||||
accumulation_aval = aval_out
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
# TODO(b/...): JAX's dot_general primitive accepts the same input dtype
|
||||
# combinations that are accepted in XLA's shape_inference.cc (the canonical
|
||||
# reference for the HLO type system), but actually different XLA platforms
|
||||
# fail on codegen for different accepted cases. To handle those cases, we
|
||||
# insert ConvertOps on the input, in a platform-dependent way.
|
||||
if lhs_dtype != rhs_dtype:
|
||||
if platform == "tpu":
|
||||
handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or
|
||||
dtypes.issubdtype(dt, np.integer))
|
||||
if not (handled(lhs_dtype) and handled(rhs_dtype)):
|
||||
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
||||
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
||||
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
||||
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
||||
lhs_dtype = rhs_dtype = aval_out.dtype
|
||||
else: # cpu and gpu
|
||||
# Do not convert mixed fp8 types to output type.
|
||||
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
|
||||
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
||||
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
||||
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
||||
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
||||
lhs_dtype = rhs_dtype = aval_out.dtype
|
||||
|
||||
|
||||
dot_dnums = hlo.DotDimensionNumbers.get(
|
||||
lhs_batching_dimensions=list(lhs_batch),
|
||||
rhs_batching_dimensions=list(rhs_batch),
|
||||
lhs_contracting_dimensions=list(lhs_contracting),
|
||||
rhs_contracting_dimensions=list(rhs_contracting))
|
||||
|
||||
if algorithm is not None and precision not in {
|
||||
None, Precision.DEFAULT, (Precision.DEFAULT, Precision.DEFAULT)}:
|
||||
raise ValueError(
|
||||
"The dot_general precision must be None or DEFAULT when an algorithm "
|
||||
"is specified.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
if algorithm is not None:
|
||||
algorithm_kwarg = {}
|
||||
if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
||||
# The CPU backend silently ignores the algorithm spec, so we check here to
|
||||
# make sure that the selected algorithm is supported. We could be a little
|
||||
# bit more liberal here (any algorithm where the input and output types
|
||||
# match and all the other parameters have default values should work), but
|
||||
# it's probably sufficient to just check the presets here.
|
||||
if platform == "cpu" and precision not in {
|
||||
DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16,
|
||||
DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64,
|
||||
}:
|
||||
raise ValueError(
|
||||
"The dot_general algorithm parameter is only supported for jaxlib "
|
||||
"versions larger than 0.4.33.")
|
||||
algorithm_kwargs = {}
|
||||
f"The precision '{precision}' is not supported by dot_general on CPU")
|
||||
|
||||
# If an explicit algorithm was specified, we always cast the input types to
|
||||
# the correct types.
|
||||
def maybe_convert_dtype(operand, operand_aval, target_dtype):
|
||||
if target_dtype is None:
|
||||
return operand, operand_aval.dtype
|
||||
if not isinstance(target_dtype, tuple):
|
||||
target_dtype = (target_dtype,)
|
||||
if any(operand_aval.dtype == d for d in target_dtype):
|
||||
return operand, operand_aval.dtype
|
||||
aval = core.ShapedArray(operand_aval.shape, target_dtype[0])
|
||||
return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0]
|
||||
|
||||
lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type)
|
||||
rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type)
|
||||
accumulation_type = precision.accumulation_type
|
||||
if accumulation_type is not None:
|
||||
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type)
|
||||
|
||||
if precision != DotAlgorithmPreset.DEFAULT:
|
||||
algorithm_kwarg = {
|
||||
"algorithm": dot_algorithm_attr(precision, lhs_dtype, rhs_dtype)
|
||||
}
|
||||
else:
|
||||
algorithm_kwargs = {"algorithm": dot_algorithm_attr(algorithm, lhs_dtype,
|
||||
rhs_dtype)}
|
||||
return [
|
||||
hlo.dot_general(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
lhs,
|
||||
rhs,
|
||||
dot_dnums,
|
||||
precision_config=precision_attr(precision),
|
||||
**algorithm_kwargs,
|
||||
)
|
||||
]
|
||||
# TODO(b/...): JAX's dot_general primitive accepts the same input dtype
|
||||
# combinations that are accepted in XLA's shape_inference.cc (the canonical
|
||||
# reference for the HLO type system), but actually different XLA platforms
|
||||
# fail on codegen for different accepted cases. To handle those cases, we
|
||||
# insert ConvertOps on the input, in a platform-dependent way.
|
||||
if lhs_dtype != rhs_dtype:
|
||||
if platform == "tpu":
|
||||
handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or
|
||||
dtypes.issubdtype(dt, np.integer))
|
||||
if not (handled(lhs_dtype) and handled(rhs_dtype)):
|
||||
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
||||
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
||||
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
||||
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
||||
lhs_dtype = rhs_dtype = aval_out.dtype
|
||||
else: # cpu and gpu
|
||||
# Do not convert mixed fp8 types to output type.
|
||||
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
|
||||
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
||||
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
||||
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
||||
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
||||
lhs_dtype = rhs_dtype = aval_out.dtype
|
||||
|
||||
result = hlo.dot_general(
|
||||
mlir.aval_to_ir_type(accumulation_aval),
|
||||
lhs,
|
||||
rhs,
|
||||
dot_dnums,
|
||||
precision_config=precision_attr(precision),
|
||||
**algorithm_kwarg,
|
||||
)
|
||||
if accumulation_aval.dtype != aval_out.dtype:
|
||||
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
|
||||
return [result]
|
||||
|
||||
mlir.register_lowering(dot_general_p, _dot_general_lower)
|
||||
|
||||
@ -3556,8 +3624,7 @@ def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
||||
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
|
||||
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
|
||||
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision, preferred_element_type=preferred_element_type,
|
||||
algorithm=None, transpose_algorithm=None)
|
||||
precision=precision, preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
def _ragged_dot_jvp_rule(
|
||||
@ -3680,12 +3747,7 @@ def _ragged_dot_invoke_prim(
|
||||
new_dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
algorithm,
|
||||
transpose_algorithm,
|
||||
):
|
||||
assert algorithm is None
|
||||
assert transpose_algorithm is None
|
||||
|
||||
return ragged_dot(
|
||||
lhs,
|
||||
rhs,
|
||||
@ -5779,7 +5841,7 @@ def remaining(original, *removed_lists):
|
||||
return [i for i in original if i not in removed]
|
||||
|
||||
|
||||
def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None:
|
||||
def canonicalize_precision(precision: PrecisionLike) -> CanonicalPrecision:
|
||||
"""Turns an API precision specification into a pair of enumeration values.
|
||||
|
||||
The API can take the precision as a string, or int, and either as a single
|
||||
@ -5789,56 +5851,44 @@ def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precisi
|
||||
if config.default_matmul_precision.value is None:
|
||||
return None
|
||||
try:
|
||||
return (
|
||||
Precision(config.default_matmul_precision.value),
|
||||
Precision(config.default_matmul_precision.value),
|
||||
)
|
||||
except TypeError:
|
||||
return canonicalize_precision(config.default_matmul_precision.value)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"jax_default_matmul_precision flag must be set to None or a value in "
|
||||
f"{list(_precision_strings)}, but got {config.default_matmul_precision.value}"
|
||||
"jax_default_matmul_precision flag must be set to None, a value in "
|
||||
f"{list(_precision_strings)}, or the name of a lax.DotAlgorithmPreset, "
|
||||
f"but got {config.default_matmul_precision.value}"
|
||||
) from None
|
||||
elif isinstance(precision, str) and precision in _precision_strings:
|
||||
return Precision(precision), Precision(precision)
|
||||
elif isinstance(precision, str):
|
||||
if precision in _precision_strings:
|
||||
return Precision(precision), Precision(precision)
|
||||
else:
|
||||
try:
|
||||
return DotAlgorithmPreset[precision]
|
||||
except KeyError:
|
||||
pass
|
||||
elif isinstance(precision, Precision):
|
||||
return precision, precision
|
||||
elif isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
||||
return precision
|
||||
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
||||
all(isinstance(p, Precision) for p in precision)):
|
||||
return type_cast(tuple[Precision, Precision], precision)
|
||||
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
||||
all(isinstance(s, str) for s in precision)):
|
||||
s1, s2 = precision
|
||||
s1, s2 = type_cast(tuple[str, str], precision)
|
||||
p1 = type_cast(tuple[Precision, Precision], canonicalize_precision(s1))[0]
|
||||
p2 = type_cast(tuple[Precision, Precision], canonicalize_precision(s2))[0]
|
||||
return (p1, p2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Precision argument must be None, a string in {list(_precision_strings)}, "
|
||||
"a lax.Precision value or a tuple of two lax.Precision values or "
|
||||
f"strings; got {precision}.")
|
||||
raise ValueError(
|
||||
"Precision argument must be one of:\n"
|
||||
"- None,\n"
|
||||
f"- a string in {list(_precision_strings)},\n"
|
||||
"- a lax.Precision value,\n"
|
||||
"- a tuple of two lax.Precision values or strings,\n"
|
||||
"- a lax.DotAlgorithmPreset or the name of one of these presets, or\n"
|
||||
"- a lax.DotAlgorithm value;\n"
|
||||
f"but got {precision}.")
|
||||
|
||||
def canonicalize_dot_algorithm(algorithm: DotAlgorithmLike) -> _DotAlgorithmLike:
|
||||
if isinstance(algorithm, str):
|
||||
algorithm = DotAlgorithm.Preset[algorithm]
|
||||
if algorithm is None or algorithm == DotAlgorithm.Preset.DEFAULT:
|
||||
return None
|
||||
return algorithm
|
||||
|
||||
def canonicalize_dot_transpose_algorithm(
|
||||
algorithm: DotTransposeAlgorithmLike) -> DotTransposeAlgorithm | None:
|
||||
if algorithm is None:
|
||||
return None
|
||||
elif isinstance(algorithm, DotAlgorithm):
|
||||
return (algorithm, algorithm)
|
||||
elif isinstance(algorithm, tuple):
|
||||
if len(algorithm) != 2:
|
||||
raise ValueError(
|
||||
"The transpose_algorithm argument must be a single value or a tuple "
|
||||
f"of two values; got {algorithm}.")
|
||||
return (canonicalize_dot_algorithm(algorithm[0]),
|
||||
canonicalize_dot_algorithm(algorithm[1]))
|
||||
algorithm = canonicalize_dot_algorithm(algorithm)
|
||||
return (algorithm, algorithm)
|
||||
|
||||
def _balanced_eq(x, z, y):
|
||||
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
|
||||
|
@ -205,11 +205,13 @@ def conv_general_dilated_local(
|
||||
lhs_array = lax.asarray(lhs)
|
||||
|
||||
c_precision = lax.canonicalize_precision(precision)
|
||||
lhs_precision = (
|
||||
c_precision[0]
|
||||
if (isinstance(c_precision, tuple) and len(c_precision) == 2)
|
||||
else c_precision
|
||||
)
|
||||
if c_precision is None:
|
||||
lhs_precision = None
|
||||
elif isinstance(c_precision, tuple) and len(c_precision) == 2:
|
||||
lhs_precision = c_precision[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported precision for conv_general_dilated_local: {precision}")
|
||||
|
||||
patches = conv_general_dilated_patches(
|
||||
lhs=lhs_array,
|
||||
|
@ -2090,10 +2090,8 @@ def _dot_general_lowering(
|
||||
dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
algorithm,
|
||||
transpose_algorithm,
|
||||
):
|
||||
del preferred_element_type, algorithm, transpose_algorithm # Unused.
|
||||
del preferred_element_type # Unused.
|
||||
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
|
||||
assert batch_dims == ((), ())
|
||||
|
||||
|
@ -364,15 +364,12 @@ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
precision: tuple[PrecisionType, PrecisionType] | None,
|
||||
preferred_element_type: DType | None,
|
||||
algorithm: Any, transpose_algorithm: Any,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
# Unused arguments.
|
||||
del precision
|
||||
del preferred_element_type
|
||||
del algorithm
|
||||
del transpose_algorithm
|
||||
|
||||
lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype(
|
||||
lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)
|
||||
|
@ -2183,12 +2183,9 @@ tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
precision: tuple[PrecisionType, PrecisionType] | None,
|
||||
preferred_element_type: DType | None,
|
||||
algorithm: Any, transpose_algorithm: Any,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
del algorithm, transpose_algorithm # unused
|
||||
|
||||
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
|
||||
# failures. Since we are going to turn on native serializaton soon, wait
|
||||
# until then to turn on this check.
|
||||
|
@ -84,7 +84,7 @@ TODO:
|
||||
"""
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import cast, Any
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
@ -228,10 +228,10 @@ def _lstm_cudnn_allow_tf32(precision: lax.PrecisionLike) -> bool:
|
||||
#
|
||||
# but we prefer to still invoke it here for consistency
|
||||
precision = lax.canonicalize_precision(precision)
|
||||
if precision is None:
|
||||
if precision is None or not (isinstance(precision, tuple) and len(precision) == 2):
|
||||
return True
|
||||
# cuDNN allows only one precision specifier per RNN op
|
||||
precision, _ = precision
|
||||
precision, _ = cast(tuple[lax.Precision, lax.Precision], precision)
|
||||
if precision == lax.Precision.HIGHEST:
|
||||
return False
|
||||
elif precision == lax.Precision.HIGH:
|
||||
|
@ -609,8 +609,7 @@ mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
|
||||
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
|
||||
|
||||
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None, preferred_element_type: None = None,
|
||||
algorithm: None = None, transpose_algorithm: None = None) -> BCOO | Array:
|
||||
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -621,8 +620,6 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
|
||||
(lhs_batch_dims, rhs_batch_dims))`.
|
||||
precision: unused
|
||||
preferred_element_type: unused
|
||||
algorithm: unused
|
||||
transpose_algorithm: unused
|
||||
|
||||
Returns:
|
||||
An ndarray or BCOO-format sparse array containing the result. If both inputs
|
||||
@ -630,7 +627,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
|
||||
the result will be dense, of type ndarray.
|
||||
"""
|
||||
# TODO(jakevdp) make use of these?
|
||||
del precision, algorithm, transpose_algorithm # unused
|
||||
del precision # unused
|
||||
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
|
||||
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
|
||||
dimension_numbers)
|
||||
@ -1056,9 +1053,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
|
||||
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
|
||||
kwds = {'dimension_numbers': dimension_numbers,
|
||||
'precision': None,
|
||||
'preferred_element_type': None,
|
||||
'algorithm': None,
|
||||
'transpose_algorithm': None}
|
||||
'preferred_element_type': None}
|
||||
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
|
||||
return A, B, indices
|
||||
|
||||
|
@ -463,9 +463,7 @@ bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
|
||||
def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None,
|
||||
preferred_element_type: None = None,
|
||||
algorithm: None = None,
|
||||
transpose_algorithm: None = None) -> Array:
|
||||
preferred_element_type: None = None) -> Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -476,15 +474,13 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
(lhs_batch_dims, rhs_batch_dims))`.
|
||||
precision: unused
|
||||
preferred_element_type: unused
|
||||
algorithm: unused
|
||||
transpose_algorithm: unused
|
||||
|
||||
Returns:
|
||||
An ndarray or BCSR-format sparse array containing the result. If both inputs
|
||||
are sparse, the result will be sparse, of type BCSR. If either input is
|
||||
dense, the result will be dense, of type ndarray.
|
||||
"""
|
||||
del precision, algorithm, transpose_algorithm # unused
|
||||
del precision # unused
|
||||
if isinstance(rhs, (np.ndarray, jax.Array)):
|
||||
if isinstance(lhs, (np.ndarray, jax.Array)):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
|
@ -113,5 +113,4 @@ def _dot_general_validated_shape(
|
||||
rhs = core.ShapedArray(rhs_shape, np.float32)
|
||||
return _dot_general_shape_rule(
|
||||
lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
precision=None, preferred_element_type=None, algorithm=None,
|
||||
transpose_algorithm=None)
|
||||
precision=None, preferred_element_type=None)
|
||||
|
@ -20,8 +20,7 @@ from jax._src.lax.lax import (
|
||||
Precision as Precision,
|
||||
PrecisionLike as PrecisionLike,
|
||||
DotAlgorithm as DotAlgorithm,
|
||||
DotAlgorithmLike as DotAlgorithmLike,
|
||||
DotTransposeAlgorithmLike as DotTransposeAlgorithmLike,
|
||||
DotAlgorithmPreset as DotAlgorithmPreset,
|
||||
RandomAlgorithm as RandomAlgorithm,
|
||||
RoundingMethod as RoundingMethod,
|
||||
abs as abs,
|
||||
|
@ -1061,19 +1061,19 @@ class LaxTest(jtu.JaxTestCase):
|
||||
accumulation_type=np.float32,
|
||||
), [np.float16]),
|
||||
("F16_F16_F32", [np.float16]),
|
||||
(lax.DotAlgorithm.Preset.DEFAULT, lax_test_util.float_dtypes),
|
||||
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes),
|
||||
(lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes),
|
||||
(lax.DotAlgorithm.Preset.F16_F16_F16, [np.float16]),
|
||||
(lax.DotAlgorithm.Preset.F16_F16_F32, [np.float16]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_BF16, [dtypes.bfloat16]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_F32, [dtypes.bfloat16]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.TF32_TF32_F32, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.TF32_TF32_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.F32_F32_F32, [np.float32]),
|
||||
(lax.DotAlgorithm.Preset.F64_F64_F64, [np.float64]),
|
||||
(lax.DotAlgorithmPreset.DEFAULT, lax_test_util.float_dtypes),
|
||||
(lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes),
|
||||
(lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes),
|
||||
(lax.DotAlgorithmPreset.F16_F16_F16, [np.float16]),
|
||||
(lax.DotAlgorithmPreset.F16_F16_F32, [np.float16]),
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_BF16, [dtypes.bfloat16]),
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_F32, [dtypes.bfloat16]),
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_F32_X6, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.TF32_TF32_F32, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.TF32_TF32_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.F32_F32_F32, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.F64_F64_F64, [np.float64]),
|
||||
] for dtype in test_dtypes
|
||||
if jtu.dtypes.supported([dtype])
|
||||
])
|
||||
@ -1084,26 +1084,35 @@ class LaxTest(jtu.JaxTestCase):
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
if algorithm not in {
|
||||
lax.DotAlgorithmPreset.DEFAULT,
|
||||
lax.DotAlgorithmPreset.F16_F16_F16,
|
||||
lax.DotAlgorithmPreset.F32_F32_F32,
|
||||
lax.DotAlgorithmPreset.F64_F64_F64,
|
||||
}:
|
||||
raise SkipTest(
|
||||
f"The dot algorithm '{algorithm}' is not supported on CPU.")
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# GPU algorithm support is a little spotty. It is checked in
|
||||
# xla/service/algorithm_util.cc and the logic is copied here.
|
||||
if algorithm in {
|
||||
lax.DotAlgorithm.Preset.F16_F16_F32,
|
||||
lax.DotAlgorithm.Preset.TF32_TF32_F32,
|
||||
lax.DotAlgorithm.Preset.BF16_BF16_F32,
|
||||
lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, # Must have f32 input
|
||||
lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, # Must have f32 input
|
||||
lax.DotAlgorithmPreset.F16_F16_F32,
|
||||
lax.DotAlgorithmPreset.TF32_TF32_F32,
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32,
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
|
||||
}:
|
||||
if not jtu.is_cuda_compute_capability_at_least("8.0"):
|
||||
raise SkipTest(
|
||||
f"The dot algorithm '{algorithm}' requires CUDA compute "
|
||||
"capability >= 8.0.")
|
||||
elif algorithm not in {
|
||||
lax.DotAlgorithm.Preset.DEFAULT,
|
||||
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
|
||||
lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
|
||||
lax.DotAlgorithm.Preset.F32_F32_F32,
|
||||
lax.DotAlgorithm.Preset.F64_F64_F64,
|
||||
lax.DotAlgorithmPreset.DEFAULT,
|
||||
lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32,
|
||||
lax.DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
|
||||
lax.DotAlgorithmPreset.F32_F32_F32,
|
||||
lax.DotAlgorithmPreset.F64_F64_F64,
|
||||
}:
|
||||
raise SkipTest(
|
||||
f"The dot algorithm '{algorithm}' is not supported on GPU.")
|
||||
@ -1111,12 +1120,8 @@ class LaxTest(jtu.JaxTestCase):
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
self._CompileAndCheck(partial(lax.dot, algorithm=algorithm), args_maker)
|
||||
# Check that accumulation type sets the output type
|
||||
output = lax.dot(*args_maker(), algorithm=algorithm)
|
||||
algorithm = lax_internal.canonicalize_dot_algorithm(algorithm)
|
||||
expected_dtype = dtype if algorithm is None else algorithm.accumulation_type
|
||||
self.assertEqual(output.dtype, expected_dtype)
|
||||
self._CompileAndCheck(partial(lax.dot, precision=algorithm), args_maker)
|
||||
self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype)
|
||||
|
||||
def testDotAlgorithmInvalidFloat8Type(self):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
@ -1125,95 +1130,29 @@ class LaxTest(jtu.JaxTestCase):
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
raise SkipTest("Not supported on CPU.")
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn)
|
||||
with self.assertRaisesRegex(ValueError, "The dot algorithm"):
|
||||
lax.dot(lhs, rhs, algorithm="ANY_F8_ANY_F8_F32")
|
||||
lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32")
|
||||
|
||||
@parameterized.parameters([
|
||||
({"precision": lax.Precision.HIGHEST}, "The dot_general precision must be None or DEFAULT"),
|
||||
({"preferred_element_type": np.float32}, "The preferred_element_type and algorithm arguments"),
|
||||
])
|
||||
def testDotAlgorithmInvalidParameters(self, kwargs, pattern):
|
||||
def testDotAlgorithmCasting(self):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
def fun(lhs, rhs):
|
||||
return lax.dot(lhs, rhs, precision="F32_F32_F32")
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
|
||||
with self.assertRaisesRegex(ValueError, pattern):
|
||||
lax.dot(lhs, rhs, algorithm="F32_F32_F32", **kwargs)
|
||||
|
||||
def testDotAlgorithmTransposeRequired(self):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
|
||||
fun = partial(lax.dot, algorithm="F32_F32_F32")
|
||||
out = fun(lhs, rhs)
|
||||
_, vjp_fun = jax.vjp(fun, lhs, rhs)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "When a dot_general algorithm is specified"):
|
||||
vjp_fun(out)
|
||||
|
||||
@parameterized.parameters([
|
||||
("F32_F32_F32", "F16_F16_F32"),
|
||||
("F32_F32_F32", ("F16_F16_F32", "F64_F64_F64")),
|
||||
])
|
||||
def testDotAlgorithmTranspose(self, algorithm, transpose_algorithm):
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is not supported by PJRT C API.")
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
raise SkipTest(
|
||||
"The dot algorithm attribute is only supported for jaxlib >0.4.33.")
|
||||
def fun(x, y):
|
||||
return lax.dot(x, y, algorithm=algorithm,
|
||||
transpose_algorithm=transpose_algorithm)
|
||||
|
||||
algorithm_ = lax_internal.canonicalize_dot_algorithm(algorithm)
|
||||
lhs_alg, rhs_alg = lax_internal.canonicalize_dot_transpose_algorithm(
|
||||
transpose_algorithm)
|
||||
|
||||
lhs_shape = (3, 4)
|
||||
rhs_shape = (4, 3)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32)
|
||||
out = fun(lhs, rhs)
|
||||
|
||||
def check_transpose_algorithm(f, arg, alg, trans_alg, trans_trans_alg):
|
||||
fun_trans = jax.linear_transpose(f, arg)
|
||||
jaxpr = jax.make_jaxpr(fun_trans)(out)
|
||||
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr.eqns))
|
||||
self.assertEqual(eqn.params["algorithm"], alg)
|
||||
self.assertEqual(eqn.params["transpose_algorithm"], trans_alg)
|
||||
|
||||
fun_ = jax.linear_transpose(lambda x: fun_trans(x)[0], out)
|
||||
jaxpr_ = jax.make_jaxpr(fun_)(arg)
|
||||
eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr_.eqns))
|
||||
self.assertEqual(eqn.params["algorithm"], algorithm_)
|
||||
|
||||
# Note that transposing the RHS of a dot_general introduce extra
|
||||
# transposes on the input and output, so we don't actually end up with
|
||||
# the same `transpose_algorithm` parameter after 2 transposes.
|
||||
self.assertEqual(eqn.params["transpose_algorithm"], trans_trans_alg)
|
||||
|
||||
check_transpose_algorithm(partial(fun, y=rhs), lhs, lhs_alg,
|
||||
(algorithm_, rhs_alg), (lhs_alg, rhs_alg))
|
||||
check_transpose_algorithm(partial(fun, lhs), rhs, rhs_alg,
|
||||
(algorithm_, lhs_alg), (rhs_alg, lhs_alg))
|
||||
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
|
||||
self.assertEqual(fun(lhs, rhs).dtype, np.float16)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user