2022-09-22 12:26:48 -07:00
|
|
|
|
# Copyright 2018 The JAX Authors.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2023-08-11 08:06:51 -07:00
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
import builtins
|
2024-06-26 14:44:52 -04:00
|
|
|
|
from collections.abc import Callable, Sequence
|
2025-03-10 12:24:38 -07:00
|
|
|
|
import dataclasses
|
2021-08-11 17:32:36 -04:00
|
|
|
|
import enum
|
2020-10-17 14:33:26 -04:00
|
|
|
|
import functools
|
2021-09-13 17:24:44 -04:00
|
|
|
|
from functools import partial
|
2020-10-17 14:33:26 -04:00
|
|
|
|
import itertools
|
2023-02-28 12:40:30 -08:00
|
|
|
|
import math
|
2020-10-17 14:33:26 -04:00
|
|
|
|
import operator
|
2024-09-25 06:16:22 -07:00
|
|
|
|
from typing import Any, NamedTuple, TypeVar, Union, cast as type_cast, overload
|
2020-10-17 14:33:26 -04:00
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax import tree_util
|
2024-01-22 09:27:47 -08:00
|
|
|
|
from jax.sharding import Sharding
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax.tree_util import tree_map
|
|
|
|
|
|
2021-06-07 14:51:04 -07:00
|
|
|
|
from jax._src import ad_util
|
2021-04-13 09:42:54 -07:00
|
|
|
|
from jax._src import api
|
2021-09-08 09:00:23 -07:00
|
|
|
|
from jax._src import api_util
|
2022-11-30 15:25:21 -08:00
|
|
|
|
from jax._src import array
|
2023-10-09 07:28:18 -07:00
|
|
|
|
from jax._src import config
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import core
|
2022-09-08 08:49:12 -07:00
|
|
|
|
from jax._src import dispatch
|
2021-04-07 19:35:17 -07:00
|
|
|
|
from jax._src import dtypes
|
2023-02-01 17:50:00 -08:00
|
|
|
|
from jax._src import effects
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import linear_util as lu
|
2024-09-01 07:49:49 -07:00
|
|
|
|
from jax._src import pjit
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import pretty_printer as pp
|
2022-01-20 22:58:09 -08:00
|
|
|
|
from jax._src import source_info_util
|
2024-04-04 14:33:06 -04:00
|
|
|
|
from jax._src import state
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import util
|
2023-12-22 15:53:48 -08:00
|
|
|
|
from jax._src.abstract_arrays import array_types
|
2024-10-31 14:06:08 -07:00
|
|
|
|
from jax._src.core import (Primitive, UnshapedArray, ShapedArray,
|
2024-11-05 07:16:32 -08:00
|
|
|
|
abstract_token, canonicalize_shape)
|
2025-01-21 13:28:08 -08:00
|
|
|
|
from jax._src.errors import UnexpectedTracerError
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
|
from jax._src.interpreters import batching
|
|
|
|
|
from jax._src.interpreters import mlir
|
2023-03-27 13:29:59 -07:00
|
|
|
|
from jax._src.interpreters import partial_eval as pe
|
2023-03-21 05:13:21 -07:00
|
|
|
|
from jax._src.interpreters import pxla
|
2023-03-27 13:29:59 -07:00
|
|
|
|
from jax._src.interpreters import xla
|
2023-05-05 15:25:42 -04:00
|
|
|
|
from jax._src.interpreters.batching import RaggedAxis
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.lax import slicing
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
from jax._src import mesh as mesh_lib
|
2021-11-23 16:34:33 -08:00
|
|
|
|
from jax._src.lax.utils import (
|
2023-06-14 18:30:52 -07:00
|
|
|
|
_input_dtype, dtype_to_string, standard_abstract_eval,
|
2024-07-25 00:02:55 +00:00
|
|
|
|
standard_multi_result_abstract_eval, standard_primitive)
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
|
from jax._src.lib.mlir.dialects import chlo
|
|
|
|
|
from jax._src.lib.mlir.dialects import hlo
|
2024-10-09 21:23:57 -07:00
|
|
|
|
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
|
2025-03-21 17:35:37 -07:00
|
|
|
|
ShardingContext, SPMDAxisContext,
|
2025-01-08 11:10:37 -08:00
|
|
|
|
PartitionSpec as P, canonicalize_sharding)
|
2024-05-29 14:23:03 +03:00
|
|
|
|
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
|
2024-12-18 19:37:58 -08:00
|
|
|
|
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
|
2025-03-14 07:42:07 -07:00
|
|
|
|
safe_map, safe_zip, split_list, weakref_lru_cache,
|
|
|
|
|
foreach)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
_max = builtins.max
|
2020-11-05 09:24:18 -05:00
|
|
|
|
_min = builtins.min
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_reduce = functools.reduce
|
|
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
2022-06-17 15:53:53 -07:00
|
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
export = util.set_module("jax.lax")
|
|
|
|
|
|
2024-09-01 07:49:49 -07:00
|
|
|
|
def _matrix_transpose(x: Array) -> Array:
|
|
|
|
|
assert x.ndim >= 2
|
|
|
|
|
return transpose(x, [*range(x.ndim - 2), x.ndim - 1, x.ndim - 2])
|
|
|
|
|
|
2024-07-31 10:12:31 +02:00
|
|
|
|
def _clip_int_to_valid_range(val: DimSize, dtype, where: str) -> int:
|
2023-11-09 10:23:20 -08:00
|
|
|
|
info = np.iinfo(dtype)
|
2024-07-31 10:12:31 +02:00
|
|
|
|
val = core.concrete_dim_or_error(val, where)
|
2024-05-29 14:23:03 +03:00
|
|
|
|
return core.max_dim(info.min, core.min_dim(val, info.max))
|
2023-11-09 10:23:20 -08:00
|
|
|
|
|
2022-02-15 15:03:33 +01:00
|
|
|
|
def _validate_shapes(shapes: Sequence[Shape]):
|
|
|
|
|
def _check_static_shape(shape: Shape):
|
|
|
|
|
checked = canonicalize_shape(shape)
|
|
|
|
|
if not all(idx >= 0 for idx in checked):
|
|
|
|
|
msg = f"Only non-negative indices are allowed when broadcasting" \
|
|
|
|
|
f" static shapes, but got shape {shape!r}."
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
|
|
|
|
|
assert shapes
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.dynamic_shapes.value:
|
2022-02-15 15:03:33 +01:00
|
|
|
|
# pass dynamic shapes through unchecked
|
|
|
|
|
return
|
|
|
|
|
else:
|
2025-03-14 07:42:07 -07:00
|
|
|
|
foreach(_check_static_shape, shapes)
|
2022-02-15 15:03:33 +01:00
|
|
|
|
|
2024-12-11 15:03:39 -08:00
|
|
|
|
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
|
|
|
|
|
"""
|
|
|
|
|
Attempt to broadcast shapes, raising a TypeError if broadcasting fails.
|
|
|
|
|
"""
|
|
|
|
|
if not shapes:
|
|
|
|
|
raise TypeError(f"{name}: At least one shape is required.")
|
2023-07-27 13:39:20 -07:00
|
|
|
|
ranks = {len(shape) for shape in shapes}
|
2024-12-11 15:03:39 -08:00
|
|
|
|
if len(ranks) != 1:
|
|
|
|
|
raise TypeError(f'{name}: arrays must have the same number of dimensions,'
|
|
|
|
|
f' got {ranks}')
|
2022-06-29 13:55:30 -07:00
|
|
|
|
result_shape = []
|
2024-12-11 15:03:39 -08:00
|
|
|
|
for ds in zip(*shapes):
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
|
|
|
|
|
# if all axes are identical objects, the resulting size is the object
|
|
|
|
|
result_shape.append(ds[0])
|
|
|
|
|
else:
|
2024-12-11 15:03:39 -08:00
|
|
|
|
# if all dims are equal (or 1), the result is the non-1 size
|
2023-06-10 17:33:27 -04:00
|
|
|
|
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if not non_1s:
|
|
|
|
|
result_shape.append(1)
|
2023-06-10 17:33:27 -04:00
|
|
|
|
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
|
2022-06-29 13:55:30 -07:00
|
|
|
|
result_shape.append(non_1s[0])
|
|
|
|
|
else:
|
2024-12-11 15:03:39 -08:00
|
|
|
|
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
|
|
|
|
|
f'{", ".join(map(str, map(tuple, shapes)))}.')
|
2020-11-13 14:55:04 -08:00
|
|
|
|
return tuple(result_shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-03-30 10:21:55 -07:00
|
|
|
|
def asarray(x: ArrayLike) -> Array:
|
|
|
|
|
"""Lightweight conversion of ArrayLike input to Array output."""
|
|
|
|
|
if isinstance(x, Array):
|
|
|
|
|
return x
|
2025-01-17 14:38:13 -08:00
|
|
|
|
elif isinstance(x, (bool, np.ndarray, np.generic)):
|
2025-02-13 15:44:19 -08:00
|
|
|
|
return _convert_element_type(x, weak_type=False) # pytype: disable=bad-return-type
|
2025-01-17 14:38:13 -08:00
|
|
|
|
elif isinstance(x, (int, float, builtins.complex)):
|
2025-02-13 18:05:27 +00:00
|
|
|
|
return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True)
|
2023-03-30 10:21:55 -07:00
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
|
|
|
|
|
|
2022-10-04 15:50:29 -07:00
|
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
|
2022-10-04 15:50:29 -07:00
|
|
|
|
|
|
|
|
|
@overload
|
2023-12-08 12:09:04 +00:00
|
|
|
|
def broadcast_shapes(*shapes: tuple[int | core.Tracer, ...]
|
|
|
|
|
) -> tuple[int | core.Tracer, ...]: ...
|
2022-10-04 15:50:29 -07:00
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
@export
|
2022-10-04 15:50:29 -07:00
|
|
|
|
def broadcast_shapes(*shapes):
|
2025-01-31 09:30:56 -08:00
|
|
|
|
"""Returns the shape that results from NumPy broadcasting of `shapes`.
|
|
|
|
|
|
|
|
|
|
This follows the rules of `NumPy broadcasting`_.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
shapes: one or more tuples of integers containing the shapes of arrays
|
|
|
|
|
to be broadcast.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A tuple of integers representing the broadcasted shape.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: if shapes are not broadcast-compatible.
|
|
|
|
|
|
|
|
|
|
See Also:
|
|
|
|
|
- :func:`jax.numpy.broadcast_shapes`: similar API in the JAX NumPy namespace
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
Some examples of broadcasting compatible shapes:
|
|
|
|
|
|
|
|
|
|
>>> jnp.broadcast_shapes((1,), (4,))
|
|
|
|
|
(4,)
|
|
|
|
|
>>> jnp.broadcast_shapes((3, 1), (4,))
|
|
|
|
|
(3, 4)
|
|
|
|
|
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
|
|
|
|
|
(5, 3, 4)
|
|
|
|
|
|
|
|
|
|
Error when attempting to broadcast incompatible shapes:
|
|
|
|
|
|
|
|
|
|
>>> jnp.broadcast_shapes((3, 1), (4, 1)) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
|
|
|
Traceback (most recent call last):
|
|
|
|
|
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]
|
|
|
|
|
|
|
|
|
|
.. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
|
|
|
|
|
"""
|
2022-01-20 22:58:09 -08:00
|
|
|
|
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
|
|
|
|
|
try:
|
|
|
|
|
return _broadcast_shapes_cached(*shapes)
|
|
|
|
|
except:
|
|
|
|
|
return _broadcast_shapes_uncached(*shapes)
|
|
|
|
|
|
|
|
|
|
@cache()
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
|
2022-01-20 22:58:09 -08:00
|
|
|
|
return _broadcast_shapes_uncached(*shapes)
|
|
|
|
|
|
|
|
|
|
def _broadcast_shapes_uncached(*shapes):
|
2022-02-15 15:03:33 +01:00
|
|
|
|
_validate_shapes(shapes)
|
2022-01-20 22:58:09 -08:00
|
|
|
|
fst, *rst = shapes
|
|
|
|
|
if not rst: return fst
|
|
|
|
|
|
|
|
|
|
# First check if we need only rank promotion (and not singleton-broadcasting).
|
2024-12-11 13:50:20 -08:00
|
|
|
|
result_shape = _max(shapes, key=len)
|
|
|
|
|
ndim = len(result_shape)
|
|
|
|
|
if ndim == 0 or all(core.definitely_equal_shape(result_shape[ndim - len(s):], s) for s in shapes):
|
|
|
|
|
return result_shape
|
2022-01-20 22:58:09 -08:00
|
|
|
|
|
|
|
|
|
# Next try singleton-broadcasting, padding out ranks using singletons.
|
2024-12-11 15:03:39 -08:00
|
|
|
|
rank_promoted_shapes = tuple((*((1,) * (ndim - len(shape))), *shape) for shape in shapes)
|
|
|
|
|
try:
|
|
|
|
|
return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
|
|
|
|
|
except TypeError as err:
|
|
|
|
|
# Raise ValueError here for backward compatibility.
|
|
|
|
|
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-01-28 11:04:05 -08:00
|
|
|
|
def broadcast_shardings(*avals):
|
2024-12-11 16:40:46 -08:00
|
|
|
|
fst, *rst = avals
|
|
|
|
|
if not rst:
|
|
|
|
|
return fst.sharding
|
|
|
|
|
|
|
|
|
|
# First check if we need only rank promotion (and not singleton-broadcasting).
|
|
|
|
|
res_aval = _max(avals, key=lambda a: a.ndim)
|
|
|
|
|
ndim = res_aval.ndim
|
|
|
|
|
if ndim == 0 or all(
|
|
|
|
|
res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals):
|
|
|
|
|
return res_aval.sharding
|
|
|
|
|
|
|
|
|
|
# Next try singleton-broadcasting, padding out ranks using singletons.
|
|
|
|
|
aval_list = []
|
|
|
|
|
for a in avals:
|
|
|
|
|
new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec)
|
|
|
|
|
new_shape = (1,) * (ndim - a.ndim) + a.shape
|
|
|
|
|
aval_list.append(a.update(shape=new_shape,
|
|
|
|
|
sharding=a.sharding.with_spec(new_spec)))
|
|
|
|
|
return broadcasting_sharding_rule('broadcast_shardings', *aval_list)
|
|
|
|
|
|
2025-02-21 04:50:35 -08:00
|
|
|
|
def _identity(x, **_): return x
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
def _extract_tracers_dyn_shape(
|
2023-12-08 12:09:04 +00:00
|
|
|
|
shape: Sequence[int | core.Tracer]
|
|
|
|
|
) -> tuple[list[core.Tracer], list[int | None]]:
|
2022-06-29 13:55:30 -07:00
|
|
|
|
# Given a sequence representing a shape, pull out Tracers, replacing with None
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.dynamic_shapes.value:
|
2021-11-16 11:17:42 +02:00
|
|
|
|
# We must gate this behavior under a flag because otherwise the errors
|
|
|
|
|
# raised are different (and have worse source provenance information).
|
2022-06-29 13:55:30 -07:00
|
|
|
|
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
|
|
|
|
|
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
|
2021-11-16 11:17:42 +02:00
|
|
|
|
return dyn_shape, static_shape
|
|
|
|
|
else:
|
2022-06-29 13:55:30 -07:00
|
|
|
|
return [], list(shape) # type: ignore
|
2021-11-16 11:17:42 +02:00
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
def _merge_dyn_shape(
|
2023-12-08 12:09:04 +00:00
|
|
|
|
static_shape: Sequence[int | None],
|
2022-06-29 13:55:30 -07:00
|
|
|
|
dyn_shape: Sequence[Any],
|
2023-12-08 12:09:04 +00:00
|
|
|
|
) -> tuple[int | mlir.Value | core.Tracer, ...]:
|
2022-06-29 13:55:30 -07:00
|
|
|
|
# Replace Nones in static_shape with elements of dyn_shape, in order
|
2022-06-27 16:46:46 +03:00
|
|
|
|
dyn_shape_it = iter(dyn_shape)
|
|
|
|
|
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
|
|
|
|
|
assert next(dyn_shape_it, None) is None
|
2021-11-16 11:17:42 +02:00
|
|
|
|
return shape
|
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params):
|
2021-11-16 11:17:42 +02:00
|
|
|
|
source_info = source_info_util.current()
|
2022-06-29 13:55:30 -07:00
|
|
|
|
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info)
|
|
|
|
|
eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args],
|
|
|
|
|
[trace.makevar(out_tracer)],
|
2021-11-16 11:17:42 +02:00
|
|
|
|
prim, params, core.no_effects, source_info)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
trace.frame.add_eqn(eqn)
|
2021-11-16 11:17:42 +02:00
|
|
|
|
return out_tracer
|
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
### traceables
|
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def neg(x: ArrayLike) -> Array:
|
2025-01-31 09:30:56 -08:00
|
|
|
|
r"""Elementwise negation: :math:`-x`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.negate`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of same shape and dtype as ``x``, containing the element-wise negative.
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
For unsigned integer inputs, this function returns ``2 ** nbits - x``, where
|
|
|
|
|
``nbits`` is the number of bits in the integer representation.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.negate: https://openxla.org/stablehlo/spec#negate
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return neg_p.bind(x)
|
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def sign(x: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise sign.
|
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
This function lowers directly to the `stablehlo.sign`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of same shape and dtype as ``x``, containing the sign
|
|
|
|
|
of the value, as defined in Notes below.
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
For floating-point inputs, returns
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
\mathrm{sign}(x) = \begin{cases}
|
|
|
|
|
-1 & x < 0\\
|
|
|
|
|
-0 & x = -0\\
|
|
|
|
|
\mathit{NaN} & x = \mathit{NaN}\\
|
|
|
|
|
+0 & x = +0\\
|
|
|
|
|
1 & x > 0
|
|
|
|
|
\end{cases}
|
|
|
|
|
|
|
|
|
|
For signed integer inputs, returns
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
\mathrm{sign}(x) = \begin{cases}
|
|
|
|
|
-1 & x < 0\\
|
|
|
|
|
0 & x = 0\\
|
|
|
|
|
1 & x > 0
|
|
|
|
|
\end{cases}
|
|
|
|
|
|
|
|
|
|
For complex inputs, returns the complex phase, i.e.
|
|
|
|
|
:math:`\mathrm{sign}(x) = x / |x|`.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.sign: https://openxla.org/stablehlo/spec#sign
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
return sign_p.bind(x)
|
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
|
2025-01-31 09:30:56 -08:00
|
|
|
|
"""Returns the next representable value after ``x1`` in the direction of ``x2``.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.next_after`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x1, x2: input arrays. Must have a matching floating-point dtypes. If neither is
|
|
|
|
|
a scalar, must have the same number of dimensions and be broadcast-compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same dtype and broadcasted shape of the inputs, containing the
|
|
|
|
|
next representable floating-point value after ``x1`` in the direction of
|
|
|
|
|
``x2``.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
Notes:
|
|
|
|
|
In some environments flush-denormal-to-zero semantics is used.
|
|
|
|
|
This means that, around zero, this function returns strictly non-zero
|
|
|
|
|
values which appear as zero in any operations. Consider this example::
|
2021-01-28 15:20:02 -08:00
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
>>> from jax import lax
|
|
|
|
|
>>> lax.nextafter(0.0, 1.0) # denormal numbers are representable
|
|
|
|
|
Array(1.e-45, dtype=float32, weak_type=True)
|
|
|
|
|
>>> lax.nextafter(0.0, 1.0) * 1 # but are flushed to zero
|
|
|
|
|
Array(0., dtype=float32, weak_type=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-01-31 09:30:56 -08:00
|
|
|
|
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x1, x2 = core.standard_insert_pbroadcast(x1, x2)
|
2021-04-15 15:16:29 -07:00
|
|
|
|
return nextafter_p.bind(x1, x2)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-03 10:19:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def floor(x: ArrayLike) -> Array:
|
2025-02-03 10:19:22 -08:00
|
|
|
|
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.floor`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-02-04 09:33:52 -08:00
|
|
|
|
x: input array. Must have floating-point type.
|
2025-02-03 10:19:22 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of same shape and dtype as ``x``, containing values rounded
|
|
|
|
|
to the next integer toward negative infinity.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
|
|
|
|
|
- :func:`jax.lax.round`: round to the nearest integer
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return floor_p.bind(x)
|
|
|
|
|
|
2025-02-03 10:19:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def ceil(x: ArrayLike) -> Array:
|
2025-02-03 10:19:22 -08:00
|
|
|
|
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.ceil`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-02-04 09:33:52 -08:00
|
|
|
|
x: input array. Must have floating-point type.
|
2025-02-03 10:19:22 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of same shape and dtype as ``x``, containing values rounded
|
|
|
|
|
to the next integer toward positive infinity.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
|
|
|
|
|
- :func:`jax.lax.round`: round to the nearest integer
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return ceil_p.bind(x)
|
|
|
|
|
|
2021-11-23 16:34:33 -08:00
|
|
|
|
class RoundingMethod(enum.IntEnum):
|
2024-10-03 07:22:22 -07:00
|
|
|
|
"""Rounding strategies for handling halfway values (e.g., 0.5) in
|
|
|
|
|
:func:`jax.lax.round`.
|
|
|
|
|
"""
|
|
|
|
|
|
2020-11-23 15:33:04 +01:00
|
|
|
|
AWAY_FROM_ZERO = 0
|
2024-10-03 07:22:22 -07:00
|
|
|
|
"""Rounds halfway values away from zero (e.g., 0.5 -> 1, -0.5 -> -1)."""
|
|
|
|
|
|
2020-11-23 15:33:04 +01:00
|
|
|
|
TO_NEAREST_EVEN = 1
|
2024-10-03 07:22:22 -07:00
|
|
|
|
"""Rounds halfway values to the nearest even integer. This is also known
|
|
|
|
|
as “banker’s rounding” (e.g., 0.5 -> 0, 1.5 -> 2).
|
|
|
|
|
"""
|
2020-11-23 15:33:04 +01:00
|
|
|
|
|
2025-02-03 10:19:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def round(x: ArrayLike,
|
2020-11-23 15:33:04 +01:00
|
|
|
|
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
|
|
|
|
|
) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise round.
|
|
|
|
|
|
2025-02-03 10:19:22 -08:00
|
|
|
|
Rounds values to the nearest integer. This function lowers directly to the
|
|
|
|
|
`stablehlo.round`_ operation.
|
2020-11-23 15:33:04 +01:00
|
|
|
|
|
|
|
|
|
Args:
|
2025-02-03 10:19:22 -08:00
|
|
|
|
x: an array or scalar value to round. Must have floating-point type.
|
2020-11-23 15:33:04 +01:00
|
|
|
|
rounding_method: the method to use when rounding halfway values
|
2025-02-03 10:19:22 -08:00
|
|
|
|
(e.g., ``0.5``). See :class:`jax.lax.RoundingMethod` for possible values.
|
2020-11-23 15:33:04 +01:00
|
|
|
|
|
|
|
|
|
Returns:
|
2025-02-03 10:19:22 -08:00
|
|
|
|
An array of the same shape and dtype as ``x``, containing the elementwise
|
|
|
|
|
rounding of ``x``.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
|
|
|
|
|
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> import jax.numpy as jnp
|
|
|
|
|
>>> from jax import lax
|
|
|
|
|
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
|
|
|
|
|
>>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO
|
|
|
|
|
Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32)
|
|
|
|
|
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
|
|
|
|
|
Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.round: https://openxla.org/stablehlo/spec#round
|
2020-11-23 15:33:04 +01:00
|
|
|
|
"""
|
|
|
|
|
rounding_method = RoundingMethod(rounding_method)
|
|
|
|
|
return round_p.bind(x, rounding_method=rounding_method)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def is_finite(x: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise :math:`\mathrm{isfinite}`.
|
|
|
|
|
|
2025-02-04 09:33:52 -08:00
|
|
|
|
This function lowers directly to the `stablehlo.is_finite`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of boolean dtype with the same shape as ``x``, containing ``False`` where
|
|
|
|
|
``x`` is :math:`\pm\infty` or :math:`\mathit{NaN}`, and ``True`` otherwise.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.isinf`: return True where array is infinite.
|
|
|
|
|
- :func:`jax.numpy.isnan`: return True where array is NaN.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
return is_finite_p.bind(x)
|
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
class Tolerance:
|
|
|
|
|
"""Specify the tolerances used for computing unary functions.
|
|
|
|
|
|
|
|
|
|
Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0):
|
|
|
|
|
if atol < 0.0 or rtol < 0.0 or ulps < 0.0:
|
|
|
|
|
raise ValueError('Tolerances must be non-negative.')
|
|
|
|
|
if atol == 0.0 and rtol == 0.0 and ulps == 0:
|
|
|
|
|
raise ValueError('At least one of atol, rtol, or ulps must be set.')
|
|
|
|
|
|
|
|
|
|
self.atol = atol
|
|
|
|
|
self.rtol = rtol
|
|
|
|
|
self.ulps = ulps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AccuracyMode(enum.Enum):
|
|
|
|
|
HIGHEST = 1
|
|
|
|
|
DEFAULT = 2
|
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def exp(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-04 09:33:52 -08:00
|
|
|
|
r"""Elementwise exponential: :math:`e^x`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.exponential`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
2025-02-06 11:09:42 -08:00
|
|
|
|
exponential.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
|
|
|
|
|
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return exp_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
|
|
|
|
|
def exp2(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-04 09:33:52 -08:00
|
|
|
|
r"""Elementwise base-2 exponential: :math:`2^x`.
|
|
|
|
|
|
|
|
|
|
This function is implemented in terms of the `stablehlo.exponential`_
|
|
|
|
|
and `stablehlo.multiply`_ operations.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
2025-02-06 11:09:42 -08:00
|
|
|
|
base-2 exponential.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
|
|
|
|
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
|
|
|
|
|
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return exp2_p.bind(x, accuracy=accuracy)
|
2023-07-28 09:58:25 -07:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def expm1(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-04 09:33:52 -08:00
|
|
|
|
r"""Elementwise :math:`e^{x} - 1`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.exponential_minus_one`_
|
|
|
|
|
operation. Compared to the naive expression ``lax.exp(x) - 1``, it is
|
|
|
|
|
more accurate for ``x`` near zero.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
2025-02-06 11:09:42 -08:00
|
|
|
|
exponential minus 1.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
|
|
|
|
- :func:`jax.lax.log1p`: elementwise :math:`\mathrm{log}(1 + x)`.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return expm1_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def log(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-04 09:33:52 -08:00
|
|
|
|
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.log`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
2025-02-06 11:09:42 -08:00
|
|
|
|
natural logarithm.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.log: https://openxla.org/stablehlo/spec#log
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return log_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def log1p(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise :math:`\mathrm{log}(1 + x)`.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
|
|
|
|
|
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
|
|
|
|
|
for ``x`` near zero.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
2025-02-06 11:09:42 -08:00
|
|
|
|
natural logarithm of ``x + 1``.
|
2025-02-04 09:33:52 -08:00
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
|
|
|
|
|
- :func:`jax.lax.log`: elementwise natural logarithm :math:`\mathrm{log}(x)`.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return log1p_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def tanh(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-07 09:33:25 -08:00
|
|
|
|
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.tanh`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-07 09:33:25 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
hyperbolic tangent.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent.
|
|
|
|
|
- :func:`jax.lax.cosh`: elementwise hyperbolic cosine.
|
|
|
|
|
- :func:`jax.lax.sinh`: elementwise hyperbolic sine.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return tanh_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
|
|
|
|
|
def logistic(x: ArrayLike, accuracy=None) -> Array:
|
2025-03-18 08:55:38 -07:00
|
|
|
|
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
|
|
|
|
|
|
|
|
|
|
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
|
|
|
|
|
of HLO arithmetic operations.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating point or complex dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
logistic/sigmoid function.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return logistic_p.bind(x, accuracy=accuracy)
|
2022-09-07 06:06:22 -07:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def sin(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.
|
|
|
|
|
|
|
|
|
|
For floating-point inputs, this function lowers directly to the
|
|
|
|
|
`stablehlo.sine`_ operation. For complex inputs, it lowers to a
|
|
|
|
|
sequence of HLO operations implementing the complex sine.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-06 11:09:42 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
sine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.cos`: elementwise cosine.
|
|
|
|
|
- :func:`jax.lax.tan`: elementwise tangent.
|
|
|
|
|
- :func:`jax.lax.asin`: elementwise arc sine.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return sin_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def cos(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.
|
|
|
|
|
|
|
|
|
|
For floating-point inputs, this function lowers directly to the
|
|
|
|
|
`stablehlo.cosine`_ operation. For complex inputs, it lowers to a
|
|
|
|
|
sequence of HLO operations implementing the complex cosine.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-06 11:09:42 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
cosine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.sin`: elementwise sine.
|
|
|
|
|
- :func:`jax.lax.tan`: elementwise tangent.
|
|
|
|
|
- :func:`jax.lax.acos`: elementwise arc cosine.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return cos_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise two-term arc tangent: :math:`\mathrm{atan}({x \over y})`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.atan2`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: input arrays. Must have a matching floating-point or complex dtypes. If
|
|
|
|
|
neither is a scalar, the two arrays must have the same number of dimensions
|
|
|
|
|
and be broadcast-compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` and ``y`` containing the element-wise
|
|
|
|
|
arc tangent of :math:`x \over y`, respecting the quadrant indicated by the sign
|
|
|
|
|
of each input.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.tan`: elementwise tangent.
|
|
|
|
|
- :func:`jax.lax.atan`: elementwise one-term arc tangent.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return atan2_p.bind(x, y)
|
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def real(x: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
|
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
This function lowers directly to the `stablehlo.real`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have complex dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape as ``x`` containing its real part. Will have dtype
|
|
|
|
|
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.complex`: elementwise construct complex number.
|
|
|
|
|
- :func:`jax.lax.imag`: elementwise extract imaginary part.
|
|
|
|
|
- :func:`jax.lax.conj`: elementwise complex conjugate.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
return real_p.bind(x)
|
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def imag(x: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
|
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
This function lowers directly to the `stablehlo.imag`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have complex dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape as ``x`` containing its imaginary part. Will have dtype
|
|
|
|
|
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.complex`: elementwise construct complex number.
|
|
|
|
|
- :func:`jax.lax.real`: elementwise extract real part.
|
|
|
|
|
- :func:`jax.lax.conj`: elementwise complex conjugate.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
return imag_p.bind(x)
|
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def complex(x: ArrayLike, y: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise make complex number: :math:`x + jy`.
|
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
This function lowers directly to the `stablehlo.complex`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: input arrays. Must have matching floating-point dtypes. If
|
|
|
|
|
neither is a scalar, the two arrays must have the same number
|
|
|
|
|
of dimensions and be broadcast-compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The complex array with the real part given by ``x``, and the imaginary
|
|
|
|
|
part given by ``y``. For inputs of dtype float32 or float64, the result
|
|
|
|
|
will have dtype complex64 or complex128 respectively.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.real`: elementwise extract real part.
|
|
|
|
|
- :func:`jax.lax.imag`: elementwise extract imaginary part.
|
|
|
|
|
- :func:`jax.lax.conj`: elementwise complex conjugate.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2021-04-15 15:16:29 -07:00
|
|
|
|
return complex_p.bind(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def conj(x: ArrayLike) -> Array:
|
2025-02-11 14:12:22 -08:00
|
|
|
|
r"""Elementwise complex conjugate function: :math:`\overline{x}`.
|
|
|
|
|
|
|
|
|
|
This function lowers to a combination of `stablehlo.real`_, `stablehlo.imag`_,
|
|
|
|
|
and `stablehlo.complex`_.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have complex dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing its complex conjugate.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.complex`: elementwise construct complex number.
|
|
|
|
|
- :func:`jax.lax.real`: elementwise extract real part.
|
|
|
|
|
- :func:`jax.lax.imag`: elementwise extract imaginary part.
|
|
|
|
|
- :func:`jax.lax.abs`: elementwise absolute value / complex magnitude.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
|
|
|
|
|
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
|
|
|
|
|
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
|
|
|
|
|
"""
|
2023-04-04 20:45:21 -07:00
|
|
|
|
# TODO(mattjj): remove input_dtype, not needed anymore
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return conj_p.bind(x, input_dtype=_dtype(x))
|
|
|
|
|
|
2025-02-11 14:12:22 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def abs(x: ArrayLike) -> Array:
|
2025-02-11 14:12:22 -08:00
|
|
|
|
r"""Elementwise absolute value: :math:`|x|`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.abs`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array. Must have signed integer, floating, or complex dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape as ``x`` containing the elementwise absolute value.
|
|
|
|
|
For complex valued input, :math:`a + ib`, ``abs(x)`` returns :math:`\sqrt{a^2+b^2}`.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.abs`: a more flexible NumPy-style ``abs`` implementation.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return abs_p.bind(x)
|
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def pow(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-14 08:40:19 -08:00
|
|
|
|
r"""Elementwise power: :math:`x^y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.pow`_ operation, along with
|
|
|
|
|
a `stablehlo.convert`_ when the argument dtypes do not match.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array giving the base value. Must have floating or complex type.
|
|
|
|
|
y: Input array giving the exponent value. Must have integer, floating, or
|
|
|
|
|
complex type. Its dtype will be cast to that of ``x.dtype`` if necessary.
|
|
|
|
|
If neither ``x`` nor ``y`` is a scalar, then ``x`` and ``y`` must have
|
|
|
|
|
the same number of dimensions and be broadcast-compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` containing the elementwise power.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
:func:`jax.lax.integer_pow`: Elementwise power where ``y`` is a static integer.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
|
|
|
|
|
.. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return pow_p.bind(x, y)
|
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def integer_pow(x: ArrayLike, y: int) -> Array:
|
2025-02-14 08:40:19 -08:00
|
|
|
|
r"""Elementwise power: :math:`x^y`, where :math:`y` is a static integer.
|
|
|
|
|
|
|
|
|
|
This will lower to a sequence of :math:`O[\log_2(y)]` repetitions of
|
|
|
|
|
`stablehlo.multiply`_.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array giving the base value. Must have numerical dtype.
|
|
|
|
|
y: Static scalar integer giving the exponent.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape and dtype as ``x`` containing the elementwise power.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
:func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
|
|
|
|
"""
|
2021-01-19 11:36:39 -08:00
|
|
|
|
return integer_pow_p.bind(x, y=y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def sqrt(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-14 08:40:19 -08:00
|
|
|
|
r"""Elementwise square root: :math:`\sqrt{x}`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.sqrt`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array. Must have floating or complex dtype.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-14 08:40:19 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape and dtype as ``x`` containing the square root.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
:func:`jax.lax.pow`: Elementwise power.
|
|
|
|
|
:func:`jax.lax.cbrt`: Elementwise cube root.
|
|
|
|
|
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return sqrt_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def rsqrt(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-14 08:40:19 -08:00
|
|
|
|
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.rsqrt`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array. Must have floating or complex dtype.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-14 08:40:19 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape and dtype as ``x`` containing the
|
|
|
|
|
reciporical square root.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
:func:`jax.lax.pow`: Elementwise power.
|
|
|
|
|
:func:`jax.lax.sqrt`: Elementwise square root.
|
|
|
|
|
:func:`jax.lax.cbrt`: Elementwise cube root.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return rsqrt_p.bind(x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def cbrt(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-14 08:40:19 -08:00
|
|
|
|
r"""Elementwise cube root: :math:`\sqrt[3]{x}`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.cbrt`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array. Must have floating or complex dtype.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-14 08:40:19 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape and dtype as ``x`` containing the cube root.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
:func:`jax.lax.pow`: Elementwise power.
|
|
|
|
|
:func:`jax.lax.sqrt`: Elementwise square root.
|
|
|
|
|
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return cbrt_p.bind(x, accuracy=accuracy)
|
2021-07-22 14:00:52 -07:00
|
|
|
|
|
2025-02-14 14:17:30 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def bitwise_not(x: ArrayLike) -> Array:
|
2025-02-14 14:17:30 -08:00
|
|
|
|
r"""Elementwise NOT: :math:`\neg x`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.not`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array. Must have boolean or integer dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape and dtype as ``x`` containing the bitwise
|
|
|
|
|
inversion of each entry.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.invert`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``~x`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.bitwise_and`: Elementwise AND.
|
|
|
|
|
- :func:`jax.lax.bitwise_or`: Elementwise OR.
|
|
|
|
|
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.not: https://openxla.org/stablehlo/spec#not
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return not_p.bind(x)
|
|
|
|
|
|
2025-02-14 14:17:30 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-14 14:17:30 -08:00
|
|
|
|
r"""Elementwise AND: :math:`x \wedge y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.and`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching boolean or integer dtypes.
|
|
|
|
|
If neither is a scalar, ``x`` and ``y`` must have the same number
|
|
|
|
|
of dimensions and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the bitwise
|
|
|
|
|
AND of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.bitwise_and`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x & y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.bitwise_not`: Elementwise NOT.
|
|
|
|
|
- :func:`jax.lax.bitwise_or`: Elementwise OR.
|
|
|
|
|
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.and: https://openxla.org/stablehlo/spec#and
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return and_p.bind(x, y)
|
|
|
|
|
|
2025-02-14 14:17:30 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-14 14:17:30 -08:00
|
|
|
|
r"""Elementwise OR: :math:`x \vee y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.or`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching boolean or integer dtypes.
|
|
|
|
|
If neither is a scalar, ``x`` and ``y`` must have the same number
|
|
|
|
|
of dimensions and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the bitwise
|
|
|
|
|
OR of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.invert`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x | y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.bitwise_not`: Elementwise NOT.
|
|
|
|
|
- :func:`jax.lax.bitwise_and`: Elementwise AND.
|
|
|
|
|
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.or: https://openxla.org/stablehlo/spec#or
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return or_p.bind(x, y)
|
|
|
|
|
|
2025-02-14 14:17:30 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-14 14:17:30 -08:00
|
|
|
|
r"""Elementwise exclusive OR: :math:`x \oplus y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.xor`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching boolean or integer dtypes.
|
|
|
|
|
If neither is a scalar, ``x`` and ``y`` must have the same number
|
|
|
|
|
of dimensions and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the bitwise
|
|
|
|
|
XOR of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.bitwise_xor`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x ^ y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.bitwise_not`: Elementwise NOT.
|
|
|
|
|
- :func:`jax.lax.bitwise_and`: Elementwise AND.
|
|
|
|
|
- :func:`jax.lax.bitwise_or`: Elementwise OR.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return xor_p.bind(x, y)
|
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def population_count(x: ArrayLike) -> Array:
|
2025-03-18 08:55:38 -07:00
|
|
|
|
r"""Elementwise popcount, count the number of set bits in each element.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.popcnt`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array. Must have integer dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape and dtype as ``x``, containing the number of
|
|
|
|
|
set bits in the input.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.clz`: Elementwise count leading zeros.
|
|
|
|
|
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return population_count_p.bind(x)
|
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def clz(x: ArrayLike) -> Array:
|
2025-03-18 08:55:38 -07:00
|
|
|
|
r"""Elementwise count-leading-zeros.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: Input array. Must have integer dtype.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same shape and dtype as ``x``, containing the number of
|
|
|
|
|
set bits in the input.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
|
|
|
|
|
"""
|
2021-03-19 22:35:31 -07:00
|
|
|
|
return clz_p.bind(x)
|
|
|
|
|
|
2025-03-15 11:49:51 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def add(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-03-15 11:49:51 -07:00
|
|
|
|
r"""Elementwise addition: :math:`x + y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.add`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
|
|
|
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
|
|
|
|
and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the sum
|
|
|
|
|
of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.add`: NumPy-style addition supporting inputs
|
|
|
|
|
with mixed dtypes and ranks.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.add: https://openxla.org/stablehlo/spec#add
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return add_p.bind(x, y)
|
|
|
|
|
|
2025-03-15 11:49:51 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def sub(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-03-15 11:49:51 -07:00
|
|
|
|
r"""Elementwise subtraction: :math:`x - y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.subtract`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
|
|
|
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
|
|
|
|
and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the difference
|
|
|
|
|
of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.subtract`: NumPy-style subtraction supporting
|
|
|
|
|
inputs with mixed dtypes and ranks.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return sub_p.bind(x, y)
|
|
|
|
|
|
2025-03-15 11:49:51 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def mul(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-03-15 11:49:51 -07:00
|
|
|
|
r"""Elementwise multiplication: :math:`x \times y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.multiply`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
|
|
|
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
|
|
|
|
and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the product
|
|
|
|
|
of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.multiply`: NumPy-style multiplication supporting
|
|
|
|
|
inputs with mixed dtypes and ranks.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
|
|
|
|
"""
|
2025-03-21 10:25:38 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return mul_p.bind(x, y)
|
|
|
|
|
|
2025-03-15 11:49:51 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def div(x: ArrayLike, y: ArrayLike) -> Array:
|
2022-07-18 16:39:52 +01:00
|
|
|
|
r"""Elementwise division: :math:`x \over y`.
|
|
|
|
|
|
2025-03-15 11:49:51 -07:00
|
|
|
|
This function lowers directly to the `stablehlo.divide`_ operation.
|
|
|
|
|
|
|
|
|
|
Integer division overflow (division by zero or signed division of
|
|
|
|
|
INT_SMIN with -1) produces an implementation defined value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching numerical dtypes. If neither
|
|
|
|
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
|
|
|
|
and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the quotient
|
|
|
|
|
of each pair of broadcasted entries. For integer inputs, any fractional
|
|
|
|
|
part is discarded.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.divide`: NumPy-style true division supporting
|
|
|
|
|
inputs with mixed dtypes and ranks.
|
|
|
|
|
- :func:`jax.numpy.floor_divide`: NumPy-style floor division supporting
|
|
|
|
|
inputs with mixed dtypes and ranks.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide
|
2022-07-18 16:39:52 +01:00
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return div_p.bind(x, y)
|
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def rem(x: ArrayLike, y: ArrayLike) -> Array:
|
2022-07-18 16:39:52 +01:00
|
|
|
|
r"""Elementwise remainder: :math:`x \bmod y`.
|
2022-07-20 10:53:35 +01:00
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
This function lowers directly to the `stablehlo.remainder`_ operation.
|
|
|
|
|
The sign of the result is taken from the dividend, and the absolute value
|
|
|
|
|
of the result is always less than the divisor's absolute value.
|
2022-07-18 16:39:52 +01:00
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
|
2022-07-18 16:39:52 +01:00
|
|
|
|
produces an implementation defined value.
|
2025-03-18 08:55:38 -07:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching int or float dtypes. If neither
|
|
|
|
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
|
|
|
|
and be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the remainder.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
|
|
|
|
|
sign semantics.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
|
2022-07-18 16:39:52 +01:00
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return rem_p.bind(x, y)
|
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def max(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-03-18 08:55:38 -07:00
|
|
|
|
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.maximum`_ operation for
|
|
|
|
|
non-complex inputs. For complex numbers, this uses a lexicographic
|
|
|
|
|
comparison on the `(real, imaginary)` pairs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
|
|
|
|
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
|
|
|
|
maximum.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
|
|
|
|
|
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
|
|
|
|
|
- :func:`jax.lax.min`: elementwise minimum.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return max_p.bind(x, y)
|
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def min(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-03-18 08:55:38 -07:00
|
|
|
|
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.minimum`_ operation for
|
|
|
|
|
non-complex inputs. For complex numbers, this uses a lexicographic
|
|
|
|
|
comparison on the `(real, imaginary)` pairs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
|
|
|
|
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
|
|
|
|
minimum.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
|
|
|
|
|
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
|
|
|
|
|
- :func:`jax.lax.max`: elementwise maximum.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return min_p.bind(x, y)
|
|
|
|
|
|
2025-02-14 14:17:30 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def shift_left(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-14 14:17:30 -08:00
|
|
|
|
r"""Elementwise left shift: :math:`x \ll y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.shift_left`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching integer dtypes. If neither is a
|
|
|
|
|
scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the element-wise
|
|
|
|
|
left shift of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.left_shift`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x << y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.shift_right_arithmetic`: Elementwise arithmetic right shift.
|
|
|
|
|
- :func:`jax.lax.shift_right_logical`: Elementwise logical right shift.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return shift_left_p.bind(x, y)
|
|
|
|
|
|
2025-02-14 14:17:30 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-14 14:17:30 -08:00
|
|
|
|
r"""Elementwise arithmetic right shift: :math:`x \gg y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.shift_right_arithmetic`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching integer dtypes. If neither is a
|
|
|
|
|
scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the element-wise
|
|
|
|
|
arithmetic right shift of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.right_shift`: NumPy wrapper for this API when applied to
|
|
|
|
|
signed integers, also accessible via the ``x >> y`` operator on JAX arrays
|
|
|
|
|
with signed integer dtype.
|
|
|
|
|
- :func:`jax.lax.shift_left`: Elementwise left shift.
|
|
|
|
|
- :func:`jax.lax.shift_right_logical`: Elementwise logical right shift.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return shift_right_arithmetic_p.bind(x, y)
|
|
|
|
|
|
2025-02-14 14:17:30 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-14 14:17:30 -08:00
|
|
|
|
r"""Elementwise logical right shift: :math:`x \gg y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.shift_right_logical`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching integer dtypes. If neither is a
|
|
|
|
|
scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``x`` and ``y`` containing the element-wise
|
|
|
|
|
logical right shift of each pair of broadcasted entries.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.right_shift`: NumPy wrapper for this API when applied to
|
|
|
|
|
unsigned integers, also accessible via the ``x >> y`` operator on JAX arrays
|
|
|
|
|
with unsigned integer dtype.
|
|
|
|
|
- :func:`jax.lax.shift_left`: Elementwise left shift.
|
|
|
|
|
- :func:`jax.lax.shift_right_arithmetic`: Elementwise arithmetic right shift.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return shift_right_logical_p.bind(x, y)
|
|
|
|
|
|
2025-02-18 13:48:59 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def eq(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-18 13:48:59 -08:00
|
|
|
|
r"""Elementwise equals: :math:`x = y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.compare`_ operation
|
|
|
|
|
with ``comparison_direction=EQ`` and ``compare_type`` set according
|
|
|
|
|
to the input dtype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching dtypes. If neither is a
|
|
|
|
|
scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
|
|
|
|
containing the elementwise equal comparison.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.equal`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x == y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.ne`: elementwise not-equal
|
|
|
|
|
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
|
|
|
|
- :func:`jax.lax.gt`: elementwise greater-than
|
|
|
|
|
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
|
|
|
|
- :func:`jax.lax.lt`: elementwise less-than
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return eq_p.bind(x, y)
|
|
|
|
|
|
2025-02-18 13:48:59 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def ne(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-18 13:48:59 -08:00
|
|
|
|
r"""Elementwise not-equals: :math:`x \neq y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.compare`_ operation
|
|
|
|
|
with ``comparison_direction=NE`` and ``compare_type`` set according
|
|
|
|
|
to the input dtype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching dtypes. If neither is a
|
|
|
|
|
scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
|
|
|
|
containing the elementwise not-equal comparison.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.not_equal`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x != y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.eq`: elementwise equal
|
|
|
|
|
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
|
|
|
|
- :func:`jax.lax.gt`: elementwise greater-than
|
|
|
|
|
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
|
|
|
|
- :func:`jax.lax.lt`: elementwise less-than
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return ne_p.bind(x, y)
|
|
|
|
|
|
2025-02-18 13:48:59 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def ge(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-18 13:48:59 -08:00
|
|
|
|
r"""Elementwise greater-than-or-equals: :math:`x \geq y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.compare`_ operation
|
|
|
|
|
with ``comparison_direction=GE`` and ``compare_type`` set according
|
|
|
|
|
to the input dtype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
|
|
|
|
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
|
|
|
|
containing the elementwise greater-than-or-equal comparison.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.greater_equal`: NumPy wrapper for this API, also
|
|
|
|
|
accessible via the ``x >= y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.eq`: elementwise equal
|
|
|
|
|
- :func:`jax.lax.ne`: elementwise not-equal
|
|
|
|
|
- :func:`jax.lax.gt`: elementwise greater-than
|
|
|
|
|
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
|
|
|
|
- :func:`jax.lax.lt`: elementwise less-than
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return ge_p.bind(x, y)
|
|
|
|
|
|
2025-02-18 13:48:59 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def gt(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-18 13:48:59 -08:00
|
|
|
|
r"""Elementwise greater-than: :math:`x > y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.compare`_ operation
|
|
|
|
|
with ``comparison_direction=GT`` and ``compare_type`` set according
|
|
|
|
|
to the input dtype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
|
|
|
|
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
|
|
|
|
containing the elementwise greater-than comparison.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.greater`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x > y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.eq`: elementwise equal
|
|
|
|
|
- :func:`jax.lax.ne`: elementwise not-equal
|
|
|
|
|
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
|
|
|
|
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
|
|
|
|
- :func:`jax.lax.lt`: elementwise less-than
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return gt_p.bind(x, y)
|
|
|
|
|
|
2025-02-18 13:48:59 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def le(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-18 13:48:59 -08:00
|
|
|
|
r"""Elementwise less-than-or-equals: :math:`x \leq y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.compare`_ operation
|
|
|
|
|
with ``comparison_direction=LE`` and ``compare_type`` set according
|
|
|
|
|
to the input dtype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
|
|
|
|
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
|
|
|
|
containing the elementwise less-than-or-equal comparison.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.less_equal`: NumPy wrapper for this API, also
|
|
|
|
|
accessible via the ``x <= y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.eq`: elementwise equal
|
|
|
|
|
- :func:`jax.lax.ne`: elementwise not-equal
|
|
|
|
|
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
|
|
|
|
- :func:`jax.lax.gt`: elementwise greater-than
|
|
|
|
|
- :func:`jax.lax.lt`: elementwise less-than
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return le_p.bind(x, y)
|
|
|
|
|
|
2025-02-18 13:48:59 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
2025-02-18 13:48:59 -08:00
|
|
|
|
r"""Elementwise less-than: :math:`x < y`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.compare`_ operation
|
|
|
|
|
with ``comparison_direction=LT`` and ``compare_type`` set according
|
|
|
|
|
to the input dtype.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
|
|
|
|
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
|
|
|
|
be broadcast compatible.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
|
|
|
|
containing the elementwise less-than comparison.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.less`: NumPy wrapper for this API, also accessible
|
|
|
|
|
via the ``x < y`` operator on JAX arrays.
|
|
|
|
|
- :func:`jax.lax.eq`: elementwise equal
|
|
|
|
|
- :func:`jax.lax.ne`: elementwise not-equal
|
|
|
|
|
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
|
|
|
|
- :func:`jax.lax.gt`: elementwise greater-than
|
|
|
|
|
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
x, y = core.standard_insert_pbroadcast(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return lt_p.bind(x, y)
|
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2024-07-11 05:25:49 -07:00
|
|
|
|
def convert_element_type(operand: ArrayLike,
|
|
|
|
|
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Elementwise cast.
|
2021-03-28 10:32:02 -07:00
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
This function lowers directly to the `stablehlo.convert`_ operation, which
|
|
|
|
|
performs an elementwise conversion from one type to another, similar to a
|
|
|
|
|
C++ ``static_cast``.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Args:
|
2024-06-25 10:38:37 +02:00
|
|
|
|
operand: an array or scalar value to be cast.
|
2025-03-18 08:55:38 -07:00
|
|
|
|
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
|
|
|
|
|
or a valid dtype name) representing the target dtype.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Returns:
|
2025-03-18 08:55:38 -07:00
|
|
|
|
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
|
|
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
|
|
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
|
|
|
|
|
the appropriate 32-bit type will be used in its place.
|
|
|
|
|
|
|
|
|
|
If the input is a JAX array and the input dtype and output dtype match, then
|
|
|
|
|
the input array will be returned unmodified.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
|
|
|
|
|
- :meth:`jax.Array.astype`: dtype casting as an array method.
|
|
|
|
|
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
|
|
|
|
|
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
|
2021-02-08 13:37:25 -08:00
|
|
|
|
|
2024-07-11 05:25:49 -07:00
|
|
|
|
def _convert_element_type(
|
|
|
|
|
operand: ArrayLike,
|
|
|
|
|
new_dtype: DTypeLike | dtypes.ExtendedDType | None = None,
|
|
|
|
|
weak_type: bool = False,
|
2024-12-20 15:43:03 -05:00
|
|
|
|
sharding: Sharding | None = None,
|
|
|
|
|
warn_on_complex_to_real_cast: bool = True):
|
2023-04-05 20:24:22 -07:00
|
|
|
|
if hasattr(operand, '__jax_array__'):
|
2024-05-17 09:46:36 +01:00
|
|
|
|
operand = operand.__jax_array__()
|
2023-04-05 20:24:22 -07:00
|
|
|
|
|
2021-03-21 19:38:12 -07:00
|
|
|
|
# Don't canonicalize old_dtype because x64 context might cause
|
|
|
|
|
# un-canonicalized operands to be passed in.
|
2021-12-01 10:33:26 -08:00
|
|
|
|
old_dtype = dtypes.dtype(operand, canonicalize=False)
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
|
|
|
|
|
if (isinstance(new_dtype, dtypes.ExtendedDType) or
|
|
|
|
|
isinstance(old_dtype, dtypes.ExtendedDType)):
|
2025-03-01 18:17:36 -08:00
|
|
|
|
if new_dtype == old_dtype:
|
|
|
|
|
if sharding is None:
|
|
|
|
|
return operand
|
|
|
|
|
if isinstance(operand, core.Tracer) and operand.aval.sharding == sharding:
|
|
|
|
|
return operand
|
|
|
|
|
if sharding is not None or weak_type:
|
|
|
|
|
raise NotImplementedError
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
if (isinstance(new_dtype, dtypes.ExtendedDType) and
|
|
|
|
|
isinstance(old_dtype, dtypes.ExtendedDType)):
|
|
|
|
|
old_rep_dtype = core.physical_element_aval(old_dtype).dtype
|
|
|
|
|
new_rep_dtype = core.physical_element_aval(new_dtype).dtype
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"cannot directly convert between extended dtypes: from "
|
|
|
|
|
f"{dtype_to_string(old_dtype)} to {dtype_to_string(new_dtype)}. "
|
|
|
|
|
"Instead, convert to and from their representation dtypes, e.g.:\n"
|
|
|
|
|
f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} "
|
|
|
|
|
f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}")
|
2025-03-25 17:02:45 -07:00
|
|
|
|
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
if isinstance(new_dtype, dtypes.ExtendedDType):
|
|
|
|
|
return to_edtype_p.bind(operand, edtype=new_dtype)
|
|
|
|
|
return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype))
|
|
|
|
|
|
|
|
|
|
new_dtype = type_cast(DTypeLike | None, new_dtype)
|
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
|
old_weak_type = dtypes.is_weakly_typed(operand)
|
2021-12-01 10:33:26 -08:00
|
|
|
|
if new_dtype is None:
|
|
|
|
|
new_dtype = old_dtype
|
|
|
|
|
else:
|
|
|
|
|
new_dtype = np.dtype(new_dtype)
|
|
|
|
|
new_dtype = dtypes.dtype(new_dtype, canonicalize=True)
|
2021-03-12 15:26:06 -08:00
|
|
|
|
|
2025-01-28 11:04:05 -08:00
|
|
|
|
if sharding is not None and not isinstance(sharding, Sharding):
|
|
|
|
|
raise ValueError(f'{sharding=} must be an instance of jax.sharding.Sharding')
|
2025-01-08 11:10:37 -08:00
|
|
|
|
|
2024-12-20 15:43:03 -05:00
|
|
|
|
if (warn_on_complex_to_real_cast and
|
|
|
|
|
dtypes.issubdtype(old_dtype, np.complexfloating) and
|
2020-10-17 14:33:26 -04:00
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
|
|
|
msg = "Casting complex values to real discards the imaginary part"
|
2023-08-07 19:08:41 +02:00
|
|
|
|
warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
|
2021-02-08 13:37:25 -08:00
|
|
|
|
|
2021-03-21 19:38:12 -07:00
|
|
|
|
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
|
|
|
|
|
# not be an error since the target dtype fits the value. Handle this case by
|
2021-03-21 19:41:04 -07:00
|
|
|
|
# converting to a NumPy array before calling bind. Without this step, we'd
|
|
|
|
|
# first canonicalize the input to a value of dtype int32 or int64, leading to
|
|
|
|
|
# an overflow error.
|
2021-03-21 19:38:12 -07:00
|
|
|
|
if type(operand) is int:
|
2023-04-04 15:57:32 -07:00
|
|
|
|
operand = np.asarray(operand).astype(new_dtype)
|
2022-01-20 22:58:09 -08:00
|
|
|
|
old_weak_type = False
|
2021-03-21 19:38:12 -07:00
|
|
|
|
|
2023-07-20 09:21:55 -07:00
|
|
|
|
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
|
|
|
|
|
isinstance(operand, Array) and
|
2024-10-31 14:06:08 -07:00
|
|
|
|
not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and
|
2025-02-05 11:49:11 -08:00
|
|
|
|
(sharding is None or
|
|
|
|
|
(sharding._is_concrete and getattr(operand, 'sharding', None) == sharding))):
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return operand
|
2021-02-08 13:37:25 -08:00
|
|
|
|
else:
|
2024-07-09 07:32:38 -07:00
|
|
|
|
return convert_element_type_p.bind(
|
|
|
|
|
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
|
|
|
|
sharding=sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
@export
|
2022-09-14 15:03:55 -07:00
|
|
|
|
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Elementwise bitcast.
|
|
|
|
|
|
2025-03-18 08:55:38 -07:00
|
|
|
|
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
|
2023-02-16 08:21:18 -08:00
|
|
|
|
|
|
|
|
|
The output shape depends on the size of the input and output dtypes with
|
|
|
|
|
the following logic::
|
|
|
|
|
|
|
|
|
|
if new_dtype.itemsize == operand.dtype.itemsize:
|
|
|
|
|
output_shape = operand.shape
|
|
|
|
|
if new_dtype.itemsize < operand.dtype.itemsize:
|
|
|
|
|
output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize)
|
|
|
|
|
if new_dtype.itemsize > operand.dtype.itemsize:
|
|
|
|
|
assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize
|
|
|
|
|
output_shape = operand.shape[:-1]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: an array or scalar value to be cast
|
|
|
|
|
new_dtype: the new type. Should be a NumPy type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
2023-02-16 08:21:18 -08:00
|
|
|
|
An array of shape `output_shape` (see above) and type `new_dtype`,
|
|
|
|
|
constructed from the same bits as operand.
|
2025-03-18 08:55:38 -07:00
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
|
|
|
|
|
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
2021-02-08 13:37:25 -08:00
|
|
|
|
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise clamp.
|
|
|
|
|
|
|
|
|
|
Returns :math:`\mathrm{clamp}(x) = \begin{cases}
|
|
|
|
|
\mathit{min} & \text{if } x < \mathit{min},\\
|
|
|
|
|
\mathit{max} & \text{if } x > \mathit{max},\\
|
|
|
|
|
x & \text{otherwise}
|
|
|
|
|
\end{cases}`.
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
min, x, max = core.standard_insert_pbroadcast(min, x, max)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return clamp_p.bind(min, x, max)
|
|
|
|
|
|
2024-12-18 19:37:58 -08:00
|
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
|
def _trace_composite_to_jaxpr(fun: Callable,
|
|
|
|
|
in_tree: tree_util.PyTreeDef,
|
|
|
|
|
in_avals: Sequence[core.AbstractValue],
|
|
|
|
|
name: str,
|
2025-01-31 22:23:20 +02:00
|
|
|
|
debug_info: core.DebugInfo):
|
2025-01-24 10:57:28 +02:00
|
|
|
|
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
|
|
|
|
|
lu.wrap_init(fun, debug_info=debug_info), in_tree)
|
|
|
|
|
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
2025-01-21 13:28:08 -08:00
|
|
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
|
|
|
|
raise UnexpectedTracerError(
|
|
|
|
|
"Found a JAX Tracer as a constant in the decomposition for the "
|
|
|
|
|
f"composite op '{name}'. This means that the decomposition function "
|
|
|
|
|
"closes over a value that is involved in a JAX transformation. "
|
|
|
|
|
"Any values that aren't explicitly known at compile time must be "
|
|
|
|
|
"explicitly passed as arguments to the composite.")
|
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
2024-12-18 19:37:58 -08:00
|
|
|
|
return closed_jaxpr, out_tree
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def composite(
|
|
|
|
|
decomposition: Callable,
|
|
|
|
|
name: str,
|
|
|
|
|
version: int = 0,
|
|
|
|
|
):
|
|
|
|
|
"""Composite with semantics defined by the decomposition function.
|
|
|
|
|
|
2025-03-11 10:36:32 -07:00
|
|
|
|
A composite is a higher-order JAX function that encapsulates an operation made
|
2024-12-18 19:37:58 -08:00
|
|
|
|
up (composed) of other JAX functions. The semantics of the op are implemented
|
|
|
|
|
by the ``decomposition`` function. In other words, the defined composite
|
|
|
|
|
function can be replaced with its decomposed implementation without changing
|
|
|
|
|
the semantics of the encapsulated operation.
|
|
|
|
|
|
|
|
|
|
The compiler can recognize specific composite operations by their ``name``,
|
2025-03-11 10:36:32 -07:00
|
|
|
|
``version``, ``kwargs``, and dtypes to emit more efficient code, potentially
|
2024-12-18 19:37:58 -08:00
|
|
|
|
leveraging hardware-specific instructions or optimizations. If the compiler
|
|
|
|
|
doesn't recognize the composite, it falls back to compiling the
|
|
|
|
|
``decomposition`` function.
|
|
|
|
|
|
|
|
|
|
Consider a "tangent" composite operation. Its ``decomposition`` function could
|
|
|
|
|
be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could
|
|
|
|
|
recognize the "tangent" composite and emit a single ``tangent`` instruction
|
|
|
|
|
instead of three separate instructions (``sin``, ``divide``, and ``cos``).
|
2025-03-11 10:36:32 -07:00
|
|
|
|
For hardware without dedicated tangent support, it would fall back to
|
|
|
|
|
compiling the decomposition.
|
2024-12-18 19:37:58 -08:00
|
|
|
|
|
2025-03-11 10:36:32 -07:00
|
|
|
|
This is useful for preserving high-level abstractions that would otherwise be
|
|
|
|
|
lost while lowering, which allows for easier pattern-matching in low-level IR.
|
2024-12-18 19:37:58 -08:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
decomposition: function that implements the semantics of the composite op.
|
|
|
|
|
name: name of the encapsulated operation.
|
|
|
|
|
version: optional int to indicate semantic changes to the composite.
|
|
|
|
|
|
|
|
|
|
Returns:
|
2025-03-11 10:36:32 -07:00
|
|
|
|
Callable: Returns a composite function. Note that positional arguments to
|
|
|
|
|
this function should be interpreted as inputs and keyword arguments should
|
|
|
|
|
be interpreted as attributes of the op. Any keyword arguments that are
|
|
|
|
|
passed with ``None`` as a value will be omitted from the
|
|
|
|
|
``composite_attributes``.
|
2024-12-18 19:37:58 -08:00
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
Tangent kernel:
|
2025-03-11 10:36:32 -07:00
|
|
|
|
|
2024-12-18 19:37:58 -08:00
|
|
|
|
>>> def my_tangent_composite(x):
|
|
|
|
|
... return lax.composite(
|
2025-03-11 10:36:32 -07:00
|
|
|
|
... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
|
2024-12-18 19:37:58 -08:00
|
|
|
|
... )(x)
|
2025-03-11 10:36:32 -07:00
|
|
|
|
>>>
|
2024-12-18 19:37:58 -08:00
|
|
|
|
>>> pi = jnp.pi
|
|
|
|
|
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
|
|
|
|
|
>>> with jnp.printoptions(precision=3, suppress=True):
|
|
|
|
|
... print(my_tangent_composite(x))
|
|
|
|
|
... print(lax.tan(x))
|
|
|
|
|
[ 0. 1. -1. 0.]
|
|
|
|
|
[ 0. 1. -1. 0.]
|
2025-01-09 10:54:09 -08:00
|
|
|
|
|
2025-03-11 10:36:32 -07:00
|
|
|
|
The recommended way to create composites is via a decorator. Use ``/`` and
|
|
|
|
|
``*`` in the function signature to be explicit about positional and keyword
|
|
|
|
|
arguments, respectively:
|
|
|
|
|
|
2025-01-09 10:54:09 -08:00
|
|
|
|
>>> @partial(lax.composite, name="my.softmax")
|
|
|
|
|
... def my_softmax_composite(x, /, *, axis):
|
|
|
|
|
... return jax.nn.softmax(x, axis)
|
2024-12-18 19:37:58 -08:00
|
|
|
|
"""
|
|
|
|
|
@functools.wraps(decomposition)
|
|
|
|
|
def _decorator(*args, **kwargs):
|
2025-01-31 22:23:20 +02:00
|
|
|
|
debug_info = api_util.debug_info("composite", decomposition,
|
|
|
|
|
args, kwargs)
|
2024-12-18 19:37:58 -08:00
|
|
|
|
flat_args, in_tree = tree_util.tree_flatten(args)
|
|
|
|
|
in_avals = tuple(core.get_aval(x) for x in flat_args)
|
2025-02-07 06:01:38 -08:00
|
|
|
|
if any(isinstance(v, core.Tracer) for v in kwargs.values()):
|
|
|
|
|
raise UnexpectedTracerError(
|
|
|
|
|
"Found a JAX Tracer as an attribute in the decomposition for the "
|
|
|
|
|
f"composite op '{name}'. This means that the decomposition function "
|
|
|
|
|
"closes over a value that is involved in a JAX transformation. "
|
|
|
|
|
"Any values that aren't explicitly known at compile time must be "
|
|
|
|
|
"explicitly passed as arguments to the composite."
|
|
|
|
|
"\n\nNote: If you are passing jax arrays as attributes, use numpy "
|
|
|
|
|
"arrays instead.")
|
2024-12-18 19:37:58 -08:00
|
|
|
|
closed_jaxpr, out_tree = _trace_composite_to_jaxpr(
|
[better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.
This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-25 07:16:25 +02:00
|
|
|
|
partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info
|
2024-12-18 19:37:58 -08:00
|
|
|
|
)
|
2025-03-25 17:02:45 -07:00
|
|
|
|
flat_args = core.standard_insert_pbroadcast(*flat_args)
|
2024-12-18 19:37:58 -08:00
|
|
|
|
out_flat = composite_p.bind(
|
|
|
|
|
*flat_args,
|
|
|
|
|
name=name,
|
|
|
|
|
attributes=tuple((k, v) for k, v in kwargs.items()),
|
|
|
|
|
version=version,
|
|
|
|
|
jaxpr=closed_jaxpr,
|
|
|
|
|
)
|
|
|
|
|
return tree_util.tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
|
|
|
|
|
return _decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _composite_lowering(
|
|
|
|
|
ctx: mlir.LoweringRuleContext,
|
|
|
|
|
*args: Any,
|
|
|
|
|
name: str,
|
|
|
|
|
attributes: Sequence[tuple[str, Any]],
|
|
|
|
|
version: int,
|
|
|
|
|
jaxpr: core.ClosedJaxpr,
|
|
|
|
|
):
|
|
|
|
|
"""Makes composite which calls the implementation function.
|
|
|
|
|
|
|
|
|
|
Lowering a composite primitive to a ``stablehlo.composite`` op.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ctx: The MLIR context.
|
|
|
|
|
*args: The arguments to the composite.
|
|
|
|
|
name: The name of the composite.
|
|
|
|
|
attributes: The attributes of the composite.
|
|
|
|
|
version: The version of the composite.
|
|
|
|
|
jaxpr: The jaxpr of the underlying composite.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The results of the composite.
|
|
|
|
|
"""
|
|
|
|
|
func_op, _, _ = mlir.lower_called_computation(
|
|
|
|
|
name,
|
|
|
|
|
ctx.name_stack,
|
|
|
|
|
jaxpr,
|
|
|
|
|
ctx.module_context,
|
|
|
|
|
ctx.avals_out,
|
|
|
|
|
ctx.tokens_in,
|
|
|
|
|
)
|
2025-02-11 15:05:10 -08:00
|
|
|
|
composite_attrs = {
|
|
|
|
|
k : mlir.ir_attribute(v)
|
|
|
|
|
for k, v in attributes
|
|
|
|
|
if v is not None
|
|
|
|
|
}
|
2024-12-18 19:37:58 -08:00
|
|
|
|
symbol_name = func_op.name.value
|
|
|
|
|
composite = hlo.CompositeOp(
|
|
|
|
|
func_op.type.results,
|
|
|
|
|
mlir.flatten_ir_values(args),
|
|
|
|
|
name=ir.StringAttr.get(name),
|
|
|
|
|
decomposition=ir.FlatSymbolRefAttr.get(symbol_name),
|
|
|
|
|
composite_attributes=ir.DictAttr.get(composite_attrs),
|
|
|
|
|
version=mlir.i32_attr(version),
|
|
|
|
|
)
|
|
|
|
|
return composite.results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _composite_impl(*args, jaxpr, **_):
|
|
|
|
|
return core.jaxpr_as_fun(jaxpr)(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _composite_abstract_eval(*args, jaxpr, **_):
|
|
|
|
|
del args
|
|
|
|
|
return jaxpr.out_avals
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def composite_jvp(*args, **_):
|
|
|
|
|
del args
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"JVP rule for composite not implemented. You can use `jax.custom_jvp` to "
|
|
|
|
|
"add support. See "
|
|
|
|
|
"https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def composite_transpose(*args, **_):
|
|
|
|
|
del args
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Transpose rule for composite not implemented. You can use"
|
|
|
|
|
"`jax.custom_jvp` or `jax.custom_vjp` to add support. See "
|
|
|
|
|
"https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
composite_p = core.Primitive("composite")
|
|
|
|
|
composite_p.def_impl(_composite_impl)
|
|
|
|
|
composite_p.def_abstract_eval(_composite_abstract_eval)
|
|
|
|
|
composite_p.multiple_results = True
|
|
|
|
|
ad.primitive_jvps[composite_p] = composite_jvp
|
|
|
|
|
ad.primitive_transposes[composite_p] = composite_transpose
|
|
|
|
|
mlir.register_lowering(composite_p, _composite_lowering)
|
|
|
|
|
|
|
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
|
def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Concatenates a sequence of arrays along `dimension`.
|
|
|
|
|
|
|
|
|
|
Wraps XLA's `Concatenate
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
|
|
|
|
|
operator.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operands: a sequence of arrays to concatenate. The arrays must have equal
|
|
|
|
|
shapes, except in the `dimension` axis.
|
|
|
|
|
dimension: the dimension along which to concatenate the arrays.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array containing the concatenation.
|
|
|
|
|
"""
|
2021-11-16 17:36:28 -05:00
|
|
|
|
if len(operands) == 0:
|
|
|
|
|
raise ValueError("concatenate requires a non-empty sequences of arrays")
|
2022-07-07 16:44:00 -07:00
|
|
|
|
if len(operands) == 1:
|
|
|
|
|
op, = operands
|
2022-12-05 16:13:33 -08:00
|
|
|
|
if isinstance(op, Array):
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return op
|
2025-03-25 17:02:45 -07:00
|
|
|
|
operands = core.standard_insert_pbroadcast(*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return concatenate_p.bind(*operands, dimension=dimension)
|
|
|
|
|
|
2021-11-08 14:15:31 -08:00
|
|
|
|
|
2024-12-17 10:05:58 -08:00
|
|
|
|
def split(operand: ArrayLike, sizes: Sequence[int],
|
|
|
|
|
axis: int = 0) -> Sequence[Array]:
|
|
|
|
|
"""Splits an array along ``axis``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: an array to split
|
|
|
|
|
sizes: the sizes of the split arrays. The sum of the sizes must be equal
|
|
|
|
|
to the size of the ``axis`` dimension of ``operand``.
|
|
|
|
|
axis: the axis along which to split the array.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A sequence of ``len(sizes)`` arrays. If ``sizes`` is
|
|
|
|
|
``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``,
|
|
|
|
|
taken along ``axis``.
|
|
|
|
|
"""
|
|
|
|
|
operand = asarray(operand)
|
|
|
|
|
return split_p.bind(operand, sizes=tuple(sizes),
|
|
|
|
|
axis=canonicalize_axis(axis, operand.ndim))
|
|
|
|
|
|
|
|
|
|
|
2024-03-04 14:00:36 -08:00
|
|
|
|
_precision_strings: dict[Any, Precision] = {}
|
2021-11-08 14:15:31 -08:00
|
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
|
class Precision(enum.Enum):
|
|
|
|
|
"""Precision enum for lax matrix multiply related functions.
|
|
|
|
|
|
|
|
|
|
The device-dependent `precision` argument to JAX functions generally
|
|
|
|
|
controls the tradeoff between speed and accuracy for array computations on
|
|
|
|
|
accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends.
|
|
|
|
|
This only has an effect on float32 computations, and does not affect the
|
|
|
|
|
input/output datatypes. Members are:
|
|
|
|
|
|
|
|
|
|
DEFAULT:
|
|
|
|
|
Fastest mode, but least accurate. On TPU: performs float32 computations in
|
|
|
|
|
bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100
|
|
|
|
|
GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases:
|
|
|
|
|
``'default'``, ``'fastest'``.
|
|
|
|
|
HIGH:
|
|
|
|
|
Slower but more accurate. On TPU: performs float32 computations in 3
|
|
|
|
|
bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise
|
|
|
|
|
float32. Aliases: ``'high'``..
|
|
|
|
|
HIGHEST:
|
|
|
|
|
Slowest but most accurate. On TPU: performs float32 computations in 6
|
|
|
|
|
bfloat16. Aliases: ``'highest'``. On GPU: uses float32.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
DEFAULT = 0
|
|
|
|
|
HIGH = 1
|
|
|
|
|
HIGHEST = 2
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _missing_(cls, value: object) -> Precision | None:
|
|
|
|
|
return _precision_strings.get(value)
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return f'{self.__class__.__name__}.{self.name}'
|
|
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return self.name
|
2024-03-04 14:00:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_precision_strings['highest'] = Precision.HIGHEST
|
|
|
|
|
_precision_strings['float32'] = Precision.HIGHEST
|
|
|
|
|
_precision_strings['high'] = Precision.HIGH
|
|
|
|
|
_precision_strings['bfloat16_3x'] = Precision.HIGH
|
|
|
|
|
_precision_strings['tensorfloat32'] = Precision.HIGH
|
|
|
|
|
_precision_strings['default'] = Precision.DEFAULT
|
|
|
|
|
_precision_strings['bfloat16'] = Precision.DEFAULT
|
|
|
|
|
_precision_strings['fastest'] = Precision.DEFAULT
|
|
|
|
|
_precision_strings[None] = Precision.DEFAULT
|
2021-11-08 14:15:31 -08:00
|
|
|
|
|
|
|
|
|
|
2024-09-25 06:16:22 -07:00
|
|
|
|
class DotAlgorithm(NamedTuple):
|
|
|
|
|
"""Specify the algorithm used for computing dot products.
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
When used to specify the ``precision`` input to :func:`~jax.lax.dot`,
|
|
|
|
|
:func:`~jax.lax.dot_general`, and other dot product functions, this data
|
|
|
|
|
structure is used for controlling the properties of the algorithm used for
|
|
|
|
|
computing the dot product. This API controls the precision used for the
|
|
|
|
|
computation, and allows users to access hardware-specific accelerations.
|
2024-09-25 06:16:22 -07:00
|
|
|
|
|
|
|
|
|
Support for these algorithms is platform dependent, and using an unsupported
|
|
|
|
|
algorithm will raise a Python exception when the computation is compiled. The
|
|
|
|
|
algorithms that are known to be supported on at least some platforms are
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
listed in the :class:`~jax.lax.DotAlgorithmPreset` enum, and these are a
|
2024-09-25 06:16:22 -07:00
|
|
|
|
good starting point for experimenting with this API.
|
|
|
|
|
|
|
|
|
|
A "dot algorithm" is specified by the following parameters:
|
|
|
|
|
|
|
|
|
|
* ``lhs_precision_type`` and ``rhs_precision_type``, the data types that the
|
|
|
|
|
LHS and RHS of the operation are rounded to.
|
|
|
|
|
* ``accumulation_type`` the data type used for accumulation.
|
|
|
|
|
* ``lhs_component_count``, ``rhs_component_count``, and
|
|
|
|
|
``num_primitive_operations`` apply to algorithms that decompose the LHS
|
|
|
|
|
and/or RHS into multiple components and execute multiple operations on
|
|
|
|
|
those values, usually to emulate a higher precision. For algorithms with no
|
|
|
|
|
decomposition, these values should be set to ``1``.
|
|
|
|
|
* ``allow_imprecise_accumulation`` to specify if accumulation in lower
|
|
|
|
|
precision is permitted for some steps (e.g.
|
|
|
|
|
``CUBLASLT_MATMUL_DESC_FAST_ACCUM``).
|
|
|
|
|
|
|
|
|
|
The `StableHLO spec <https://openxla.org/stablehlo/spec#dot_general>`_ for
|
|
|
|
|
the dot operation doesn't require that the precision types be the same as the
|
|
|
|
|
storage types for the inputs or outputs, but some plaforms may require that
|
|
|
|
|
these types match. Furthermore, the return type of
|
|
|
|
|
:func:`~jax.lax.dot_general` is always defined by the ``accumulation_type``
|
|
|
|
|
parameter of the input algorithm, if specified.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
Accumulate two 16-bit floats using a 32-bit float accumulator:
|
|
|
|
|
|
|
|
|
|
>>> algorithm = DotAlgorithm(
|
|
|
|
|
... lhs_precision_type=np.float16,
|
|
|
|
|
... rhs_precision_type=np.float16,
|
|
|
|
|
... accumulation_type=np.float32,
|
|
|
|
|
... )
|
|
|
|
|
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
|
|
|
|
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
|
|
|
|
|
array([ 1., 4., 9., 16.], dtype=float16)
|
2024-09-25 06:16:22 -07:00
|
|
|
|
|
|
|
|
|
Or, equivalently, using a preset:
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
|
|
|
|
|
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
|
|
|
|
|
array([ 1., 4., 9., 16.], dtype=float16)
|
|
|
|
|
|
|
|
|
|
Presets can also be specified by name:
|
|
|
|
|
|
|
|
|
|
>>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP
|
|
|
|
|
array([ 1., 4., 9., 16.], dtype=float16)
|
|
|
|
|
|
|
|
|
|
The ``preferred_element_type`` parameter can be used to return the output
|
|
|
|
|
without downcasting the accumulation type:
|
|
|
|
|
|
|
|
|
|
>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) # doctest: +SKIP
|
2024-09-25 06:16:22 -07:00
|
|
|
|
array([ 1., 4., 9., 16.], dtype=float32)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
lhs_precision_type: DTypeLike
|
|
|
|
|
rhs_precision_type: DTypeLike
|
|
|
|
|
accumulation_type: DTypeLike
|
|
|
|
|
lhs_component_count: int = 1
|
|
|
|
|
rhs_component_count: int = 1
|
|
|
|
|
num_primitive_operations: int = 1
|
|
|
|
|
allow_imprecise_accumulation: bool = False
|
|
|
|
|
|
|
|
|
|
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
|
|
|
|
|
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm:
|
|
|
|
|
del lhs_dtype, rhs_dtype # unused
|
|
|
|
|
return hlo.DotAlgorithm.get(
|
|
|
|
|
mlir.dtype_to_ir_type(dtypes.dtype(self.lhs_precision_type)),
|
|
|
|
|
mlir.dtype_to_ir_type(dtypes.dtype(self.rhs_precision_type)),
|
|
|
|
|
mlir.dtype_to_ir_type(dtypes.dtype(self.accumulation_type)),
|
|
|
|
|
self.lhs_component_count,
|
|
|
|
|
self.rhs_component_count,
|
|
|
|
|
self.num_primitive_operations,
|
|
|
|
|
self.allow_imprecise_accumulation,
|
|
|
|
|
)
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
|
|
|
|
|
class DotAlgorithmPreset(enum.Enum):
|
|
|
|
|
"""An enum of known algorithms for computing dot products.
|
|
|
|
|
|
|
|
|
|
This ``Enum`` provides a named set of :class:`~jax.lax.DotAlgorithm` objects
|
|
|
|
|
that are known to be supported on at least platform. See the
|
|
|
|
|
:class:`~jax.lax.DotAlgorithm` documentation for more details about the
|
|
|
|
|
behavior of these algorithms.
|
|
|
|
|
|
|
|
|
|
An algorithm can be selected from this list when calling :func:`~jax.lax.dot`,
|
|
|
|
|
:func:`~jax.lax.dot_general`, or most other JAX dot product functions, by
|
|
|
|
|
passing either a member of this ``Enum`` or it's name as a string using the
|
|
|
|
|
``precision`` argument.
|
|
|
|
|
|
|
|
|
|
For example, users can specify the preset using this ``Enum`` directly:
|
|
|
|
|
|
|
|
|
|
>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
|
|
|
|
>>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
|
|
|
|
>>> algorithm = DotAlgorithmPreset.F16_F16_F32
|
|
|
|
|
>>> dot(lhs, rhs, precision=algorithm) # doctest: +SKIP
|
|
|
|
|
array([ 1., 4., 9., 16.], dtype=float16)
|
|
|
|
|
|
|
|
|
|
or, equivalently, they can be specified by name:
|
|
|
|
|
|
|
|
|
|
>>> dot(lhs, rhs, precision="F16_F16_F32") # doctest: +SKIP
|
|
|
|
|
array([ 1., 4., 9., 16.], dtype=float16)
|
|
|
|
|
|
|
|
|
|
The names of the presets are typically ``LHS_RHS_ACCUM`` where ``LHS`` and
|
|
|
|
|
``RHS`` are the element types of the ``lhs`` and ``rhs`` inputs
|
|
|
|
|
respectively, and ``ACCUM`` is the element type of the accumulator. Some
|
|
|
|
|
presets have an extra suffix, and the meaning of each of these is
|
|
|
|
|
documented below. The supported presets are:
|
|
|
|
|
"""
|
|
|
|
|
DEFAULT = enum.auto()
|
|
|
|
|
"""An algorithm will be selected based on input and output types."""
|
|
|
|
|
|
|
|
|
|
ANY_F8_ANY_F8_F32 = enum.auto()
|
|
|
|
|
"""Accepts any float8 input types and accumulates into float32."""
|
|
|
|
|
|
|
|
|
|
ANY_F8_ANY_F8_F32_FAST_ACCUM = enum.auto()
|
|
|
|
|
"""Like ``ANY_F8_ANY_F8_F32``, but using faster accumulation with the cost
|
|
|
|
|
of lower accuracy.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
ANY_F8_ANY_F8_ANY = enum.auto()
|
|
|
|
|
"""Like ``ANY_F8_ANY_F8_F32``, but the accumulation type is controlled by
|
|
|
|
|
``preferred_element_type``.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
ANY_F8_ANY_F8_ANY_FAST_ACCUM = enum.auto()
|
|
|
|
|
"""Like ``ANY_F8_ANY_F8_F32_FAST_ACCUM``, but the accumulation type is
|
|
|
|
|
controlled by ``preferred_element_type``.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
F16_F16_F16 = enum.auto()
|
|
|
|
|
F16_F16_F32 = enum.auto()
|
|
|
|
|
BF16_BF16_BF16 = enum.auto()
|
|
|
|
|
BF16_BF16_F32 = enum.auto()
|
|
|
|
|
BF16_BF16_F32_X3 = enum.auto()
|
|
|
|
|
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
|
|
|
|
|
emulate higher precision.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
BF16_BF16_F32_X6 = enum.auto()
|
|
|
|
|
"""Like ``BF16_BF16_F32_X3``, but using 6 operations instead of 3."""
|
|
|
|
|
|
2025-03-14 02:57:12 -07:00
|
|
|
|
BF16_BF16_F32_X9 = enum.auto()
|
|
|
|
|
"""Like ``BF16_BF16_F32_X3``, but using 9 operations instead of 3."""
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
TF32_TF32_F32 = enum.auto()
|
|
|
|
|
TF32_TF32_F32_X3 = enum.auto()
|
|
|
|
|
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
|
|
|
|
|
emulate higher precision.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
F32_F32_F32 = enum.auto()
|
|
|
|
|
F64_F64_F64 = enum.auto()
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return f'{self.__class__.__name__}.{self.name}'
|
|
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
|
|
@property
|
2024-12-03 06:25:55 -08:00
|
|
|
|
def supported_lhs_types(self) -> tuple[DTypeLike, ...] | None:
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
match self:
|
|
|
|
|
case (
|
2024-11-20 08:26:12 -08:00
|
|
|
|
DotAlgorithmPreset.DEFAULT
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
):
|
2024-09-25 06:16:22 -07:00
|
|
|
|
return None
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32:
|
2024-12-03 06:25:55 -08:00
|
|
|
|
return (np.float16,)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
case (
|
|
|
|
|
DotAlgorithmPreset.BF16_BF16_BF16 |
|
|
|
|
|
DotAlgorithmPreset.BF16_BF16_F32
|
|
|
|
|
):
|
|
|
|
|
# These algorithms support either f32 or bf32 input storage types.
|
|
|
|
|
# If either of those types are provided as input, we use the provided
|
|
|
|
|
# type. If not, we explicitly cast to bfloat16.
|
|
|
|
|
return (dtypes.bfloat16, np.float32)
|
|
|
|
|
case DotAlgorithmPreset.F64_F64_F64:
|
2024-12-03 06:25:55 -08:00
|
|
|
|
return (np.float64,)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
case _:
|
2024-12-03 06:25:55 -08:00
|
|
|
|
return (np.float32,)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
|
|
|
|
|
@property
|
2024-12-03 06:25:55 -08:00
|
|
|
|
def supported_rhs_types(self) -> tuple[DTypeLike, ...] | None:
|
|
|
|
|
return self.supported_lhs_types
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
|
|
|
|
|
@property
|
2024-11-20 08:26:12 -08:00
|
|
|
|
def accumulation_type(self) -> DTypeLike | None:
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
match self:
|
|
|
|
|
case (
|
2024-11-20 08:26:12 -08:00
|
|
|
|
DotAlgorithmPreset.DEFAULT
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
):
|
|
|
|
|
return None
|
2024-11-20 08:26:12 -08:00
|
|
|
|
case DotAlgorithmPreset.F16_F16_F16:
|
|
|
|
|
return np.float16
|
|
|
|
|
case DotAlgorithmPreset.BF16_BF16_BF16:
|
|
|
|
|
return dtypes.bfloat16
|
|
|
|
|
case DotAlgorithmPreset.F64_F64_F64:
|
|
|
|
|
return np.float64
|
|
|
|
|
case _:
|
|
|
|
|
return np.float32
|
|
|
|
|
|
2024-12-05 07:00:58 -08:00
|
|
|
|
def supported_output_types(
|
|
|
|
|
self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike
|
|
|
|
|
) -> tuple[DTypeLike, ...] | None:
|
2024-11-20 08:26:12 -08:00
|
|
|
|
match self:
|
2024-11-12 05:29:40 -08:00
|
|
|
|
case (
|
2024-12-03 06:25:55 -08:00
|
|
|
|
DotAlgorithmPreset.ANY_F8_ANY_F8_F32
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
|
2024-11-12 05:29:40 -08:00
|
|
|
|
):
|
2024-12-03 06:25:55 -08:00
|
|
|
|
return (
|
|
|
|
|
np.float32,
|
|
|
|
|
np.float16,
|
|
|
|
|
dtypes.bfloat16,
|
|
|
|
|
dtypes.float8_e4m3fn,
|
|
|
|
|
dtypes.float8_e5m2,
|
|
|
|
|
dtypes.float8_e5m2fnuz,
|
|
|
|
|
dtypes.float8_e4m3fnuz,
|
|
|
|
|
dtypes.float8_e4m3b11fnuz,
|
|
|
|
|
)
|
2024-11-12 05:29:40 -08:00
|
|
|
|
case DotAlgorithmPreset.F16_F16_F32:
|
2024-12-05 07:00:58 -08:00
|
|
|
|
# F16 output is only supported with F16 inputs.
|
|
|
|
|
if dtypes.promote_types(lhs_dtype, rhs_dtype) == np.float16:
|
|
|
|
|
return (np.float32, np.float16)
|
|
|
|
|
else:
|
|
|
|
|
return (np.float32,)
|
|
|
|
|
case DotAlgorithmPreset.BF16_BF16_F32:
|
|
|
|
|
# BF16 output is only supported with BF16 inputs.
|
|
|
|
|
if dtypes.promote_types(lhs_dtype, rhs_dtype) == dtypes.bfloat16:
|
|
|
|
|
return (np.float32, dtypes.bfloat16)
|
|
|
|
|
else:
|
|
|
|
|
return (np.float32,)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
case _:
|
2024-11-20 08:26:12 -08:00
|
|
|
|
accumulation_type = self.accumulation_type
|
|
|
|
|
return None if accumulation_type is None else (accumulation_type,)
|
2024-09-25 06:16:22 -07:00
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
|
|
|
|
|
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
|
|
|
|
|
f16 = ir.F16Type.get()
|
|
|
|
|
f32 = ir.F32Type.get()
|
|
|
|
|
f64 = ir.F64Type.get()
|
|
|
|
|
bf16 = ir.BF16Type.get()
|
|
|
|
|
tf32 = ir.FloatTF32Type.get()
|
|
|
|
|
match self:
|
|
|
|
|
case (
|
2024-11-20 08:26:12 -08:00
|
|
|
|
DotAlgorithmPreset.ANY_F8_ANY_F8_F32
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
|
|
|
|
|
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
):
|
2024-11-20 08:26:12 -08:00
|
|
|
|
fp8_dtypes = [
|
|
|
|
|
np.dtype(dtypes.float8_e4m3b11fnuz),
|
|
|
|
|
np.dtype(dtypes.float8_e4m3fn),
|
|
|
|
|
np.dtype(dtypes.float8_e4m3fnuz),
|
|
|
|
|
np.dtype(dtypes.float8_e5m2),
|
|
|
|
|
np.dtype(dtypes.float8_e5m2fnuz),
|
|
|
|
|
]
|
2024-10-07 15:33:24 -07:00
|
|
|
|
if dtypes.float8_e3m4 is not None:
|
|
|
|
|
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
|
|
|
|
|
if dtypes.float8_e4m3 is not None:
|
|
|
|
|
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
|
2025-01-22 21:57:43 +00:00
|
|
|
|
if dtypes.float8_e8m0fnu is not None:
|
|
|
|
|
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
|
2024-09-25 06:16:22 -07:00
|
|
|
|
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"The dot algorithm '{self}' requires both inputs to have float8 "
|
2024-11-20 08:26:12 -08:00
|
|
|
|
f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.'
|
|
|
|
|
)
|
2024-09-25 06:16:22 -07:00
|
|
|
|
lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype))
|
|
|
|
|
rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype))
|
|
|
|
|
acc = ir.F32Type.get()
|
|
|
|
|
return hlo.DotAlgorithm.get(
|
2024-11-20 08:26:12 -08:00
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
acc,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
|
|
|
|
|
)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
case DotAlgorithmPreset.F16_F16_F16:
|
|
|
|
|
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
|
|
|
|
|
case DotAlgorithmPreset.F16_F16_F32:
|
|
|
|
|
return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False)
|
|
|
|
|
case DotAlgorithmPreset.BF16_BF16_BF16:
|
|
|
|
|
return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False)
|
|
|
|
|
case DotAlgorithmPreset.BF16_BF16_F32:
|
|
|
|
|
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False)
|
|
|
|
|
case DotAlgorithmPreset.BF16_BF16_F32_X3:
|
|
|
|
|
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
|
|
|
|
|
case DotAlgorithmPreset.BF16_BF16_F32_X6:
|
|
|
|
|
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
|
2025-03-14 02:57:12 -07:00
|
|
|
|
case DotAlgorithmPreset.BF16_BF16_F32_X9:
|
|
|
|
|
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
case DotAlgorithmPreset.TF32_TF32_F32:
|
|
|
|
|
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
|
|
|
|
|
case DotAlgorithmPreset.TF32_TF32_F32_X3:
|
|
|
|
|
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False)
|
|
|
|
|
case DotAlgorithmPreset.F32_F32_F32:
|
|
|
|
|
return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False)
|
|
|
|
|
case DotAlgorithmPreset.F64_F64_F64:
|
|
|
|
|
return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False)
|
|
|
|
|
case _:
|
|
|
|
|
return None
|
2024-09-25 06:16:22 -07:00
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
|
|
|
|
|
PrecisionLike = Union[
|
2024-09-25 06:16:22 -07:00
|
|
|
|
None,
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
str,
|
|
|
|
|
Precision,
|
|
|
|
|
tuple[str, str],
|
|
|
|
|
tuple[Precision, Precision],
|
2024-09-25 06:16:22 -07:00
|
|
|
|
DotAlgorithm,
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
DotAlgorithmPreset,
|
2024-09-25 06:16:22 -07:00
|
|
|
|
]
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
CanonicalPrecision = Union[
|
|
|
|
|
None,
|
|
|
|
|
tuple[Precision, Precision],
|
|
|
|
|
DotAlgorithm,
|
|
|
|
|
DotAlgorithmPreset,
|
2024-09-25 06:16:22 -07:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
2020-12-10 02:29:40 +00:00
|
|
|
|
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
preferred_element_type: DTypeLike | None = None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
Wraps XLA's `Dot <https://www.tensorflow.org/xla/operation_semantics#dot>`_
|
2020-10-17 14:33:26 -04:00
|
|
|
|
operator.
|
|
|
|
|
|
2024-08-11 12:44:50 -07:00
|
|
|
|
For more general contraction, see the :func:`jax.lax.dot_general` operator.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Args:
|
2022-06-09 15:03:53 -07:00
|
|
|
|
lhs: an array of dimension 1 or 2.
|
|
|
|
|
rhs: an array of dimension 1 or 2.
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
precision: Optional. This parameter controls the numerics of the
|
|
|
|
|
computation, and it can be one of the following:
|
|
|
|
|
|
|
|
|
|
- ``None``, which means the default precision for the current backend,
|
|
|
|
|
- a :class:`~jax.lax.Precision` enum value or a tuple of two
|
|
|
|
|
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and
|
|
|
|
|
``rhs``, or
|
|
|
|
|
- a :class:`~jax.lax.DotAlgorithm` or a
|
|
|
|
|
:class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that
|
|
|
|
|
must be used to accumulate the dot product.
|
|
|
|
|
|
|
|
|
|
preferred_element_type: Optional. This parameter controls the data type
|
|
|
|
|
output by the dot product. By default, the output element type of this
|
|
|
|
|
operation will match the ``lhs`` and ``rhs`` input element types under
|
|
|
|
|
the usual type promotion rules. Setting ``preferred_element_type`` to a
|
|
|
|
|
specific ``dtype`` will mean that the operation returns that element type.
|
|
|
|
|
When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or
|
|
|
|
|
:class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides
|
|
|
|
|
a hint to the compiler to accumulate the dot product using this data type.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array containing the product.
|
|
|
|
|
"""
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
|
2021-10-19 06:48:34 -07:00
|
|
|
|
precision=precision,
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
preferred_element_type=preferred_element_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
|
|
|
|
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
|
|
|
|
|
lhs.shape, rhs.shape))
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
|
|
|
|
tuple[Sequence[int], Sequence[int]]]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
2020-12-10 02:29:40 +00:00
|
|
|
|
precision: PrecisionLike = None,
|
2024-10-22 13:10:05 -07:00
|
|
|
|
preferred_element_type: DTypeLike | None = None,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding=None) -> Array:
|
2023-02-27 09:46:04 -08:00
|
|
|
|
"""General dot product/contraction operator.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Wraps XLA's `DotGeneral
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
|
|
|
|
|
operator.
|
|
|
|
|
|
2023-02-27 09:46:04 -08:00
|
|
|
|
The semantics of ``dot_general`` are complicated, but most users should not have to
|
|
|
|
|
use it directly. Instead, you can use higher-level functions like :func:`jax.numpy.dot`,
|
|
|
|
|
:func:`jax.numpy.matmul`, :func:`jax.numpy.tensordot`, :func:`jax.numpy.einsum`,
|
|
|
|
|
and others which will construct appropriate calls to ``dot_general`` under the hood.
|
|
|
|
|
If you really want to understand ``dot_general`` itself, we recommend reading XLA's
|
|
|
|
|
`DotGeneral <https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
|
|
|
|
|
operator documentation.
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
Args:
|
|
|
|
|
lhs: an array
|
|
|
|
|
rhs: an array
|
2023-02-27 09:46:04 -08:00
|
|
|
|
dimension_numbers: a tuple of tuples of sequences of ints of the form
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
|
|
|
|
rhs_batch_dims))``
|
|
|
|
|
precision: Optional. This parameter controls the numerics of the
|
|
|
|
|
computation, and it can be one of the following:
|
|
|
|
|
|
|
|
|
|
- ``None``, which means the default precision for the current backend,
|
|
|
|
|
- a :class:`~jax.lax.Precision` enum value or a tuple of two
|
|
|
|
|
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and
|
|
|
|
|
``rhs``, or
|
|
|
|
|
- a :class:`~jax.lax.DotAlgorithm` or a
|
|
|
|
|
:class:`~jax.lax.DotAlgorithmPreset` indicating the algorithm that
|
|
|
|
|
must be used to accumulate the dot product.
|
|
|
|
|
|
|
|
|
|
preferred_element_type: Optional. This parameter controls the data type
|
|
|
|
|
output by the dot product. By default, the output element type of this
|
|
|
|
|
operation will match the ``lhs`` and ``rhs`` input element types under
|
|
|
|
|
the usual type promotion rules. Setting ``preferred_element_type`` to a
|
|
|
|
|
specific ``dtype`` will mean that the operation returns that element type.
|
|
|
|
|
When ``precision`` is not a :class:`~jax.lax.DotAlgorithm` or
|
|
|
|
|
:class:`~jax.lax.DotAlgorithmPreset`, ``preferred_element_type`` provides
|
|
|
|
|
a hint to the compiler to accumulate the dot product using this data type.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Returns:
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
An array whose first dimensions are the (shared) batch dimensions, followed
|
|
|
|
|
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
|
2023-02-27 09:46:04 -08:00
|
|
|
|
non-contracting/non-batch dimensions.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2025-01-16 18:16:12 -08:00
|
|
|
|
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
2024-10-22 13:10:05 -07:00
|
|
|
|
raise NotImplementedError(
|
2025-01-16 18:16:12 -08:00
|
|
|
|
'`out_sharding` argument of `dot_general` only supports NamedSharding '
|
2024-10-22 13:10:05 -07:00
|
|
|
|
'instances. Please file a bug if this is not enough for your use case.')
|
2025-02-19 09:21:07 -08:00
|
|
|
|
out_sharding = canonicalize_sharding(out_sharding, 'dot_general')
|
2021-10-28 16:48:59 -07:00
|
|
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
|
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
|
|
|
|
api_util._ensure_index_tuple(rhs_contract))
|
|
|
|
|
bdims = (api_util._ensure_index_tuple(lhs_batch),
|
|
|
|
|
api_util._ensure_index_tuple(rhs_batch))
|
2021-12-06 12:20:48 -08:00
|
|
|
|
preferred_element_type = (
|
|
|
|
|
None if preferred_element_type is None else
|
|
|
|
|
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
|
2025-03-25 17:02:45 -07:00
|
|
|
|
lhs, rhs = core.standard_insert_pbroadcast(lhs, rhs)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return dot_general_p.bind(lhs, rhs,
|
2021-10-28 16:48:59 -07:00
|
|
|
|
dimension_numbers=(cdims, bdims),
|
2021-09-13 15:39:02 -07:00
|
|
|
|
precision=canonicalize_precision(precision),
|
2024-10-22 13:10:05 -07:00
|
|
|
|
preferred_element_type=preferred_element_type,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding=out_sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
|
|
|
|
def ragged_dot(
|
|
|
|
|
lhs: Array,
|
|
|
|
|
rhs: Array,
|
|
|
|
|
group_sizes: Array,
|
|
|
|
|
precision: PrecisionLike = None,
|
|
|
|
|
preferred_element_type: DTypeLike | None = None,
|
|
|
|
|
group_offset: Array | None = None,
|
|
|
|
|
) -> Array:
|
|
|
|
|
"""Ragged matrix multiplication.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
lhs: (m, k) shaped array.
|
|
|
|
|
rhs: (g, k, n) shaped array.
|
|
|
|
|
group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
|
|
|
|
|
precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
|
|
|
|
preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
2024-06-17 13:55:46 +05:30
|
|
|
|
group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
|
|
|
|
Results:
|
|
|
|
|
(m, n) shaped array with preferred_element_type element type.
|
|
|
|
|
"""
|
2025-03-10 12:24:38 -07:00
|
|
|
|
return ragged_dot_general(
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
group_sizes,
|
|
|
|
|
ragged_dot_dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS,
|
|
|
|
|
precision=canonicalize_precision(precision),
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
group_offset=group_offset,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class RaggedDotDimensionNumbers():
|
|
|
|
|
"""Describes ragged, group, and dot dimensions for ragged dot general.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dot_dimension_numbers: a tuple of tuples of sequences of ints of the form
|
|
|
|
|
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
|
|
|
|
|
rhs_batch_dims))`.
|
|
|
|
|
lhs_ragged_dimensions: a sequence of ints indicating the 'lhs' ragged
|
|
|
|
|
dimensions.
|
|
|
|
|
rhs_group_dimensions: a sequence of ints indicating the 'rhs' group
|
|
|
|
|
dimensions.
|
|
|
|
|
"""
|
|
|
|
|
dot_dimension_numbers: DotDimensionNumbers
|
|
|
|
|
lhs_ragged_dimensions: Sequence[int]
|
|
|
|
|
rhs_group_dimensions: Sequence[int]
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, dot_dimension_numbers, lhs_ragged_dimensions, rhs_group_dimensions
|
|
|
|
|
):
|
|
|
|
|
super().__setattr__(
|
|
|
|
|
'dot_dimension_numbers',
|
|
|
|
|
tuple(tuple(map(tuple, t)) for t in dot_dimension_numbers),
|
|
|
|
|
)
|
|
|
|
|
super().__setattr__('lhs_ragged_dimensions', tuple(lhs_ragged_dimensions))
|
|
|
|
|
super().__setattr__('rhs_group_dimensions', tuple(rhs_group_dimensions))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _from_maybe_ragged(
|
|
|
|
|
dot_dimension_numbers: RaggedDotDimensionNumbers | DotDimensionNumbers,
|
|
|
|
|
) -> DotDimensionNumbers:
|
|
|
|
|
return (
|
|
|
|
|
dot_dimension_numbers.dot_dimension_numbers
|
|
|
|
|
if isinstance(dot_dimension_numbers, RaggedDotDimensionNumbers)
|
|
|
|
|
else dot_dimension_numbers
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# RaggedDotDimensionNumbers that specify the simple case (i.e., lax.ragged_dot.)
|
|
|
|
|
_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = RaggedDotDimensionNumbers(
|
|
|
|
|
dot_dimension_numbers=(([1], [1]), ([], [])),
|
|
|
|
|
lhs_ragged_dimensions=[0],
|
|
|
|
|
rhs_group_dimensions=[0],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ragged_dot_general(
|
|
|
|
|
lhs: Array,
|
|
|
|
|
rhs: Array,
|
|
|
|
|
group_sizes: Array,
|
|
|
|
|
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
|
|
|
|
precision: PrecisionLike = None,
|
|
|
|
|
preferred_element_type: DTypeLike | None = None,
|
|
|
|
|
group_offset: Array | None = None,
|
|
|
|
|
) -> Array:
|
|
|
|
|
"""Ragged matrix multiplication.
|
|
|
|
|
|
|
|
|
|
Ragged dot takes three arrays---``lhs``, ``rhs``, and ``group_sizes``---and
|
|
|
|
|
a ``ragged_dot_dimension_numbers`` argument. Like `dot_general`, ``lhs`` and
|
|
|
|
|
``rhs`` are allowed arbitrary batch and contracting dimensions. Additionally,
|
|
|
|
|
``lhs`` is required to have one ragged dimension, and ``rhs`` may have at
|
|
|
|
|
most one group dimension.
|
|
|
|
|
|
|
|
|
|
Let `g` be the number of groups in the lhs ragged dimension. Ragged dot has
|
|
|
|
|
three modes, depending on the kind of the lhs ragged dimension:
|
|
|
|
|
1. `[b...,m...,k...], [g,b...,k...,n...], [b...,x...,g] -> [b...,m...,n...]`.
|
|
|
|
|
Here the ragged dimension is a non-contracting dimension (`m`) of ``lhs``,
|
|
|
|
|
and `x...` are the lhs non-contracting dims outer to the ragged dim.
|
|
|
|
|
2. `[b...,m...,k...], [b...,k...,n...], [b...,x...,g] -> [g,b...,m...,n...]`.
|
|
|
|
|
Here the ragged dimension is a contracting dimension (`k`) of ``lhs`` and
|
|
|
|
|
``rhs``, and `x...` are the lhs contracting dims outer to the ragged dim.
|
|
|
|
|
3. `[b...,m...,k...], [b...,k...,n...], [x...,g] -> [b...,m...,n...]`.
|
|
|
|
|
Here the ragged dimension is a batch dimension (`b`) of ``lhs`` and
|
|
|
|
|
``rhs``, and `x...` are the lhs batch dims outer to the ragged dim.
|
|
|
|
|
If ``group_sizes`` is passed-in with shape `[g]`, it is broadcasted according
|
|
|
|
|
to the rules above.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
lhs: an array
|
|
|
|
|
rhs: an array
|
|
|
|
|
group_sizes: an array with integer element type
|
|
|
|
|
ragged_dot_dimension_numbers: a ``RaggedDotDimensionNumbers`` object to
|
|
|
|
|
specify the dot dimension numbers, lhs ragged dimension, and rhs group
|
|
|
|
|
dimension.
|
|
|
|
|
precision: Optional. Consistent with precision argument for
|
|
|
|
|
:func:`jax.lax.dot`.
|
|
|
|
|
preferred_element_type: Optional. Consistent with precision argument for
|
|
|
|
|
:func:`jax.lax.dot`.
|
|
|
|
|
group_offset: Optional. (1,) shaped array that indicates the group in
|
|
|
|
|
group_sizes to start computing from. If not specified, defaults to [0].
|
|
|
|
|
|
|
|
|
|
Results:
|
|
|
|
|
An array whose shape is the same as that produced by `dot_general`, with an
|
|
|
|
|
extra leading dimension of size `g` in the case where the lhs ragged
|
|
|
|
|
dimension is a contracting dimension.
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
lhs, rhs, group_sizes = core.standard_insert_pbroadcast(lhs, rhs, group_sizes)
|
2025-03-10 12:24:38 -07:00
|
|
|
|
return ragged_dot_general_p.bind(
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
group_sizes,
|
|
|
|
|
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
|
|
|
|
precision=canonicalize_precision(precision),
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
group_offset=group_offset,
|
|
|
|
|
)
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
|
|
|
|
|
2025-02-12 13:58:38 -08:00
|
|
|
|
def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None
|
|
|
|
|
) -> Array:
|
2021-07-12 15:33:26 -07:00
|
|
|
|
"""Broadcasts an array, adding new leading dimensions
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: an array
|
2021-07-15 03:12:07 -07:00
|
|
|
|
sizes: a sequence of integers, giving the sizes of new leading dimensions
|
2021-07-12 15:33:26 -07:00
|
|
|
|
to add to the front of the array.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array containing the result.
|
2021-07-12 15:33:26 -07:00
|
|
|
|
|
|
|
|
|
See Also:
|
|
|
|
|
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2025-02-12 13:58:38 -08:00
|
|
|
|
if len(sizes) == 0 and out_sharding is None:
|
2024-07-24 11:00:55 -04:00
|
|
|
|
return asarray(operand)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
|
2024-10-25 10:34:33 -07:00
|
|
|
|
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims,
|
2025-02-12 13:58:38 -08:00
|
|
|
|
out_sharding=out_sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
2025-02-12 13:58:38 -08:00
|
|
|
|
broadcast_dimensions: Sequence[int], out_sharding=None
|
|
|
|
|
) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Wraps XLA's `BroadcastInDim
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
|
|
|
|
|
operator.
|
2021-07-12 15:33:26 -07:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: an array
|
|
|
|
|
shape: the shape of the target array
|
2022-01-20 13:25:35 +02:00
|
|
|
|
broadcast_dimensions: to which dimension in the target shape each dimension
|
2023-06-04 10:18:36 -04:00
|
|
|
|
of the operand shape corresponds to. That is, dimension i of the operand
|
|
|
|
|
becomes dimension broadcast_dimensions[i] of the result.
|
2021-07-12 15:33:26 -07:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array containing the result.
|
|
|
|
|
|
|
|
|
|
See Also:
|
|
|
|
|
jax.lax.broadcast : simpler interface to add new leading dimensions.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2025-02-19 09:21:07 -08:00
|
|
|
|
out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim')
|
2024-10-25 10:34:33 -07:00
|
|
|
|
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
|
2025-02-12 13:58:38 -08:00
|
|
|
|
isinstance(operand, Array) and out_sharding is None):
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return operand
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.dynamic_shapes.value:
|
2022-06-29 13:55:30 -07:00
|
|
|
|
# We must gate this behavior under a flag because otherwise the errors
|
|
|
|
|
# raised are different (and have worse source provenance information).
|
|
|
|
|
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
|
|
|
|
|
else:
|
|
|
|
|
dyn_shape, static_shape = [], shape # type: ignore
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return broadcast_in_dim_p.bind(
|
2021-11-16 11:17:42 +02:00
|
|
|
|
operand, *dyn_shape, shape=tuple(static_shape),
|
2024-10-25 10:34:33 -07:00
|
|
|
|
broadcast_dimensions=tuple(broadcast_dimensions),
|
2025-02-12 13:58:38 -08:00
|
|
|
|
sharding=out_sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-07-24 11:00:55 -04:00
|
|
|
|
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
|
2024-07-24 11:00:55 -04:00
|
|
|
|
ndim = np.ndim(x)
|
|
|
|
|
if ndim == rank:
|
|
|
|
|
return asarray(x)
|
|
|
|
|
return broadcast(x, (1,) * (rank - ndim))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def reshape(operand: ArrayLike, new_sizes: Shape,
|
2024-11-25 18:14:30 -08:00
|
|
|
|
dimensions: Sequence[int] | None = None,
|
2025-02-12 13:58:38 -08:00
|
|
|
|
out_sharding: NamedSharding | P | None = None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Wraps XLA's `Reshape
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
|
|
|
|
operator.
|
|
|
|
|
|
|
|
|
|
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
|
|
|
|
|
``lax.expand_dims``. These preserve information about axis identity that may
|
|
|
|
|
be useful for advanced transformation rules.
|
2021-04-29 09:48:52 -07:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array to be reshaped.
|
|
|
|
|
new_sizes: sequence of integers specifying the resulting shape. The size
|
|
|
|
|
of the final array must match the size of the input.
|
|
|
|
|
dimensions: optional sequence of integers specifying the permutation order of
|
|
|
|
|
the input shape. If specified, the length must match ``operand.shape``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
out: reshaped array.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
Simple reshaping from one to two dimensions:
|
|
|
|
|
|
|
|
|
|
>>> x = jnp.arange(6)
|
|
|
|
|
>>> y = reshape(x, (2, 3))
|
|
|
|
|
>>> y
|
2022-11-15 11:51:55 -08:00
|
|
|
|
Array([[0, 1, 2],
|
2021-04-29 09:48:52 -07:00
|
|
|
|
[3, 4, 5]], dtype=int32)
|
|
|
|
|
|
|
|
|
|
Reshaping back to one dimension:
|
|
|
|
|
|
|
|
|
|
>>> reshape(y, (6,))
|
2022-11-15 11:51:55 -08:00
|
|
|
|
Array([0, 1, 2, 3, 4, 5], dtype=int32)
|
2021-04-29 09:48:52 -07:00
|
|
|
|
|
|
|
|
|
Reshaping to one dimension with permutation of dimensions:
|
|
|
|
|
|
|
|
|
|
>>> reshape(y, (6,), (1, 0))
|
2022-11-15 11:51:55 -08:00
|
|
|
|
Array([0, 3, 1, 4, 2, 5], dtype=int32)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
new_sizes = canonicalize_shape(new_sizes) # TODO
|
|
|
|
|
new_sizes = tuple(new_sizes)
|
2023-06-30 12:31:47 +03:00
|
|
|
|
same_shape = core.definitely_equal_shape(np.shape(operand), new_sizes)
|
2021-10-28 16:48:59 -07:00
|
|
|
|
if dimensions is None:
|
|
|
|
|
same_dims = True
|
|
|
|
|
dims = None
|
|
|
|
|
else:
|
|
|
|
|
dims = api_util._ensure_index_tuple(dimensions)
|
|
|
|
|
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
|
2022-12-05 16:13:33 -08:00
|
|
|
|
if np.shape(operand) and same_shape and same_dims and isinstance(operand, Array):
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return operand
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2021-11-16 11:17:42 +02:00
|
|
|
|
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
|
2025-02-19 09:21:07 -08:00
|
|
|
|
out_sharding = canonicalize_sharding(out_sharding, 'reshape')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return reshape_p.bind(
|
2022-06-29 13:55:30 -07:00
|
|
|
|
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
|
2024-11-25 18:14:30 -08:00
|
|
|
|
dimensions=None if dims is None or same_dims else dims,
|
2025-02-12 13:58:38 -08:00
|
|
|
|
sharding=out_sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def pad(operand: ArrayLike, padding_value: ArrayLike,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
padding_config: Sequence[tuple[int, int, int]]) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Applies low, high, and/or interior padding to an array.
|
|
|
|
|
|
|
|
|
|
Wraps XLA's `Pad
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
|
|
|
|
|
operator.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: an array to be padded.
|
|
|
|
|
padding_value: the value to be inserted as padding. Must have the same dtype
|
|
|
|
|
as ``operand``.
|
|
|
|
|
padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
|
|
|
|
|
giving the amount of low, high, and interior (dilation) padding to insert
|
|
|
|
|
in each dimension.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The ``operand`` array with padding value ``padding_value`` inserted in each
|
|
|
|
|
dimension according to the ``padding_config``.
|
2024-11-20 15:13:14 -08:00
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from jax import lax
|
|
|
|
|
>>> import jax.numpy as jnp
|
|
|
|
|
|
|
|
|
|
Pad a 1-dimensional array with zeros, We'll specify two zeros in front and
|
|
|
|
|
three at the end:
|
|
|
|
|
|
|
|
|
|
>>> x = jnp.array([1, 2, 3, 4])
|
|
|
|
|
>>> lax.pad(x, 0, [(2, 3, 0)])
|
|
|
|
|
Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
|
|
|
|
|
|
|
|
|
|
Pad a 1-dimensional array with *interior* zeros; i.e. insert a single zero
|
|
|
|
|
between each value:
|
|
|
|
|
|
|
|
|
|
>>> lax.pad(x, 0, [(0, 0, 1)])
|
|
|
|
|
Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
|
|
|
|
|
|
|
|
|
|
Pad a 2-dimensional array with the value ``-1`` at front and end, with a pad
|
|
|
|
|
size of 2 in each dimension:
|
|
|
|
|
|
|
|
|
|
>>> x = jnp.array([[1, 2, 3],
|
|
|
|
|
... [4, 5, 6]])
|
|
|
|
|
>>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
|
|
|
|
|
Array([[-1, -1, -1, -1, -1, -1, -1],
|
|
|
|
|
[-1, -1, -1, -1, -1, -1, -1],
|
|
|
|
|
[-1, -1, 1, 2, 3, -1, -1],
|
|
|
|
|
[-1, -1, 4, 5, 6, -1, -1],
|
|
|
|
|
[-1, -1, -1, -1, -1, -1, -1],
|
|
|
|
|
[-1, -1, -1, -1, -1, -1, -1]], dtype=int32)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
operand, padding_value = core.standard_insert_pbroadcast(operand, padding_value)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Wraps XLA's `Rev
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
|
|
|
|
|
operator.
|
|
|
|
|
"""
|
|
|
|
|
return rev_p.bind(operand, dimensions=tuple(dimensions))
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
|
2023-01-06 11:31:26 -08:00
|
|
|
|
"""Selects between two branches based on a boolean predicate.
|
|
|
|
|
|
|
|
|
|
Wraps XLA's `Select
|
2020-10-17 14:33:26 -04:00
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#select>`_
|
|
|
|
|
operator.
|
2023-01-06 11:31:26 -08:00
|
|
|
|
|
|
|
|
|
In general :func:`~jax.lax.select` leads to evaluation of both branches, although
|
|
|
|
|
the compiler may elide computations if possible. For a similar function that
|
|
|
|
|
usually evaluates only a single branch, see :func:`~jax.lax.cond`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
pred: boolean array
|
|
|
|
|
on_true: array containing entries to return where ``pred`` is True. Must have
|
|
|
|
|
the same shape as ``pred``, and the same shape and dtype as ``on_false``.
|
|
|
|
|
on_false: array containing entries to return where ``pred`` is False. Must have
|
|
|
|
|
the same shape as ``pred``, and the same shape and dtype as ``on_true``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
result: array with same shape and dtype as ``on_true`` and ``on_false``.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2022-02-09 11:02:31 -08:00
|
|
|
|
# Caution! The select_n_p primitive has the *opposite* order of arguments to
|
|
|
|
|
# select(). This is because it implements `select_n`.
|
2025-03-25 17:02:45 -07:00
|
|
|
|
pred, on_false, on_true = core.standard_insert_pbroadcast(
|
|
|
|
|
pred, on_false, on_true)
|
2022-02-09 11:02:31 -08:00
|
|
|
|
return select_n_p.bind(pred, on_false, on_true)
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def select_n(which: ArrayLike, *cases: ArrayLike) -> Array:
|
2022-02-09 11:02:31 -08:00
|
|
|
|
"""Selects array values from multiple cases.
|
|
|
|
|
|
|
|
|
|
Generalizes XLA's `Select
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#select>`_
|
|
|
|
|
operator. Unlike XLA's version, the operator is variadic and can select
|
|
|
|
|
from many cases using an integer `pred`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
which: determines which case should be returned. Must be an array containing
|
|
|
|
|
either a boolean or integer values. May either be a scalar or have
|
|
|
|
|
shape matching ``cases``. For each array element, the value of ``which``
|
|
|
|
|
determines which of ``cases`` is taken. ``which`` must be in the range
|
|
|
|
|
``[0 .. len(cases))``; for values outside that range the behavior is
|
|
|
|
|
implementation-defined.
|
|
|
|
|
*cases: a non-empty list of array cases. All must have equal dtypes and
|
|
|
|
|
equal shapes.
|
|
|
|
|
Returns:
|
|
|
|
|
An array with shape and dtype equal to the cases, whose values are chosen
|
|
|
|
|
according to ``which``.
|
|
|
|
|
"""
|
|
|
|
|
if len(cases) == 0:
|
|
|
|
|
raise ValueError("select_n() must have at least one case")
|
2025-03-25 17:02:45 -07:00
|
|
|
|
which, *cases = core.standard_insert_pbroadcast(which, *cases)
|
2022-02-09 11:02:31 -08:00
|
|
|
|
return select_n_p.bind(which, *cases)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-08-11 08:06:51 -07:00
|
|
|
|
def transpose(operand: ArrayLike,
|
|
|
|
|
permutation: Sequence[int] | np.ndarray) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Wraps XLA's `Transpose
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
|
|
|
|
|
operator.
|
|
|
|
|
"""
|
2020-07-22 12:10:43 -07:00
|
|
|
|
permutation = tuple(operator.index(d) for d in permutation)
|
2022-12-05 16:13:33 -08:00
|
|
|
|
if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array):
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return operand
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2025-03-27 16:55:45 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return transpose_p.bind(operand, permutation=permutation)
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def argmin(operand: ArrayLike, axis: int,
|
2022-10-21 14:37:59 -07:00
|
|
|
|
index_dtype: DTypeLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Computes the index of the minimum element along ``axis``."""
|
|
|
|
|
return argmin_p.bind(operand, axes=(axis,),
|
|
|
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def argmax(operand: ArrayLike, axis: int,
|
2022-10-21 14:37:59 -07:00
|
|
|
|
index_dtype: DTypeLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Computes the index of the maximum element along ``axis``."""
|
|
|
|
|
return argmax_p.bind(operand, axes=(axis,),
|
|
|
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
|
def reduce(operands: Any,
|
|
|
|
|
init_values: Any,
|
|
|
|
|
computation: Callable[[Any, Any], Any],
|
|
|
|
|
dimensions: Sequence[int]) -> Any:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Wraps XLA's `Reduce
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
|
|
|
|
|
operator.
|
2021-02-08 09:23:35 -05:00
|
|
|
|
|
|
|
|
|
``init_values`` and ``computation`` together must form a `monoid
|
|
|
|
|
<https://en.wikipedia.org/wiki/Monoid>`_
|
|
|
|
|
for correctness. That is ``init_values`` must be an identity of
|
|
|
|
|
``computation``, and ``computation`` must be associative. XLA may exploit both
|
|
|
|
|
of these properties during code generation; if either is violated the result
|
|
|
|
|
is undefined.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2020-11-10 15:57:19 -08:00
|
|
|
|
flat_operands, operand_tree = tree_util.tree_flatten(operands)
|
2025-02-05 19:17:47 +02:00
|
|
|
|
comp_debug = api_util.debug_info("reduce comp", computation,
|
|
|
|
|
(init_values, init_values), {})
|
2020-11-10 15:57:19 -08:00
|
|
|
|
flat_init_values, init_value_tree = tree_util.tree_flatten(init_values)
|
|
|
|
|
if operand_tree != init_value_tree:
|
|
|
|
|
raise ValueError('Operands must have the same tree structure as init_values:'
|
|
|
|
|
f' {operand_tree} vs. {init_value_tree}')
|
|
|
|
|
if len(flat_operands) != len(flat_init_values):
|
|
|
|
|
raise ValueError('Must have same total number of operands as init_values: '
|
|
|
|
|
f' {len(flat_operands)} vs. {len(flat_init_values)}')
|
|
|
|
|
monoid_reducer = _get_monoid_reducer(computation, flat_init_values)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if monoid_reducer:
|
2021-02-08 13:37:25 -08:00
|
|
|
|
# monoid reducers bypass the weak_type_rule, so we set it explicitly.
|
|
|
|
|
weak_type = dtypes.is_weakly_typed(*flat_operands) and dtypes.is_weakly_typed(*flat_init_values)
|
2021-03-28 10:32:02 -07:00
|
|
|
|
return _convert_element_type(monoid_reducer(*flat_operands, dimensions),
|
|
|
|
|
weak_type=weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2024-12-12 14:08:17 -08:00
|
|
|
|
flat_init_avals = safe_map(core.get_aval, flat_init_values)
|
2023-12-21 22:15:12 -08:00
|
|
|
|
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
|
2025-02-05 19:17:47 +02:00
|
|
|
|
computation, comp_debug, tuple(flat_init_avals), init_value_tree)
|
2022-01-03 01:52:33 -05:00
|
|
|
|
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
|
2023-12-21 22:15:12 -08:00
|
|
|
|
jaxpr=closed_jaxpr, dimensions=tuple(dimensions))
|
2020-11-10 15:57:19 -08:00
|
|
|
|
return tree_util.tree_unflatten(out_tree, out)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
@cache()
|
2025-01-24 12:53:51 +02:00
|
|
|
|
def _reduction_jaxpr(computation: Callable,
|
|
|
|
|
aval: core.AbstractValue):
|
2021-01-28 15:36:15 -08:00
|
|
|
|
def comp(x, y):
|
|
|
|
|
result = computation(x, y)
|
|
|
|
|
if not (isinstance(result, core.Tracer) or core.valid_jaxtype(result)):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid return type from reduction function: {type(result)}\n"
|
|
|
|
|
f"Reduction functions should only return an array.\n"
|
|
|
|
|
f"Full return value: {result}")
|
|
|
|
|
return (result,)
|
2025-01-24 12:53:51 +02:00
|
|
|
|
comp_wrapped = lu.wrap_init(
|
|
|
|
|
comp,
|
|
|
|
|
debug_info=api_util.debug_info("reduction_jaxpr", computation,
|
|
|
|
|
(aval, aval), {}))
|
|
|
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp_wrapped, (aval, aval))
|
2022-04-29 16:28:08 -07:00
|
|
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Reduction computations can't close over Tracers. Please open an issue "
|
2024-09-20 07:51:48 -07:00
|
|
|
|
"at https://github.com/jax-ml/jax.")
|
2022-04-29 16:28:08 -07:00
|
|
|
|
return jaxpr, tuple(consts)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2020-11-10 15:57:19 -08:00
|
|
|
|
@cache()
|
2025-02-05 19:17:47 +02:00
|
|
|
|
def _variadic_reduction_jaxpr(computation: Callable[[Any, Any], Any],
|
|
|
|
|
debug_info: core.DebugInfo,
|
|
|
|
|
flat_avals,
|
|
|
|
|
aval_tree: tree_util.PyTreeDef):
|
2020-11-10 15:57:19 -08:00
|
|
|
|
avals = tree_util.tree_unflatten(aval_tree, flat_avals)
|
|
|
|
|
flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
|
2025-02-05 19:17:47 +02:00
|
|
|
|
comp = lu.wrap_init(computation, debug_info=debug_info)
|
2020-11-10 15:57:19 -08:00
|
|
|
|
flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
|
2024-01-25 22:20:36 -08:00
|
|
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
|
2022-04-29 16:28:08 -07:00
|
|
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Reduction computations can't close over Tracers. Please open an issue "
|
2024-09-20 07:51:48 -07:00
|
|
|
|
"at https://github.com/jax-ml/jax.")
|
2023-12-21 22:15:12 -08:00
|
|
|
|
return core.ClosedJaxpr(jaxpr, consts), out_tree()
|
2020-11-10 15:57:19 -08:00
|
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
|
def _get_monoid_reducer(monoid_op: Callable,
|
2023-12-08 12:09:04 +00:00
|
|
|
|
xs: Sequence[Array]) -> Callable | None:
|
2020-11-10 15:57:19 -08:00
|
|
|
|
if len(xs) != 1:
|
|
|
|
|
return None
|
|
|
|
|
x, = xs
|
2020-10-17 14:33:26 -04:00
|
|
|
|
aval = core.get_aval(x)
|
|
|
|
|
dtype = _dtype(x)
|
2024-10-31 14:06:08 -07:00
|
|
|
|
if core.is_concrete(x) and aval.shape == ():
|
|
|
|
|
val = core.to_concrete_value(x)
|
2022-06-15 21:27:42 +02:00
|
|
|
|
# allow bitwise reductions for boolean and integer types
|
|
|
|
|
_is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if monoid_op is add:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_sum if np.equal(val, 0) else None
|
2020-12-15 10:15:49 -08:00
|
|
|
|
elif monoid_op is mul:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_prod if np.equal(val, 1) else None
|
2022-06-15 21:27:42 +02:00
|
|
|
|
elif monoid_op is bitwise_or and _is_intlike:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
2022-06-15 21:27:42 +02:00
|
|
|
|
elif monoid_op is bitwise_and and _is_intlike:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None
|
2022-06-15 21:27:42 +02:00
|
|
|
|
elif monoid_op is bitwise_xor and _is_intlike:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
2020-10-17 14:33:26 -04:00
|
|
|
|
elif monoid_op is max:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_max if np.equal(val, _get_max_identity(dtype)) else None
|
2020-10-17 14:33:26 -04:00
|
|
|
|
elif monoid_op is min:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_min if np.equal(val, _get_min_identity(dtype)) else None
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return None
|
|
|
|
|
|
2022-09-14 15:03:55 -07:00
|
|
|
|
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
|
2022-12-08 19:40:56 +00:00
|
|
|
|
return np.array(-1).astype(dtype)
|
2022-06-15 21:27:42 +02:00
|
|
|
|
|
2022-09-14 15:03:55 -07:00
|
|
|
|
def _get_bitwise_or_identity(dtype: DTypeLike) -> np.ndarray:
|
2022-06-15 21:27:42 +02:00
|
|
|
|
return np.array(0, dtype)
|
|
|
|
|
|
2022-10-10 18:51:04 -07:00
|
|
|
|
def _get_sum_identity(dtype: DTypeLike) -> np.ndarray:
|
|
|
|
|
return np.array(0, dtype)
|
|
|
|
|
|
|
|
|
|
def _get_prod_identity(dtype: DTypeLike) -> np.ndarray:
|
|
|
|
|
return np.array(1, dtype)
|
|
|
|
|
|
2022-09-14 15:03:55 -07:00
|
|
|
|
def _get_max_identity(dtype: DTypeLike) -> np.ndarray:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
2023-12-18 13:37:45 -08:00
|
|
|
|
return np.array(-np.inf if dtypes.supports_inf(dtype) else dtypes.finfo(dtype).min,
|
|
|
|
|
dtype=dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
|
|
|
return np.array(dtypes.iinfo(dtype).min, dtype)
|
|
|
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
|
|
|
return np.array(False, np.bool_)
|
2022-09-12 09:08:13 -07:00
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported dtype for max: {dtype}")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-14 15:03:55 -07:00
|
|
|
|
def _get_min_identity(dtype: DTypeLike) -> np.ndarray:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
2023-12-18 13:37:45 -08:00
|
|
|
|
return np.array(np.inf if dtypes.supports_inf(dtype) else dtypes.finfo(dtype).max,
|
|
|
|
|
dtype=dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
|
|
|
return np.array(dtypes.iinfo(dtype).max, dtype)
|
|
|
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
|
|
|
return np.array(True, np.bool_)
|
2022-09-12 09:08:13 -07:00
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported dtype for min: {dtype}")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-11 16:00:03 -08:00
|
|
|
|
def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
|
|
|
"""Compute the sum of elements over one or more array axes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array over which to sum. Must have numerical dtype.
|
|
|
|
|
axes: sequence of zero or more unique integers specifying the axes over
|
|
|
|
|
which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``operand``, with shape corresponding
|
|
|
|
|
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
Unlike :func:`jax.numpy.sum`, :func:`jax.lax.reduce_sum` does not upcast
|
|
|
|
|
narrow-width types for accumulation, so sums of 8-bit or 16-bit types
|
|
|
|
|
may be subject to rounding errors.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.sum`: more flexible NumPy-style summation API, built
|
|
|
|
|
around :func:`jax.lax.reduce_sum`.
|
|
|
|
|
- Other low-level :mod:`jax.lax` reduction operators:
|
|
|
|
|
:func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`,
|
|
|
|
|
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return reduce_sum_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
2025-02-11 16:00:03 -08:00
|
|
|
|
def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
|
|
|
"""Compute the product of elements over one or more array axes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array over which to sum. Must have numerical dtype.
|
|
|
|
|
axes: sequence of zero or more unique integers specifying the axes over
|
|
|
|
|
which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``operand``, with shape corresponding
|
|
|
|
|
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
Unlike :func:`jax.numpy.prod`, :func:`jax.lax.reduce_prod` does not upcast
|
|
|
|
|
narrow-width types for accumulation, so products of 8-bit or 16-bit types
|
|
|
|
|
may be subject to rounding errors.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.prod`: more flexible NumPy-style product API, built
|
|
|
|
|
around :func:`jax.lax.reduce_prod`.
|
|
|
|
|
- Other low-level :mod:`jax.lax` reduction operators:
|
|
|
|
|
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`,
|
|
|
|
|
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return reduce_prod_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
2025-02-11 16:00:03 -08:00
|
|
|
|
def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
|
|
|
"""Compute the maximum of elements over one or more array axes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array over which to compute maximum.
|
|
|
|
|
axes: sequence of zero or more unique integers specifying the axes over
|
|
|
|
|
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``operand``, with shape corresponding
|
|
|
|
|
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.max`: more flexible NumPy-style max-reduction API, built
|
|
|
|
|
around :func:`jax.lax.reduce_max`.
|
|
|
|
|
- Other low-level :mod:`jax.lax` reduction operators:
|
|
|
|
|
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`,
|
|
|
|
|
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return reduce_max_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
2025-02-11 16:00:03 -08:00
|
|
|
|
def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
|
|
|
"""Compute the minimum of elements over one or more array axes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array over which to compute minimum.
|
|
|
|
|
axes: sequence of zero or more unique integers specifying the axes over
|
|
|
|
|
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``operand``, with shape corresponding
|
|
|
|
|
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.min`: more flexible NumPy-style min-reduction API, built
|
|
|
|
|
around :func:`jax.lax.reduce_min`.
|
|
|
|
|
- Other low-level :mod:`jax.lax` reduction operators:
|
|
|
|
|
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
|
|
|
|
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return reduce_min_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
2025-02-11 16:00:03 -08:00
|
|
|
|
def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
|
|
|
"""Compute the bitwise OR of elements over one or more array axes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array over which to compute the reduction. Must have boolean
|
|
|
|
|
or integer dtype.
|
|
|
|
|
axes: sequence of zero or more unique integers specifying the axes over
|
|
|
|
|
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``operand``, with shape corresponding
|
|
|
|
|
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.bitwise_or.reduce`: more flexible NumPy-style logical
|
|
|
|
|
reduction API, built around :func:`jax.lax.reduce_or`.
|
|
|
|
|
- Other low-level :mod:`jax.lax` reduction operators:
|
|
|
|
|
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
|
|
|
|
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return reduce_or_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
2025-02-11 16:00:03 -08:00
|
|
|
|
def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
|
|
|
"""Compute the bitwise AND of elements over one or more array axes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array over which to compute the reduction. Must have boolean
|
|
|
|
|
or integer dtype.
|
|
|
|
|
axes: sequence of zero or more unique integers specifying the axes over
|
|
|
|
|
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``operand``, with shape corresponding
|
|
|
|
|
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.bitwise_and.reduce`: more flexible NumPy-style logical
|
|
|
|
|
reduction API, built around :func:`jax.lax.reduce_and`.
|
|
|
|
|
- Other low-level :mod:`jax.lax` reduction operators:
|
|
|
|
|
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
|
|
|
|
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return reduce_and_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
2025-02-11 16:00:03 -08:00
|
|
|
|
def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
|
|
|
"""Compute the bitwise XOR of elements over one or more array axes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: array over which to compute the reduction. Must have boolean
|
|
|
|
|
or integer dtype.
|
|
|
|
|
axes: sequence of zero or more unique integers specifying the axes over
|
|
|
|
|
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array of the same dtype as ``operand``, with shape corresponding
|
|
|
|
|
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.numpy.bitwise_xor.reduce`: more flexible NumPy-style logical
|
|
|
|
|
reduction API, built around :func:`jax.lax.reduce_xor`.
|
|
|
|
|
- Other low-level :mod:`jax.lax` reduction operators:
|
|
|
|
|
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
|
|
|
|
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`.
|
|
|
|
|
"""
|
2022-06-09 20:38:53 +02:00
|
|
|
|
return reduce_xor_p.bind(operand, axes=tuple(axes))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-14 15:11:53 -07:00
|
|
|
|
@overload
|
|
|
|
|
def sort(operand: Array, dimension: int = -1,
|
|
|
|
|
is_stable: bool = True, num_keys: int = 1) -> Array: ...
|
|
|
|
|
|
2023-09-28 12:51:36 +01:00
|
|
|
|
@overload
|
|
|
|
|
def sort(operand: Sequence[Array], dimension: int = -1,
|
|
|
|
|
is_stable: bool = True, num_keys: int = 1) -> tuple[Array, ...]: ...
|
|
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
|
def sort(operand: Array | Sequence[Array], dimension: int = -1,
|
|
|
|
|
is_stable: bool = True, num_keys: int = 1) -> Array | tuple[Array, ...]:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Wraps XLA's `Sort
|
2022-01-13 13:03:41 -08:00
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
|
|
|
|
|
|
|
|
|
|
For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values
|
|
|
|
|
are sorted to the end of the array. For complex inputs, the sort order is
|
|
|
|
|
lexicographic over the real and imaginary parts, with the real part primary.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand : Array or sequence of arrays
|
|
|
|
|
dimension : integer dimension along which to sort. Default: -1.
|
|
|
|
|
is_stable : boolean specifying whether to use a stable sort. Default: True.
|
|
|
|
|
num_keys : number of operands to treat as sort keys. Default: 1.
|
|
|
|
|
For num_keys > 1, the sort order will be determined lexicographically using
|
|
|
|
|
the first `num_keys` arrays, with the first key being primary.
|
|
|
|
|
The remaining operands will be returned with the same permutation.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
operand : sorted version of the input or inputs.
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(operand, Sequence):
|
|
|
|
|
if len(operand) == 0:
|
|
|
|
|
raise TypeError("Sort requires at least one operand")
|
|
|
|
|
if not (1 <= num_keys <= len(operand)):
|
2022-12-01 09:12:01 -08:00
|
|
|
|
raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dimension = canonicalize_axis(dimension, len(operand[0].shape))
|
2025-03-25 17:02:45 -07:00
|
|
|
|
operand = core.standard_insert_pbroadcast(*operand)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return tuple(sort_p.bind(*operand, dimension=dimension,
|
|
|
|
|
is_stable=is_stable,
|
|
|
|
|
num_keys=num_keys))
|
|
|
|
|
else:
|
|
|
|
|
if num_keys != 1:
|
2022-12-01 09:12:01 -08:00
|
|
|
|
raise ValueError(f"{num_keys=} must equal 1 for a single operand.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dimension = canonicalize_axis(dimension, len(operand.shape))
|
|
|
|
|
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def sort_key_val(keys: Array, values: ArrayLike, dimension: int = -1,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
is_stable: bool = True) -> tuple[Array, Array]:
|
2021-08-02 17:57:09 -07:00
|
|
|
|
"""Sorts ``keys`` along ``dimension`` and applies the same permutation to ``values``."""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dimension = canonicalize_axis(dimension, len(keys.shape))
|
|
|
|
|
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
|
|
|
|
|
return k, v
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
|
2023-02-08 11:07:56 -08:00
|
|
|
|
"""Returns top ``k`` values and their indices along the last axis of ``operand``.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: N-dimensional array of non-complex type.
|
|
|
|
|
k: integer specifying the number of top entries.
|
|
|
|
|
|
|
|
|
|
Returns:
|
2024-06-10 13:57:21 -07:00
|
|
|
|
A tuple ``(values, indices)`` where
|
|
|
|
|
|
|
|
|
|
- ``values`` is an array containing the top k values along the last axis.
|
|
|
|
|
- ``indices`` is an array containing the indices corresponding to values.
|
2023-02-08 11:07:56 -08:00
|
|
|
|
|
|
|
|
|
See also:
|
2024-06-10 13:57:21 -07:00
|
|
|
|
- :func:`jax.lax.approx_max_k`
|
|
|
|
|
- :func:`jax.lax.approx_min_k`
|
|
|
|
|
|
2024-06-21 11:28:35 -04:00
|
|
|
|
Examples:
|
2024-06-10 13:57:21 -07:00
|
|
|
|
Find the largest three values, and their indices, within an array:
|
|
|
|
|
|
|
|
|
|
>>> x = jnp.array([9., 3., 6., 4., 10.])
|
|
|
|
|
>>> values, indices = jax.lax.top_k(x, 3)
|
|
|
|
|
>>> values
|
|
|
|
|
Array([10., 9., 6.], dtype=float32)
|
|
|
|
|
>>> indices
|
|
|
|
|
Array([4, 0, 2], dtype=int32)
|
2023-02-08 11:07:56 -08:00
|
|
|
|
"""
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if core.is_constant_dim(k):
|
2023-04-12 14:08:12 +03:00
|
|
|
|
k = int(k)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if k < 0:
|
2022-05-12 19:13:00 +01:00
|
|
|
|
raise ValueError(f"k argument to top_k must be nonnegative, got {k}")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return top_k_p.bind(operand, k=k)
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def tie_in(x: Any, y: T) -> T:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Deprecated. Ignores ``x`` and returns ``y``."""
|
|
|
|
|
return y
|
|
|
|
|
|
2024-01-22 09:27:47 -08:00
|
|
|
|
def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
|
|
|
|
sharding: Sharding | None = None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Returns an array of `shape` filled with `fill_value`.
|
|
|
|
|
|
2021-01-15 11:49:19 +11:00
|
|
|
|
Args:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
shape: sequence of integers, describing the shape of the output array.
|
|
|
|
|
fill_value: the value to fill the new array with.
|
|
|
|
|
dtype: the type of the output array, or `None`. If not `None`, `fill_value`
|
|
|
|
|
will be cast to `dtype`.
|
2024-02-13 15:26:22 -08:00
|
|
|
|
sharding: an optional sharding specification for the resulting array,
|
2024-03-20 10:19:03 -07:00
|
|
|
|
note, sharding will currently be ignored in jitted mode, this might change
|
|
|
|
|
in the future.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
shape = canonicalize_shape(shape)
|
|
|
|
|
if np.shape(fill_value):
|
|
|
|
|
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
|
|
|
|
raise TypeError(msg.format(np.shape(fill_value)))
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(dtype, dtypes.extended):
|
2023-05-17 09:04:50 -07:00
|
|
|
|
return dtype._rules.full(shape, fill_value, dtype) # type: ignore[union-attr]
|
2021-02-08 13:37:25 -08:00
|
|
|
|
weak_type = dtype is None and dtypes.is_weakly_typed(fill_value)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
|
2021-03-28 10:32:02 -07:00
|
|
|
|
fill_value = _convert_element_type(fill_value, dtype, weak_type)
|
2024-02-13 15:26:22 -08:00
|
|
|
|
if (sharding is not None and not isinstance(sharding, PmapSharding) and
|
2025-01-29 09:33:44 -08:00
|
|
|
|
isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete):
|
2024-02-13 15:26:22 -08:00
|
|
|
|
broadcast_shape = sharding.shard_shape(shape)
|
|
|
|
|
shard = broadcast(fill_value, broadcast_shape)
|
2025-03-11 15:24:54 -07:00
|
|
|
|
shard = shard.addressable_data(0)
|
2024-02-13 15:26:22 -08:00
|
|
|
|
return array.make_array_from_callback(shape, sharding, lambda _: shard)
|
|
|
|
|
|
2025-02-19 06:52:52 -08:00
|
|
|
|
if sharding is not None and not sharding._is_concrete:
|
2025-02-12 13:58:38 -08:00
|
|
|
|
return broadcast(fill_value, shape, out_sharding=sharding)
|
2024-10-25 10:34:33 -07:00
|
|
|
|
else:
|
|
|
|
|
return broadcast(fill_value, shape)
|
2024-02-13 15:26:22 -08:00
|
|
|
|
|
2022-12-16 16:00:38 -08:00
|
|
|
|
def zeros_like_shaped_array(aval: ShapedArray) -> Array:
|
2021-06-25 08:43:04 +02:00
|
|
|
|
assert isinstance(aval, ShapedArray)
|
2023-12-20 21:00:08 -08:00
|
|
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
2023-12-21 17:43:31 -08:00
|
|
|
|
scalar_zero = aval.dtype._rules.zero(aval.dtype)
|
2023-12-20 21:00:08 -08:00
|
|
|
|
elif aval.dtype == dtypes.float0:
|
2021-07-25 16:07:33 +03:00
|
|
|
|
scalar_zero = np.zeros((), dtype=aval.dtype)
|
2021-11-23 15:04:08 -08:00
|
|
|
|
else:
|
|
|
|
|
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
|
2025-02-19 06:52:52 -08:00
|
|
|
|
return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding)
|
2021-06-25 08:43:04 +02:00
|
|
|
|
|
|
|
|
|
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
|
|
|
|
|
2024-04-04 14:33:06 -04:00
|
|
|
|
def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray:
|
|
|
|
|
val = ad_util.zeros_like_aval(aval.inner_aval)
|
|
|
|
|
return core.mutable_array(val)
|
|
|
|
|
|
2024-12-09 09:54:58 -05:00
|
|
|
|
# TODO(dougalm): this is nonsense but it's here because in places like
|
|
|
|
|
# custom_vjp we assume that all arguments have tangent spaces. We could have
|
|
|
|
|
# a distinct NotATangentType value instead.
|
2024-04-04 14:33:06 -04:00
|
|
|
|
ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore
|
|
|
|
|
|
2022-09-14 15:03:55 -07:00
|
|
|
|
def iota(dtype: DTypeLike, size: int) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Wraps XLA's `Iota
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
|
|
|
|
|
operator.
|
|
|
|
|
"""
|
2022-06-29 13:55:30 -07:00
|
|
|
|
return broadcasted_iota(dtype, (size,), 0)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
|
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
|
2025-02-12 13:58:38 -08:00
|
|
|
|
out_sharding=None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Convenience wrapper around ``iota``."""
|
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
|
shape = canonicalize_shape(shape)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
dynamic_shape = [d for d in shape if isinstance(d, core.Tracer)]
|
|
|
|
|
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
|
2020-10-23 07:34:32 -07:00
|
|
|
|
dimension = core.concrete_or_error(
|
|
|
|
|
int, dimension, "dimension argument of lax.broadcasted_iota")
|
2025-02-19 09:21:07 -08:00
|
|
|
|
out_sharding = canonicalize_sharding(out_sharding, 'broadcasted_iota')
|
2022-06-29 13:55:30 -07:00
|
|
|
|
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
|
2025-02-12 13:58:38 -08:00
|
|
|
|
dimension=dimension, sharding=out_sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-09-01 07:49:49 -07:00
|
|
|
|
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
|
2020-10-23 07:34:32 -07:00
|
|
|
|
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
2024-07-31 10:12:31 +02:00
|
|
|
|
offset = _clip_int_to_valid_range(offset, np.int32,
|
|
|
|
|
"argument `offset` of jax.numpy.eye")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-08-03 09:12:04 +03:00
|
|
|
|
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
|
|
|
|
broadcasted_iota(np.int32, shape, 1))
|
2024-07-09 07:32:38 -07:00
|
|
|
|
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False,
|
|
|
|
|
sharding=None)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-14 15:03:55 -07:00
|
|
|
|
def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
2020-10-23 07:34:32 -07:00
|
|
|
|
"""This utility function exists for creating Kronecker delta arrays."""
|
2022-06-17 15:53:53 -07:00
|
|
|
|
axes = map(int, axes)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2024-05-17 09:46:36 +01:00
|
|
|
|
base_shape = tuple(np.take(shape, axes))
|
2021-03-03 20:27:33 -05:00
|
|
|
|
iotas = [broadcasted_iota(np.uint32, base_shape, i)
|
|
|
|
|
for i in range(len(base_shape))]
|
|
|
|
|
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
2024-07-09 07:32:38 -07:00
|
|
|
|
result = convert_element_type_p.bind(
|
|
|
|
|
_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False,
|
|
|
|
|
sharding=None)
|
2021-03-03 20:27:33 -05:00
|
|
|
|
return broadcast_in_dim(result, shape, axes)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-05-29 14:23:03 +03:00
|
|
|
|
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
|
2020-10-23 07:34:32 -07:00
|
|
|
|
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
|
2024-07-31 10:12:31 +02:00
|
|
|
|
offset = _clip_int_to_valid_range(offset, np.int32,
|
|
|
|
|
"argument `offset` of jax.numpy.tri")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2024-05-29 14:23:03 +03:00
|
|
|
|
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0),
|
|
|
|
|
asarray(core.dimension_as_value(offset)).astype(np.int32)),
|
2021-08-03 09:12:04 +03:00
|
|
|
|
broadcasted_iota(np.int32, shape, 1))
|
2024-07-09 07:32:38 -07:00
|
|
|
|
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False,
|
|
|
|
|
sharding=None)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
|
def stop_gradient(x: T) -> T:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Stops gradient computation.
|
|
|
|
|
|
|
|
|
|
Operationally ``stop_gradient`` is the identity function, that is, it returns
|
|
|
|
|
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
|
|
|
|
|
gradients during forward or reverse-mode automatic differentiation. If there
|
|
|
|
|
are multiple nested gradient computations, ``stop_gradient`` stops gradients
|
2024-09-17 15:58:14 -07:00
|
|
|
|
for all of them. For some discussion of where this is useful, refer to
|
|
|
|
|
:ref:`stopping-gradients`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: array or pytree of arrays
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
input value is returned unchanged, but within autodiff will be treated as
|
|
|
|
|
a constant.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
Consider a simple function that returns the square of the input value:
|
|
|
|
|
|
|
|
|
|
>>> def f1(x):
|
|
|
|
|
... return x ** 2
|
|
|
|
|
>>> x = jnp.float32(3.0)
|
|
|
|
|
>>> f1(x)
|
|
|
|
|
Array(9.0, dtype=float32)
|
|
|
|
|
>>> jax.grad(f1)(x)
|
|
|
|
|
Array(6.0, dtype=float32)
|
|
|
|
|
|
|
|
|
|
The same function with ``stop_gradient`` around ``x`` will be equivalent
|
|
|
|
|
under normal evaluation, but return a zero gradient because ``x`` is
|
|
|
|
|
effectively treated as a constant:
|
|
|
|
|
|
|
|
|
|
>>> def f2(x):
|
|
|
|
|
... return jax.lax.stop_gradient(x) ** 2
|
|
|
|
|
>>> f2(x)
|
|
|
|
|
Array(9.0, dtype=float32)
|
|
|
|
|
>>> jax.grad(f2)(x)
|
|
|
|
|
Array(0.0, dtype=float32)
|
|
|
|
|
|
|
|
|
|
This is used in a number of places within the JAX codebase; for example
|
|
|
|
|
:func:`jax.nn.softmax` internally normalizes the input by its maximum
|
|
|
|
|
value, and this maximum value is wrapped in ``stop_gradient`` for
|
|
|
|
|
efficiency. Refer to :ref:`stopping-gradients` for more discussion of
|
|
|
|
|
the applicability of ``stop_gradient``.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
|
|
|
|
def stop(x):
|
2022-08-22 13:56:50 -07:00
|
|
|
|
# only bind primitive on inexact dtypes, to avoid some staging
|
2023-07-24 14:46:58 -07:00
|
|
|
|
if dtypes.issubdtype(core.get_aval(x).dtype, dtypes.extended):
|
2022-08-22 13:56:50 -07:00
|
|
|
|
return x
|
|
|
|
|
elif (dtypes.issubdtype(_dtype(x), np.floating) or
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dtypes.issubdtype(_dtype(x), np.complexfloating)):
|
2024-10-29 11:03:49 -07:00
|
|
|
|
# break abstractions to support legacy leaked tracer use cases
|
|
|
|
|
if isinstance(x, ad.JVPTracer):
|
|
|
|
|
return stop(x.primal)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return ad_util.stop_gradient_p.bind(x)
|
|
|
|
|
else:
|
2022-08-22 13:56:50 -07:00
|
|
|
|
return x
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return tree_map(stop, x)
|
|
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
|
def reduce_precision(operand: float | ArrayLike,
|
2021-12-09 22:42:04 -05:00
|
|
|
|
exponent_bits: int,
|
|
|
|
|
mantissa_bits: int) -> Array:
|
2021-11-23 18:57:45 -08:00
|
|
|
|
"""Wraps XLA's `ReducePrecision
|
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reduceprecision>`_
|
|
|
|
|
operator.
|
|
|
|
|
"""
|
|
|
|
|
exponent_bits = core.concrete_or_error(
|
|
|
|
|
operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision")
|
|
|
|
|
mantissa_bits = core.concrete_or_error(
|
|
|
|
|
operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision")
|
2025-03-25 17:02:45 -07:00
|
|
|
|
return reduce_precision_p.bind(operand, exponent_bits=exponent_bits,
|
|
|
|
|
mantissa_bits=mantissa_bits)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array:
|
2021-11-23 18:57:45 -08:00
|
|
|
|
"""Squeeze any number of size 1 dimensions from an array."""
|
|
|
|
|
ndim = np.ndim(array)
|
|
|
|
|
dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions))
|
2022-12-05 16:13:33 -08:00
|
|
|
|
if not dimensions and isinstance(array, Array):
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return array
|
2021-11-23 18:57:45 -08:00
|
|
|
|
return squeeze_p.bind(array, dimensions=dimensions)
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:
|
2021-11-23 18:57:45 -08:00
|
|
|
|
"""Insert any number of size 1 dimensions into an array."""
|
2022-02-17 11:26:35 -08:00
|
|
|
|
if len(set(dimensions)) != len(dimensions):
|
|
|
|
|
raise ValueError(f'repeated axis in lax.expand_dims: {dimensions}')
|
|
|
|
|
ndim_out = np.ndim(array) + len(dimensions)
|
|
|
|
|
dims = [canonicalize_axis(i, ndim_out) for i in dimensions]
|
|
|
|
|
if len(set(dims)) != len(dims): # check again after canonicalizing
|
|
|
|
|
raise ValueError(f'repeated axis in lax.expand_dims: {dims}')
|
|
|
|
|
dims_set = frozenset(dims)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
result_shape = list(np.shape(array))
|
|
|
|
|
for i in sorted(dims_set):
|
|
|
|
|
result_shape.insert(i, 1)
|
|
|
|
|
broadcast_dims = [i for i in range(ndim_out) if i not in dims_set]
|
|
|
|
|
return broadcast_in_dim(array, result_shape, broadcast_dims)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
### convenience wrappers around traceables
|
|
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
|
def full_like(x: ArrayLike | DuckTypedArray,
|
|
|
|
|
fill_value: ArrayLike, dtype: DTypeLike | None = None,
|
2024-01-22 11:55:25 -08:00
|
|
|
|
shape: Shape | None = None, sharding: Sharding | None = None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Create a full array like np.full based on the example array `x`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: example array-like, used for shape and dtype information.
|
|
|
|
|
fill_value: a scalar value to fill the entries of the output array.
|
|
|
|
|
dtype: optional, a dtype parameter for the output ndarray.
|
|
|
|
|
shape: optional, a shape parameter for the output ndarray.
|
2024-01-22 11:55:25 -08:00
|
|
|
|
sharding: an optional sharding specification for the resulting array.
|
|
|
|
|
If not specified, the output will have the same sharding as the input,
|
2024-02-13 15:26:22 -08:00
|
|
|
|
with a few exceptions/limitations in particular:
|
|
|
|
|
1. Sharding is not available during tracing, thus this will rely on jit.
|
2024-06-17 13:55:46 +05:30
|
|
|
|
2. If x is weakly typed or uncommitted, will use default sharding.
|
2024-02-13 15:26:22 -08:00
|
|
|
|
3. Shape is not None and is different from x.shape, default will be used.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An ndarray with the same shape as `x` with its entries set equal to
|
|
|
|
|
`fill_value`, similar to the output of np.full.
|
|
|
|
|
"""
|
2023-06-06 00:59:51 -07:00
|
|
|
|
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape) # type: ignore[arg-type]
|
2021-02-08 13:37:25 -08:00
|
|
|
|
weak_type = dtype is None and dtypes.is_weakly_typed(x)
|
|
|
|
|
dtype = dtype or _dtype(x)
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(dtype, dtypes.extended):
|
2023-05-17 09:04:50 -07:00
|
|
|
|
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
|
2024-02-13 15:26:22 -08:00
|
|
|
|
|
2025-02-19 06:52:52 -08:00
|
|
|
|
if sharding is None and shape is None and isinstance(x, core.Tracer):
|
2025-01-20 15:12:12 -08:00
|
|
|
|
sharding = x.aval.sharding
|
2024-10-25 10:34:33 -07:00
|
|
|
|
else:
|
|
|
|
|
# If `x` has a sharding but no `_committed` attribute
|
|
|
|
|
# (in case of ShapeDtypeStruct), default it to True.
|
|
|
|
|
use_x_sharding = (
|
|
|
|
|
sharding is None
|
|
|
|
|
# Tracer have special logic in handling sharding and even
|
|
|
|
|
# though hasattr(x, 'sharding') returns False, it is very slow.
|
|
|
|
|
# This bypasses the check.
|
|
|
|
|
and not isinstance(x, core.Tracer)
|
|
|
|
|
and hasattr(x, 'sharding')
|
|
|
|
|
and getattr(x, '_committed', True)
|
|
|
|
|
and not weak_type
|
|
|
|
|
and fill_shape == np.shape(x) # type: ignore[arg-type]
|
|
|
|
|
)
|
|
|
|
|
if use_x_sharding:
|
|
|
|
|
# TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported.
|
|
|
|
|
sharding = x.sharding # type: ignore
|
2024-01-22 11:55:25 -08:00
|
|
|
|
val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type),
|
|
|
|
|
sharding=sharding)
|
2022-08-16 16:24:50 -07:00
|
|
|
|
return val
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collapse(operand: Array, start_dimension: int,
|
2023-12-08 12:09:04 +00:00
|
|
|
|
stop_dimension: int | None = None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Collapses dimensions of an array into a single dimension.
|
|
|
|
|
|
|
|
|
|
For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
|
|
|
|
|
``collapse(operand, 0, 2).shape == [6, 4]``. The elements of the collapsed
|
|
|
|
|
dimension are laid out major-to-minor, i.e., with the lowest-numbered
|
|
|
|
|
dimension as the slowest varying dimension.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: an input array.
|
|
|
|
|
start_dimension: the start of the dimensions to collapse (inclusive).
|
2023-06-26 18:29:55 -07:00
|
|
|
|
stop_dimension: the end of the dimensions to collapse (exclusive). Pass None
|
|
|
|
|
to collapse all the dimensions after start.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
An array where dimensions ``[start_dimension, stop_dimension)`` have been
|
|
|
|
|
collapsed (raveled) into a single dimension.
|
|
|
|
|
"""
|
2023-06-26 18:29:55 -07:00
|
|
|
|
lo, hi, _ = slice(start_dimension, stop_dimension).indices(len(operand.shape))
|
|
|
|
|
if hi < lo:
|
|
|
|
|
raise ValueError(f"Invalid dimension range passed to collapse: {operand.shape}"
|
|
|
|
|
f"[{start_dimension}:{stop_dimension}]")
|
2023-02-28 12:40:30 -08:00
|
|
|
|
size = math.prod(operand.shape[lo:hi])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
|
|
|
|
return reshape(operand, new_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_matmul(lhs: Array, rhs: Array,
|
|
|
|
|
precision: PrecisionLike = None) -> Array:
|
|
|
|
|
"""Batch matrix multiplication."""
|
|
|
|
|
if _min(lhs.ndim, rhs.ndim) < 2:
|
|
|
|
|
raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}'
|
|
|
|
|
.format(lhs.ndim, rhs.ndim))
|
|
|
|
|
if lhs.ndim != rhs.ndim:
|
|
|
|
|
raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}'
|
|
|
|
|
.format(lhs.ndim, rhs.ndim))
|
|
|
|
|
lhs_contract = (lhs.ndim - 1,)
|
|
|
|
|
rhs_contract = (rhs.ndim - 2,)
|
|
|
|
|
batch = tuple(range(lhs.ndim - 2))
|
|
|
|
|
return dot_general(lhs, rhs, ((lhs_contract, rhs_contract), (batch, batch)),
|
|
|
|
|
precision=precision)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# These functions also exist in the XLA client library, but we treat them
|
|
|
|
|
# as non-primitive to maintain a smaller set of autodiff primitives.
|
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def square(x: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise square: :math:`x^2`."""
|
2024-11-13 11:14:16 +02:00
|
|
|
|
return square_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def reciprocal(x: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
r"""Elementwise reciprocal: :math:`1 \over x`."""
|
|
|
|
|
return integer_pow(x, -1)
|
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def tan(x: ArrayLike, accuracy=None) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the `stablehlo.tangent`_ operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
2025-03-27 17:12:08 -07:00
|
|
|
|
accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that
|
|
|
|
|
selects the implementation of the op based on the requested accuracy. If
|
|
|
|
|
the implementation cannot satisfy the requested tolerance, the
|
|
|
|
|
compiler will return an error. If mode is specified and there are no
|
|
|
|
|
multiple implementations available, the default implementation will be
|
|
|
|
|
used.
|
2025-02-06 11:09:42 -08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
tangent.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.cos`: elementwise cosine.
|
|
|
|
|
- :func:`jax.lax.sin`: elementwise sine.
|
|
|
|
|
- :func:`jax.lax.atan`: elementwise arc tangent.
|
|
|
|
|
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.
|
|
|
|
|
|
|
|
|
|
.. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent
|
|
|
|
|
"""
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return tan_p.bind(x, accuracy=accuracy)
|
2024-09-18 07:40:58 -07:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def asin(x: ArrayLike) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.asin`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the
|
|
|
|
|
element-wise arc sine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.sin`: elementwise sine.
|
|
|
|
|
- :func:`jax.lax.acos`: elementwise arc cosine.
|
|
|
|
|
- :func:`jax.lax.atan`: elementwise arc tangent.
|
|
|
|
|
"""
|
2020-12-07 17:34:27 -05:00
|
|
|
|
return asin_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def acos(x: ArrayLike) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.acos`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the
|
|
|
|
|
element-wise arc cosine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.cos`: elementwise cosine.
|
|
|
|
|
- :func:`jax.lax.asin`: elementwise arc sine.
|
|
|
|
|
- :func:`jax.lax.atan`: elementwise arc tangent.
|
|
|
|
|
"""
|
2020-12-07 17:34:27 -05:00
|
|
|
|
return acos_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-06 11:09:42 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def atan(x: ArrayLike) -> Array:
|
2025-02-06 11:09:42 -08:00
|
|
|
|
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.atan`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the
|
|
|
|
|
element-wise arc tangent.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.tan`: elementwise tangent.
|
|
|
|
|
- :func:`jax.lax.acos`: elementwise arc cosine.
|
|
|
|
|
- :func:`jax.lax.asin`: elementwise arc sine.
|
|
|
|
|
- :func:`jax.lax.atan2`: elementwise 2-term arc tangent.
|
|
|
|
|
"""
|
2020-12-07 17:34:27 -05:00
|
|
|
|
return atan_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def sinh(x: ArrayLike) -> Array:
|
2025-02-07 09:33:25 -08:00
|
|
|
|
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.sinh`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
hyperbolic sine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.asinh`: elementwise inverse hyperbolic sine.
|
|
|
|
|
- :func:`jax.lax.cosh`: elementwise hyperbolic cosine.
|
|
|
|
|
- :func:`jax.lax.tanh`: elementwise hyperbolic tangent.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return sinh_p.bind(x)
|
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def cosh(x: ArrayLike) -> Array:
|
2025-02-07 09:33:25 -08:00
|
|
|
|
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.cosh`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
hyperbolic cosine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.acosh`: elementwise inverse hyperbolic cosine.
|
|
|
|
|
- :func:`jax.lax.sinh`: elementwise hyperbolic sine.
|
|
|
|
|
- :func:`jax.lax.tanh`: elementwise hyperbolic tangent.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return cosh_p.bind(x)
|
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def asinh(x: ArrayLike) -> Array:
|
2025-02-07 09:33:25 -08:00
|
|
|
|
r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.asinh`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
inverse hyperbolic sine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.acosh`: elementwise inverse hyperbolic cosine.
|
|
|
|
|
- :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent.
|
|
|
|
|
- :func:`jax.lax.sinh`: elementwise hyperbolic sine.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return asinh_p.bind(x)
|
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def acosh(x: ArrayLike) -> Array:
|
2025-02-07 09:33:25 -08:00
|
|
|
|
r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.acosh`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
inverse hyperbolic cosine.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.asinh`: elementwise inverse hyperbolic sine.
|
|
|
|
|
- :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent.
|
|
|
|
|
- :func:`jax.lax.cosh`: elementwise hyperbolic cosine.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return acosh_p.bind(x)
|
|
|
|
|
|
2025-02-14 08:40:19 -08:00
|
|
|
|
@export
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def atanh(x: ArrayLike) -> Array:
|
2025-02-07 09:33:25 -08:00
|
|
|
|
r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`.
|
|
|
|
|
|
|
|
|
|
This function lowers directly to the ``chlo.atanh`` operation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: input array. Must have floating-point or complex type.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
|
|
|
|
inverse hyperbolic tangent.
|
|
|
|
|
|
|
|
|
|
See also:
|
|
|
|
|
- :func:`jax.lax.acosh`: elementwise inverse hyperbolic cosine.
|
|
|
|
|
- :func:`jax.lax.asinh`: elementwise inverse hyperbolic sine.
|
|
|
|
|
- :func:`jax.lax.tanh`: elementwise hyperbolic tangent.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return atanh_p.bind(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Add some methods to ShapedArray that rely on lax primitives
|
|
|
|
|
|
|
|
|
|
ShapedArray.broadcast = core.aval_method(broadcast)
|
|
|
|
|
ShapedArray.transpose = core.aval_method(transpose) # clobbered by lax_numpy
|
|
|
|
|
ShapedArray.reshape = core.aval_method(reshape) # clobbered by lax_numpy
|
|
|
|
|
|
|
|
|
|
def _iter(tracer):
|
|
|
|
|
if tracer.ndim == 0:
|
|
|
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
|
|
|
else:
|
|
|
|
|
n = int(tracer.shape[0])
|
2022-09-29 14:00:47 -07:00
|
|
|
|
if any(isinstance(d, core.Tracer) for d in tracer.shape):
|
|
|
|
|
return (slicing.dynamic_index_in_dim(tracer, i, keepdims=False)
|
|
|
|
|
for i in range(n))
|
|
|
|
|
else:
|
|
|
|
|
return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ShapedArray._iter = staticmethod(_iter)
|
2022-09-26 16:31:18 -07:00
|
|
|
|
core.DShapedArray._iter = staticmethod(_iter)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def zeros_like_array(x: ArrayLike) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return full_like(x, 0)
|
|
|
|
|
|
|
|
|
|
|
2023-12-22 16:01:00 -08:00
|
|
|
|
def _add_arrays(x, y):
|
|
|
|
|
if (isinstance(a := core.get_aval(x), ShapedArray) and
|
|
|
|
|
dtypes.issubdtype(a.dtype, dtypes.extended)):
|
2024-05-22 06:35:38 -07:00
|
|
|
|
return dtype._rules.add(dtype, x, y) # pytype: disable=attribute-error
|
2023-12-22 16:01:00 -08:00
|
|
|
|
return add(x, y)
|
|
|
|
|
|
2023-12-22 15:53:48 -08:00
|
|
|
|
for t in itertools.chain(
|
|
|
|
|
dtypes.python_scalar_dtypes.keys(), array_types, [array.ArrayImpl]):
|
2023-12-22 16:01:00 -08:00
|
|
|
|
ad_util.raw_jaxval_adders[t] = _add_arrays
|
2023-12-22 15:53:48 -08:00
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
### primitives
|
|
|
|
|
|
|
|
|
|
|
2023-06-14 18:30:52 -07:00
|
|
|
|
_fixed_dtype = \
|
|
|
|
|
lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
|
2025-02-21 04:50:35 -08:00
|
|
|
|
_complex_basetype = lambda dtype, **kwargs: np.abs(np.zeros((), dtype)).dtype
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
|
_strip_weak_type = lambda *args, **_: False
|
2021-03-09 13:48:15 -08:00
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
|
2024-02-13 10:35:44 -08:00
|
|
|
|
if aval.dtype == dtypes.float0:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Called {name} with a float0 array. "
|
|
|
|
|
"float0s do not support any operations by design, because they "
|
|
|
|
|
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
|
|
|
|
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
|
|
|
|
"to cast a float0 array to a regular zeros array. \n"
|
|
|
|
|
"If you didn't expect to get a float0 you might have accidentally "
|
|
|
|
|
"taken a gradient with respect to an integer argument.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes):
|
|
|
|
|
msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
2023-04-24 16:54:25 -07:00
|
|
|
|
typename = dtype_to_string(aval.dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
accepted_typenames = (t.__name__ for t in accepted_dtypes)
|
|
|
|
|
raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames)))
|
2025-02-21 04:50:35 -08:00
|
|
|
|
return result_dtype(aval.dtype, **kwargs)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
def unop(result_dtype, accepted_dtypes, name):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
|
2024-10-17 15:54:42 -07:00
|
|
|
|
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_attrgetter('sharding'),
|
|
|
|
|
vma_rule=_attrgetter('vma'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
batching.defvectorized(prim)
|
2022-08-05 07:37:55 -07:00
|
|
|
|
pe.def_trivial_padding(prim)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return prim
|
2024-10-17 15:54:42 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
standard_unop = partial(unop, _identity)
|
2024-10-17 15:54:42 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
|
|
|
|
|
|
|
|
|
|
|
2023-04-26 13:12:04 -07:00
|
|
|
|
def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
|
2023-06-14 18:30:52 -07:00
|
|
|
|
require_same=True, allow_extended_dtype=False, **kwargs):
|
2023-04-11 13:11:41 -07:00
|
|
|
|
assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes)
|
|
|
|
|
for i, aval in enumerate(avals):
|
Use an isinstance check rather than dtypes.issubdtype to check whether the dtype in an aval is an extended dtype.
We don't need the full generality of issubdtype, and this is slightly faster. This operation is very common (e.g., for every aval construction, even with a non-extended dtype).
On my laptop:
```
In [18]: d = jnp.dtype(jnp.int32)
In [20]: %timeit jax.dtypes.issubdtype(d, jax.dtypes.extended)
490 ns ± 2.78 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [22]: %timeit isinstance(d, jax._src.dtypes.ExtendedDType)
78.3 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
```
PiperOrigin-RevId: 606616884
2024-02-13 07:37:14 -08:00
|
|
|
|
if allow_extended_dtype and isinstance(aval.dtype, dtypes.ExtendedDType):
|
2023-04-26 13:12:04 -07:00
|
|
|
|
continue
|
2023-04-11 13:11:41 -07:00
|
|
|
|
types = accepted_dtypes[i]
|
|
|
|
|
if not any(dtypes.issubdtype(aval.dtype, t) for t in types):
|
|
|
|
|
if aval.dtype == dtypes.float0:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Called {name} with a float0 at position {i}. "
|
|
|
|
|
"float0s do not support any operations by design, because they "
|
|
|
|
|
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
|
|
|
|
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
|
|
|
|
"to cast a float0 array to a regular zeros array. \n"
|
|
|
|
|
"If you didn't expect to get a float0 you might have accidentally "
|
|
|
|
|
"taken a gradient with respect to an integer argument.")
|
|
|
|
|
else:
|
|
|
|
|
msg = ('{} does not accept dtype {} at position {}. '
|
|
|
|
|
'Accepted dtypes at position {} are subtypes of {}.')
|
2023-04-24 16:54:25 -07:00
|
|
|
|
typename = dtype_to_string(aval.dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
typenames = ', '.join(t.__name__ for t in types)
|
|
|
|
|
raise TypeError(msg.format(name, typename, i, i, typenames))
|
2023-06-14 18:30:52 -07:00
|
|
|
|
if require_same: check_same_dtypes(name, *avals)
|
2025-02-21 04:04:51 -08:00
|
|
|
|
return result_dtype(*avals, **kwargs)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
|
def broadcasting_shape_rule(name, *avals):
|
2020-11-13 14:55:04 -08:00
|
|
|
|
shapes = [aval.shape for aval in avals if aval.shape]
|
|
|
|
|
if not shapes:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return ()
|
2024-12-11 15:03:39 -08:00
|
|
|
|
return _try_broadcast_shapes(*shapes, name=name)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
|
|
2024-08-29 10:49:30 -07:00
|
|
|
|
def broadcasting_sharding_rule(name, *avals):
|
|
|
|
|
mesh = None
|
|
|
|
|
for a in avals:
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
if a.sharding is not None and not a.sharding.mesh.empty:
|
2024-08-29 10:49:30 -07:00
|
|
|
|
if mesh is not None and mesh != a.sharding.mesh:
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2024-08-29 10:49:30 -07:00
|
|
|
|
f'Mesh for all inputs should be equal. Got one mesh: {mesh} and'
|
|
|
|
|
f' another mesh: {a.sharding.mesh}')
|
2024-11-18 16:20:21 -08:00
|
|
|
|
mesh = a.sharding.mesh
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
mesh = mesh_lib.get_abstract_mesh() if mesh is None else mesh
|
2024-08-29 10:49:30 -07:00
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
shapes = [aval.shape for aval in avals if aval.shape]
|
|
|
|
|
if not shapes:
|
|
|
|
|
return NamedSharding(mesh, P())
|
|
|
|
|
if len({len(shape) for shape in shapes}) != 1:
|
|
|
|
|
msg = '{}: arrays must have same number of dimensions, got {}.'
|
|
|
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
|
|
|
|
|
|
|
|
|
specs = [a.sharding.spec for a in avals if a.shape]
|
|
|
|
|
|
2024-10-09 21:23:57 -07:00
|
|
|
|
result_specs = [None] * len(shapes[0])
|
|
|
|
|
for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))):
|
2024-08-29 10:49:30 -07:00
|
|
|
|
if all(s == ss[0] for s in ss[1:]):
|
|
|
|
|
# if all dimension shardings are same, the resulting dimension sharding is
|
|
|
|
|
# the same.
|
2024-10-09 21:23:57 -07:00
|
|
|
|
result_specs[i] = ss[0]
|
2024-08-29 10:49:30 -07:00
|
|
|
|
else:
|
|
|
|
|
non_trivial_s = [s for s, d in zip(ss, ds)
|
|
|
|
|
if not (core.definitely_equal(d, 1) and s is None)]
|
|
|
|
|
if not non_trivial_s:
|
2024-10-09 21:23:57 -07:00
|
|
|
|
result_specs[i] = None
|
2024-08-29 10:49:30 -07:00
|
|
|
|
elif all(non_trivial_s[0] == s for s in non_trivial_s[1:]):
|
2024-10-09 21:23:57 -07:00
|
|
|
|
result_specs[i] = non_trivial_s[0]
|
2024-08-29 10:49:30 -07:00
|
|
|
|
else:
|
2024-10-09 21:23:57 -07:00
|
|
|
|
for s in ss:
|
|
|
|
|
if result_specs[i] is None and s is not None:
|
|
|
|
|
result_specs[i] = s
|
|
|
|
|
elif (result_specs[i] is not None and s is not None and
|
|
|
|
|
result_specs[i] != s):
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2024-10-09 21:23:57 -07:00
|
|
|
|
f'{name} got incompatible shardings for broadcasting: '
|
|
|
|
|
f'{", ".join(map(str, map(tuple, specs)))}.')
|
|
|
|
|
return NamedSharding(mesh, P(*result_specs))
|
2024-08-29 10:49:30 -07:00
|
|
|
|
|
2023-06-14 18:30:52 -07:00
|
|
|
|
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
|
2024-12-11 11:59:10 -08:00
|
|
|
|
require_same_dtypes=True):
|
2023-04-26 13:12:04 -07:00
|
|
|
|
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
|
2023-06-14 18:30:52 -07:00
|
|
|
|
allow_extended_dtype=allow_extended_dtype,
|
|
|
|
|
require_same=require_same_dtypes)
|
2022-08-22 13:56:50 -07:00
|
|
|
|
shape_rule = partial(broadcasting_shape_rule, name)
|
2024-08-29 10:49:30 -07:00
|
|
|
|
sharding_rule = partial(broadcasting_sharding_rule, name)
|
2025-03-21 10:25:38 -07:00
|
|
|
|
prim = standard_primitive(
|
|
|
|
|
shape_rule, dtype_rule, name, sharding_rule=sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, name))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
batching.defbroadcasting(prim)
|
2022-08-05 07:37:55 -07:00
|
|
|
|
pe.def_trivial_padding(prim)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return prim
|
|
|
|
|
standard_naryop = partial(naryop, _input_dtype)
|
|
|
|
|
|
|
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
# Like autograd.numpy.numpy_vjps.unbroadcast, this utility handles transposition
|
|
|
|
|
# involving linear primitives with implicit broadcasting.
|
|
|
|
|
def _unbroadcast(aval, x):
|
2022-06-17 15:53:53 -07:00
|
|
|
|
if not isinstance(aval, (core.DShapedArray, ShapedArray)):
|
2021-04-15 15:16:29 -07:00
|
|
|
|
raise TypeError("transpose with implicit broadcasting of unshaped values")
|
|
|
|
|
x_shape = np.shape(x)
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if core.definitely_equal_shape(aval.shape, x_shape):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return x
|
2021-04-15 15:16:29 -07:00
|
|
|
|
assert not aval.shape or len(x_shape) == len(aval.shape)
|
|
|
|
|
if not aval.shape:
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reduce_sum(x, list(range(len(x_shape))))
|
2021-04-15 15:16:29 -07:00
|
|
|
|
else:
|
2023-06-30 12:31:47 +03:00
|
|
|
|
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims)
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return reshape(reduce_sum(x, dims), aval.shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
def _maybe_broadcast(target_shape, x):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
x_shape = np.shape(x)
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if core.definitely_equal_shape(x_shape, target_shape):
|
2021-04-15 15:16:29 -07:00
|
|
|
|
return x
|
2022-06-17 15:53:53 -07:00
|
|
|
|
elif not x_shape:
|
|
|
|
|
return broadcast_in_dim(x, target_shape, ())
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2022-06-17 15:53:53 -07:00
|
|
|
|
dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape))
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if core.definitely_equal(a, b)]
|
2021-04-15 15:16:29 -07:00
|
|
|
|
squeeze_shape = [x_shape[i] for i in dims]
|
|
|
|
|
return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
|
def broadcast_hlo(
|
2021-11-23 18:57:45 -08:00
|
|
|
|
aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
|
|
|
|
|
args: Sequence[ir.Value]) -> Sequence[ir.Value]:
|
2022-12-15 20:59:34 -08:00
|
|
|
|
"""Broadcasts HLO values with broadcast-compatible shapes to the same shape.
|
2021-11-23 18:57:45 -08:00
|
|
|
|
"""
|
|
|
|
|
out = []
|
|
|
|
|
for aval, arg in zip(avals, args):
|
|
|
|
|
if aval.shape != aval_out.shape:
|
|
|
|
|
assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out)
|
2024-05-28 10:58:10 -07:00
|
|
|
|
dims = mlir.dense_int_array(
|
2024-06-27 19:32:18 -07:00
|
|
|
|
list(range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape))))
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if any(isinstance(d, ir.Value) for d in aval_out.shape):
|
2023-11-17 11:46:24 -08:00
|
|
|
|
arg = hlo.dynamic_broadcast_in_dim(
|
2022-06-29 13:55:30 -07:00
|
|
|
|
mlir.aval_to_ir_type(aval_out), arg,
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.shape_tensor(aval_out.shape), dims)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
else:
|
2023-11-17 11:46:24 -08:00
|
|
|
|
arg = hlo.broadcast_in_dim(
|
2022-06-29 13:55:30 -07:00
|
|
|
|
mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg,
|
2023-11-17 11:46:24 -08:00
|
|
|
|
dims)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
out.append(arg)
|
|
|
|
|
return out
|
|
|
|
|
|
2024-08-29 10:49:30 -07:00
|
|
|
|
def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
|
|
|
|
|
out = []
|
|
|
|
|
for op, in_aval in zip(ops, in_avals):
|
|
|
|
|
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
|
|
|
|
|
out.append(op)
|
|
|
|
|
else:
|
2025-02-22 10:45:18 -08:00
|
|
|
|
out.append(mlir.lower_with_sharding_in_types(ctx, op, out_aval))
|
2024-08-29 10:49:30 -07:00
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def _nary_lower_hlo(
|
|
|
|
|
op: Callable, ctx, *args: ir.Value, accuracy=None, **params
|
|
|
|
|
) -> Sequence[ir.Value]:
|
2022-12-15 20:59:34 -08:00
|
|
|
|
"""Lowers an elementwise operator to its MLIR equivalent.
|
2021-11-23 18:57:45 -08:00
|
|
|
|
"""
|
|
|
|
|
del params
|
2022-06-29 13:55:30 -07:00
|
|
|
|
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
2025-02-13 18:05:27 +00:00
|
|
|
|
args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape)
|
2025-02-19 06:52:52 -08:00
|
|
|
|
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
|
2024-10-07 18:00:56 -07:00
|
|
|
|
out = op(*args)
|
2025-03-27 17:12:08 -07:00
|
|
|
|
if accuracy:
|
|
|
|
|
out = op(*args, result_accuracy=accuracy_attr(accuracy))
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
_float = {np.floating}
|
|
|
|
|
_complex = {np.complexfloating}
|
|
|
|
|
_complex_elem_types = {np.float32, np.float64}
|
|
|
|
|
_int = {np.integer}
|
|
|
|
|
_bool = {np.bool_}
|
2023-10-05 10:53:08 -07:00
|
|
|
|
_signedint = {np.signedinteger}
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
_num = _int | _float | _complex
|
|
|
|
|
_any = _int | _float | _complex | _bool
|
|
|
|
|
_bool_or_int = _int | _bool
|
2022-06-01 16:32:10 -04:00
|
|
|
|
_ordered = _int | _float | _bool
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
neg_p = standard_unop(_num, 'neg')
|
2020-12-30 17:42:04 -08:00
|
|
|
|
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.negate))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
sign_p = standard_unop(_num, 'sign')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(sign_p)
|
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
|
def _sign_lower_hlo(ctx, x):
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
x_aval, = ctx.avals_in
|
2021-11-23 18:57:45 -08:00
|
|
|
|
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return [hlo.select(
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
|
2023-11-17 11:46:24 -08:00
|
|
|
|
'UNSIGNED'),
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
mlir.full_like_aval(ctx, 0, x_aval),
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.full_like_aval(ctx, 1, x_aval))]
|
|
|
|
|
return [hlo.sign(x)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(sign_p, _sign_lower_hlo)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
nextafter_p = standard_naryop([_float, _float], 'nextafter')
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.next_after))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
floor_p = standard_unop(_float, 'floor')
|
|
|
|
|
ad.defjvp_zero(floor_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.floor))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
ceil_p = standard_unop(_float, 'ceil')
|
|
|
|
|
ad.defjvp_zero(ceil_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.ceil))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
round_p = standard_unop(_float, 'round')
|
|
|
|
|
ad.defjvp_zero(round_p)
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _round_lower(ctx, x, *, rounding_method):
|
2021-11-23 18:57:45 -08:00
|
|
|
|
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return [hlo.round_nearest_afz(x)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
else:
|
|
|
|
|
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return [hlo.round_nearest_even(x)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(round_p, _round_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
|
|
|
|
|
ad.defjvp_zero(is_finite_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
exp_p = standard_unop(_float | _complex, 'exp')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-07-28 09:58:25 -07:00
|
|
|
|
exp2_p = standard_unop(_float | _complex, 'exp2')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(
|
|
|
|
|
exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _exp2_lower(ctx, x, accuracy):
|
2023-07-28 09:58:25 -07:00
|
|
|
|
x_aval, = ctx.avals_in
|
|
|
|
|
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
|
|
|
|
|
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return [
|
|
|
|
|
hlo.exponential(
|
|
|
|
|
hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy)
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
|
2023-07-28 09:58:25 -07:00
|
|
|
|
mlir.register_lowering(exp2_p, _exp2_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
log_p = standard_unop(_float | _complex, 'log')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
expm1_p = standard_unop(_float | _complex, 'expm1')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(expm1_p,
|
|
|
|
|
partial(_nary_lower_hlo, hlo.exponential_minus_one))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
log1p_p = standard_unop(_float | _complex, 'log1p')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
tanh_p = standard_unop(_float | _complex, 'tanh')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(
|
|
|
|
|
tanh_p,
|
|
|
|
|
lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)),
|
|
|
|
|
)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-07 06:06:22 -07:00
|
|
|
|
logistic_p = standard_unop(_float | _complex, 'logistic')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(
|
|
|
|
|
logistic_p,
|
|
|
|
|
lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))),
|
|
|
|
|
)
|
2022-12-15 20:59:34 -08:00
|
|
|
|
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
|
2023-11-17 11:46:24 -08:00
|
|
|
|
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic))
|
2022-09-07 06:06:22 -07:00
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
|
|
|
|
|
def logistic_impl(x, accuracy):
|
2022-09-07 06:06:22 -07:00
|
|
|
|
one = _const(x, 1)
|
|
|
|
|
return div(one, add(one, exp(neg(x))))
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(logistic_p,
|
|
|
|
|
mlir.lower_fun(logistic_impl, multiple_results=False))
|
|
|
|
|
|
2024-02-15 13:29:35 +02:00
|
|
|
|
def _sin_complex(x):
|
|
|
|
|
# use expm1 instead of exp to avoid cancellation when abs(x) is small
|
|
|
|
|
# relies on the quality of real-valued expm1, sin, cos
|
|
|
|
|
# sin(x) = complex(sin(real(x)) * cosh(imag(x)), cos(real(x)) * sinh(imag(x)))
|
|
|
|
|
# 2 * sinh(x) = exp(x) - 1 - (exp(-x) - 1) = expm1(x) - expm1(-x)
|
|
|
|
|
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
|
|
|
|
|
a, b = real(x), imag(x)
|
|
|
|
|
a_is_zero = eq(a, _const(a, 0))
|
2025-03-10 11:03:52 -04:00
|
|
|
|
two = _const(a, 2)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
sn, cs = sin(a), cos(a)
|
2025-03-10 11:03:52 -04:00
|
|
|
|
e1m, e2m = expm1(b), expm1(neg(b))
|
|
|
|
|
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
|
|
|
|
|
re, im = mul(sn, csh), mul(cs, snh)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
|
|
|
|
|
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
|
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def _sin_lowering(ctx, x, accuracy):
|
2024-02-15 13:29:35 +02:00
|
|
|
|
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
|
|
|
|
sine = mlir.lower_fun(_sin_complex, multiple_results=False)
|
|
|
|
|
return sine(ctx, x)
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
|
|
|
|
|
def _sin_p_lin(nzs, x, accuracy):
|
2024-11-22 14:15:46 -08:00
|
|
|
|
nz, = nzs
|
2024-11-21 17:46:21 -08:00
|
|
|
|
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return (
|
|
|
|
|
sin_p.bind(x, accuracy=accuracy),
|
|
|
|
|
nz,
|
|
|
|
|
cos_x,
|
|
|
|
|
lambda cos_x_, t: mul(t, cos_x_),
|
|
|
|
|
)
|
2024-11-21 17:46:21 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
sin_p = standard_unop(_float | _complex, 'sin')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy)))
|
|
|
|
|
ad.primitive_linearizations[sin_p] = _sin_p_lin
|
2024-02-15 13:29:35 +02:00
|
|
|
|
mlir.register_lowering(sin_p, _sin_lowering)
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
|
|
|
|
|
|
2024-02-15 13:29:35 +02:00
|
|
|
|
def _cos_complex(x):
|
|
|
|
|
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
|
|
|
|
|
# see also _sin_complex
|
|
|
|
|
a, b = real(x), imag(x)
|
|
|
|
|
a_is_zero = eq(a, _const(a, 0))
|
2025-03-10 11:03:52 -04:00
|
|
|
|
two = _const(a, 2)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
sn, cs = sin(a), cos(a)
|
2025-03-10 11:03:52 -04:00
|
|
|
|
e1m, e2m = expm1(b), expm1(neg(b))
|
|
|
|
|
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
|
|
|
|
|
re, im = mul(cs, csh), mul(neg(sn), snh)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
|
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def _cos_lowering(ctx, x, accuracy):
|
2024-02-15 13:29:35 +02:00
|
|
|
|
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
|
|
|
|
cosine = mlir.lower_fun(_cos_complex, multiple_results=False)
|
|
|
|
|
return cosine(ctx, x)
|
2025-03-27 17:12:08 -07:00
|
|
|
|
return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
cos_p = standard_unop(_float | _complex, 'cos')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp(
|
|
|
|
|
cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy)))
|
|
|
|
|
)
|
2024-02-15 13:29:35 +02:00
|
|
|
|
mlir.register_lowering(cos_p, _cos_lowering)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
tan_p = standard_unop(_float | _complex, 'tan')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans))))
|
2024-10-29 16:08:23 -07:00
|
|
|
|
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
|
2020-12-07 17:34:27 -05:00
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
asin_p = standard_unop(_float | _complex, 'asin')
|
2025-03-10 11:03:52 -04:00
|
|
|
|
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x)))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
|
2020-12-07 17:34:27 -05:00
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
acos_p = standard_unop(_float | _complex, 'acos')
|
2025-03-10 11:03:52 -04:00
|
|
|
|
ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x))))))
|
2024-10-25 13:19:50 -07:00
|
|
|
|
mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos))
|
2020-12-07 17:34:27 -05:00
|
|
|
|
|
2022-04-06 12:53:19 -07:00
|
|
|
|
def atan_impl(x):
|
2021-07-15 16:39:18 -04:00
|
|
|
|
return atan2(x, _const(x, 1))
|
2020-12-07 17:34:27 -05:00
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
atan_p = standard_unop(_float | _complex, 'atan')
|
2025-03-10 11:03:52 -04:00
|
|
|
|
ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
|
2020-12-07 17:34:27 -05:00
|
|
|
|
|
2021-07-02 16:18:26 -07:00
|
|
|
|
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp(atan2_p,
|
2025-03-10 11:03:52 -04:00
|
|
|
|
lambda g, x, y: mul(g, div(y, add(square(x), square(y)))),
|
|
|
|
|
lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y)))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
sinh_p = standard_unop(_float | _complex, 'sinh')
|
|
|
|
|
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.sinh))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
cosh_p = standard_unop(_float | _complex, 'cosh')
|
|
|
|
|
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
asinh_p = standard_unop(_float | _complex, 'asinh')
|
2025-03-10 11:03:52 -04:00
|
|
|
|
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x)))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
acosh_p = standard_unop(_float | _complex, 'acosh')
|
|
|
|
|
ad.defjvp(acosh_p,
|
2025-03-10 11:03:52 -04:00
|
|
|
|
lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x))))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
atanh_p = standard_unop(_float | _complex, 'atanh')
|
|
|
|
|
ad.defjvp(atanh_p,
|
2025-03-10 11:03:52 -04:00
|
|
|
|
lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
real_p = unop(_complex_basetype, _complex, 'real')
|
2020-12-30 17:42:04 -08:00
|
|
|
|
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.real))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
imag_p = unop(_complex_basetype, _complex, 'imag')
|
2020-12-30 17:42:04 -08:00
|
|
|
|
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.imag))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
|
|
|
|
|
def _complex_transpose_rule(t, x, y):
|
|
|
|
|
assert ad.is_undefined_primal(x) or ad.is_undefined_primal(y)
|
|
|
|
|
if ad.is_undefined_primal(x) and ad.is_undefined_primal(y):
|
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(x.aval), ad_util.Zero(y.aval)]
|
|
|
|
|
else:
|
|
|
|
|
return [_unbroadcast(x.aval, real(t)), _unbroadcast(y.aval, imag(neg(t)))]
|
|
|
|
|
elif ad.is_undefined_primal(x):
|
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(x.aval), None]
|
|
|
|
|
else:
|
|
|
|
|
return [_unbroadcast(x.aval, real(t)), None]
|
|
|
|
|
else:
|
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
|
return [None, ad_util.Zero(y.aval)]
|
|
|
|
|
else:
|
|
|
|
|
return [None, _unbroadcast(y.aval, imag(neg(t)))]
|
|
|
|
|
|
2025-02-21 04:50:35 -08:00
|
|
|
|
_complex_dtype = lambda dtype, *args, **kwargs: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
|
2020-10-17 14:33:26 -04:00
|
|
|
|
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
|
|
|
|
'complex')
|
2021-04-15 15:16:29 -07:00
|
|
|
|
ad.deflinear2(complex_p, _complex_transpose_rule)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.complex))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
|
|
|
|
|
|
2021-11-30 06:08:26 -08:00
|
|
|
|
def _conj_impl(x, **kw):
|
2021-11-23 18:57:45 -08:00
|
|
|
|
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
|
|
|
|
return complex(real(x), -imag(x))
|
|
|
|
|
else:
|
|
|
|
|
return complex(x, _zeros(x))
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(conj_p,
|
|
|
|
|
mlir.lower_fun(_conj_impl, multiple_results=False))
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _conj_transpose_rule(t, x, *, input_dtype):
|
|
|
|
|
assert ad.is_undefined_primal(x)
|
2023-04-04 20:45:21 -07:00
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(x.aval)]
|
|
|
|
|
elif dtypes.issubdtype(input_dtype, np.complexfloating):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return [conj(t)]
|
|
|
|
|
else:
|
|
|
|
|
return [real(t)]
|
|
|
|
|
|
|
|
|
|
ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
|
|
|
|
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
|
|
|
|
|
2023-10-05 10:53:08 -07:00
|
|
|
|
abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs')
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _abs_jvp_rule(g, ans, x):
|
|
|
|
|
if _iscomplex(x):
|
|
|
|
|
return _maybe_real(mul(g, div(_maybe_conj(x),
|
|
|
|
|
_replace_zero(convert_element_type(ans, _dtype(x))))))
|
|
|
|
|
else:
|
|
|
|
|
return select(ge(x, _zero(x)), g, neg(g))
|
|
|
|
|
ad.defjvp2(abs_p, _abs_jvp_rule)
|
|
|
|
|
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
|
|
|
|
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
|
|
|
|
|
|
|
|
|
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans)))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(
|
|
|
|
|
rsqrt_p,
|
|
|
|
|
lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))),
|
|
|
|
|
)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-11-18 14:55:19 -05:00
|
|
|
|
cbrt_p = standard_unop(_float, 'cbrt')
|
2025-03-27 17:12:08 -07:00
|
|
|
|
ad.defjvp2(
|
|
|
|
|
cbrt_p,
|
|
|
|
|
lambda g, ans, x, **kwargs: mul(
|
|
|
|
|
g, mul(_const(x, 1 / 3), integer_pow(ans, -2))
|
|
|
|
|
),
|
|
|
|
|
)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt))
|
2021-07-22 14:00:52 -07:00
|
|
|
|
|
2024-11-13 11:14:16 +02:00
|
|
|
|
square_p = standard_unop(_int | _float | _complex, 'square')
|
|
|
|
|
|
|
|
|
|
def _square_complex(x):
|
|
|
|
|
a, b = real(x), imag(x)
|
|
|
|
|
# zero square(x).real is handled explicitly for abs(a)==abs(b) cases
|
|
|
|
|
# where for finite a, 2 * a is non-finite:
|
2025-03-10 11:03:52 -04:00
|
|
|
|
zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b)))
|
2024-11-13 11:14:16 +02:00
|
|
|
|
# equivalent to a**2 - b**2 but avoids overflow errors for large a
|
|
|
|
|
# and large b cases:
|
2025-03-10 11:03:52 -04:00
|
|
|
|
re = mul(sub(a, b), add(a, b))
|
|
|
|
|
im = mul(mul(a, b), _const(a, 2))
|
2024-11-13 11:14:16 +02:00
|
|
|
|
return select(zero_re, complex(_const(a, 0), im), complex(re, im))
|
|
|
|
|
|
|
|
|
|
def _square_lower_hlo(ctx, x):
|
|
|
|
|
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
|
|
|
|
return mlir.lower_fun(_square_complex, multiple_results=False)(ctx, x)
|
|
|
|
|
return [hlo.multiply(x, x)]
|
|
|
|
|
|
|
|
|
|
ad.defjvp2(square_p, lambda g, ans, x: mul(g, mul(_const(x, 2), x)))
|
|
|
|
|
mlir.register_lowering(square_p, _square_lower_hlo) # TODO(pearu): use chlo.square
|
|
|
|
|
|
2023-06-14 18:30:52 -07:00
|
|
|
|
def _pow_dtype_rule(x, y):
|
|
|
|
|
if (dtypes.issubdtype(x.dtype, np.inexact) and
|
|
|
|
|
dtypes.issubdtype(y.dtype, np.integer)):
|
|
|
|
|
return x.dtype
|
|
|
|
|
if x.dtype == y.dtype:
|
|
|
|
|
return x.dtype
|
|
|
|
|
raise TypeError("the first argument to pow must have an inexact dtype (float "
|
|
|
|
|
"or complex), and the second argument must have an inexact or"
|
|
|
|
|
" integer dtype, and two inexact dtypes must match, but got "
|
|
|
|
|
f"{x.dtype} and {y.dtype} respectively.")
|
|
|
|
|
pow_p = naryop(_pow_dtype_rule, [_float | _complex, _int | _float | _complex],
|
|
|
|
|
'pow', require_same_dtypes=False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _pow_jvp_lhs(g, ans, x, y):
|
2023-06-14 18:30:52 -07:00
|
|
|
|
y_dtype = dtypes.dtype(y)
|
2024-09-01 07:49:49 -07:00
|
|
|
|
result_dtype = dtypes.result_type(x, y)
|
|
|
|
|
if result_dtype == bool:
|
|
|
|
|
result_dtype = 'int32'
|
|
|
|
|
x = convert_element_type(x, result_dtype)
|
|
|
|
|
y = convert_element_type(y, result_dtype)
|
2023-06-14 18:30:52 -07:00
|
|
|
|
if dtypes.issubdtype(y_dtype, np.integer):
|
2023-08-25 10:05:55 -07:00
|
|
|
|
if x.shape != y.shape:
|
|
|
|
|
shape = broadcast_shapes(x.shape, y.shape)
|
|
|
|
|
x = _maybe_broadcast(shape, x)
|
|
|
|
|
y = _maybe_broadcast(shape, y)
|
2023-10-06 17:53:31 -07:00
|
|
|
|
jac = select(eq(y, _const(y, 0)), _zeros(y),
|
2023-06-14 18:30:52 -07:00
|
|
|
|
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
|
|
|
|
|
else:
|
|
|
|
|
jac = mul(y, pow(x, sub(y, _ones(y))))
|
|
|
|
|
return mul(g, jac)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _pow_jvp_rhs(g, ans, x, y):
|
2023-06-14 18:30:52 -07:00
|
|
|
|
y_dtype = dtypes.dtype(y)
|
|
|
|
|
assert dtypes.issubdtype(y_dtype, np.inexact)
|
|
|
|
|
return convert_element_type(mul(g, mul(log(_replace_zero(x)), ans)), y_dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
2023-06-14 18:30:52 -07:00
|
|
|
|
|
|
|
|
|
def _pow_lower(ctx, x, y):
|
|
|
|
|
x_aval, y_aval = ctx.avals_in
|
2024-11-16 14:29:20 -08:00
|
|
|
|
if x_aval.dtype != y_aval.dtype:
|
|
|
|
|
out_aval, = ctx.avals_out
|
|
|
|
|
y_aval = y_aval.update(dtype=out_aval.dtype)
|
|
|
|
|
y = hlo.convert(mlir.aval_to_ir_type(y_aval), y)
|
|
|
|
|
ctx = ctx.replace(avals_in=[x_aval, y_aval])
|
|
|
|
|
return _nary_lower_hlo(hlo.power, ctx, x, y)
|
2023-06-14 18:30:52 -07:00
|
|
|
|
mlir.register_lowering(pow_p, _pow_lower)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _integer_pow_dtype_rule(x, *, y):
|
|
|
|
|
dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x)
|
|
|
|
|
if y < 0 and dtypes.issubdtype(dtype, np.integer):
|
|
|
|
|
raise TypeError("Integers cannot be raised to negative powers, got "
|
|
|
|
|
f"integer_pow({x}, {y})")
|
|
|
|
|
return dtype
|
|
|
|
|
|
|
|
|
|
def _integer_pow_jvp(g, x, *, y):
|
2021-01-19 15:42:40 -08:00
|
|
|
|
return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
integer_pow_p = standard_primitive(
|
2024-10-17 15:54:42 -07:00
|
|
|
|
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
batching.defvectorized(integer_pow_p)
|
|
|
|
|
ad.defjvp(integer_pow_p, _integer_pow_jvp)
|
2022-08-05 07:37:55 -07:00
|
|
|
|
pe.def_trivial_padding(integer_pow_p)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
def _integer_pow(x, *, y):
|
|
|
|
|
# This should be kept in sync with the jax2tf translation rule.
|
|
|
|
|
if y == 0:
|
|
|
|
|
return full_like(x, 1)
|
|
|
|
|
is_reciprocal = y < 0
|
|
|
|
|
if is_reciprocal:
|
|
|
|
|
y = -y
|
|
|
|
|
acc = None
|
|
|
|
|
while y > 0:
|
|
|
|
|
if y & 1:
|
|
|
|
|
acc = x if acc is None else mul(acc, x)
|
|
|
|
|
y >>= 1
|
|
|
|
|
if y > 0:
|
|
|
|
|
# We don't call square because it calls integer_pow.
|
|
|
|
|
x = mul(x, x)
|
|
|
|
|
return div(full_like(acc, 1), acc) if is_reciprocal else acc
|
|
|
|
|
|
2022-02-02 10:59:56 -05:00
|
|
|
|
|
|
|
|
|
def _integer_pow_lowering(ctx, x, *, y):
|
2024-02-22 05:24:00 -08:00
|
|
|
|
# These cases are subsumed by the general case, but it's faster to emit these
|
|
|
|
|
# common cases directly.
|
2024-11-14 08:17:10 -08:00
|
|
|
|
if y == 1:
|
|
|
|
|
out = x
|
|
|
|
|
elif y == 2:
|
2024-10-17 15:54:42 -07:00
|
|
|
|
out = hlo.multiply(x, x)
|
2024-02-22 05:24:00 -08:00
|
|
|
|
elif y == 3:
|
2024-10-17 15:54:42 -07:00
|
|
|
|
out = hlo.multiply(hlo.multiply(x, x), x)
|
2024-11-14 08:17:10 -08:00
|
|
|
|
elif y == -1:
|
|
|
|
|
out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x)
|
2024-02-22 05:24:00 -08:00
|
|
|
|
else:
|
|
|
|
|
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
|
2024-11-14 08:17:10 -08:00
|
|
|
|
if builtins.abs(y) >= 3:
|
|
|
|
|
lowering = mlir.cache_lowering(lowering)
|
|
|
|
|
out, = lowering(ctx, x, y=y)
|
2025-02-19 06:52:52 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2022-02-02 10:59:56 -05:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
|
|
|
|
|
|
|
|
|
not_p = standard_unop(_bool_or_int, 'not')
|
|
|
|
|
ad.defjvp_zero(not_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.not_))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
|
|
|
|
ad.defjvp_zero(and_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.and_))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
|
|
|
|
|
ad.defjvp_zero(or_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.or_))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
|
|
|
|
ad.defjvp_zero(xor_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.xor))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
population_count_p = standard_unop(_int, 'population_count')
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(population_count_p, partial(_nary_lower_hlo, hlo.popcnt))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-03-19 22:35:31 -07:00
|
|
|
|
clz_p = standard_unop(_int, 'clz')
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.count_leading_zeros))
|
2021-03-19 22:35:31 -07:00
|
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
def _add_jvp(primals, tangents):
|
|
|
|
|
x, y = primals
|
|
|
|
|
xdot, ydot = tangents
|
|
|
|
|
primal_out = add(x, y)
|
|
|
|
|
if type(xdot) is type(ydot) is ad_util.Zero:
|
2024-09-18 13:43:14 -07:00
|
|
|
|
return primal_out, ad_util.Zero.from_primal_value(primal_out)
|
2021-04-15 15:16:29 -07:00
|
|
|
|
if type(xdot) is ad_util.Zero:
|
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, ydot)
|
|
|
|
|
elif type(ydot) is ad_util.Zero:
|
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, xdot)
|
|
|
|
|
else:
|
|
|
|
|
return primal_out, add(xdot, ydot)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _add_transpose(t, x, y):
|
2021-04-15 15:16:29 -07:00
|
|
|
|
# Morally the following assertion is true, but because we instantiate zeros in
|
|
|
|
|
# some places (e.g. in custom_jvp) it may not always hold. For example, see
|
|
|
|
|
# api_test.py's CustomJVPTest.test_jaxpr_zeros.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
2024-12-12 14:08:17 -08:00
|
|
|
|
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
|
|
|
|
|
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
|
2021-04-15 15:16:29 -07:00
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
|
|
|
|
else:
|
|
|
|
|
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-07-26 13:44:57 +01:00
|
|
|
|
# TODO(slebedev): Why does mypy fail to infer the type here?
|
|
|
|
|
add_p: Primitive = standard_naryop([_num, _num], 'add')
|
2021-04-15 15:16:29 -07:00
|
|
|
|
ad.primitive_jvps[add_p] = _add_jvp
|
|
|
|
|
ad.primitive_transposes[add_p] = _add_transpose
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add))
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[add_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
def _sub_jvp(primals, tangents):
|
|
|
|
|
x, y = primals
|
|
|
|
|
xdot, ydot = tangents
|
|
|
|
|
primal_out = sub(x, y)
|
|
|
|
|
if type(xdot) is type(ydot) is ad_util.Zero:
|
2024-09-18 13:43:14 -07:00
|
|
|
|
return primal_out, ad_util.Zero.from_primal_value(primal_out)
|
2021-04-15 15:16:29 -07:00
|
|
|
|
if type(xdot) is ad_util.Zero:
|
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot))
|
|
|
|
|
elif type(ydot) is ad_util.Zero:
|
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, xdot)
|
|
|
|
|
else:
|
|
|
|
|
return primal_out, sub(xdot, ydot)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _sub_transpose(t, x, y):
|
2021-04-15 15:16:29 -07:00
|
|
|
|
# Morally the following assertion is true, but see the comment in add_p's
|
|
|
|
|
# transpose rule.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
2024-12-12 14:08:17 -08:00
|
|
|
|
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
|
|
|
|
|
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
|
2020-11-28 09:13:21 -08:00
|
|
|
|
if type(t) is ad_util.Zero:
|
2021-04-15 15:16:29 -07:00
|
|
|
|
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
2020-11-28 09:13:21 -08:00
|
|
|
|
else:
|
2021-04-15 15:16:29 -07:00
|
|
|
|
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, neg(t))]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
sub_p = standard_naryop([_num, _num], 'sub')
|
2021-04-15 15:16:29 -07:00
|
|
|
|
ad.primitive_jvps[sub_p] = _sub_jvp
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.primitive_transposes[sub_p] = _sub_transpose
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract))
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
|
|
|
|
|
def _mul_transpose(ct, x, y):
|
|
|
|
|
assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y)
|
|
|
|
|
if ad.is_undefined_primal(x):
|
|
|
|
|
if type(ct) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(x.aval), None]
|
|
|
|
|
else:
|
|
|
|
|
return [_unbroadcast(x.aval, mul(ct, y)), None]
|
|
|
|
|
else:
|
|
|
|
|
if type(ct) is ad_util.Zero:
|
|
|
|
|
return [None, ad_util.Zero(y.aval)]
|
|
|
|
|
else:
|
|
|
|
|
return [None, _unbroadcast(y.aval, mul(x, ct))]
|
|
|
|
|
|
2024-06-17 05:00:18 -07:00
|
|
|
|
mul_p = standard_naryop([_num, _num], 'mul')
|
2021-04-15 15:16:29 -07:00
|
|
|
|
ad.defjvp(mul_p,
|
|
|
|
|
lambda xdot, x, y: mul(xdot, y),
|
|
|
|
|
lambda ydot, x, y: mul(x, ydot))
|
|
|
|
|
ad.primitive_transposes[mul_p] = _mul_transpose
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply))
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _div_transpose_rule(cotangent, x, y):
|
|
|
|
|
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
2021-04-15 15:16:29 -07:00
|
|
|
|
if type(cotangent) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(x.aval), None]
|
|
|
|
|
else:
|
|
|
|
|
return [_unbroadcast(x.aval, div(cotangent, y)), None]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
div_p = standard_naryop([_num, _num], 'div')
|
|
|
|
|
ad.defjvp(div_p,
|
2021-04-15 15:16:29 -07:00
|
|
|
|
lambda g, x, y: div(g, y),
|
|
|
|
|
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.primitive_transposes[div_p] = _div_transpose_rule
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide))
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-06-01 16:32:10 -04:00
|
|
|
|
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
|
2021-04-15 15:16:29 -07:00
|
|
|
|
ad.defjvp(
|
|
|
|
|
rem_p,
|
|
|
|
|
lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
|
2022-07-12 09:50:42 -07:00
|
|
|
|
lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y))))))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-06-04 11:02:50 +03:00
|
|
|
|
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
|
|
|
|
|
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
|
|
|
|
|
x = _maybe_broadcast(result_shape, x)
|
|
|
|
|
y = _maybe_broadcast(result_shape, y)
|
|
|
|
|
rx = real(x)
|
|
|
|
|
ry = real(y)
|
|
|
|
|
pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)),
|
|
|
|
|
lax_cmp_pick_x(rx, ry))
|
|
|
|
|
return select(pick_x, x, y)
|
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
max_p: core.Primitive = standard_naryop([_any, _any], 'max')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp2(max_p,
|
2021-04-15 15:16:29 -07:00
|
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
|
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo))
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp2(min_p,
|
2021-04-15 15:16:29 -07:00
|
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
|
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
|
|
|
|
ad.defjvp_zero(shift_left_p)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.shift_left))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
|
|
|
|
|
ad.defjvp_zero(shift_right_arithmetic_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(shift_right_arithmetic_p,
|
2023-11-17 11:46:24 -08:00
|
|
|
|
partial(_nary_lower_hlo, hlo.shift_right_arithmetic))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
|
|
|
|
|
ad.defjvp_zero(shift_right_logical_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(shift_right_logical_p,
|
2023-11-17 11:46:24 -08:00
|
|
|
|
partial(_nary_lower_hlo, hlo.shift_right_logical))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2023-05-12 16:48:26 -07:00
|
|
|
|
def _opaque_comparison_hlo(direction, reduction_op, identity, ctx,
|
|
|
|
|
avals_in, aval_out, x, y):
|
|
|
|
|
aval_x, aval_y = avals_in
|
|
|
|
|
base_aval_x = core.physical_aval(aval_x)
|
|
|
|
|
base_aval_y = core.physical_aval(aval_y)
|
|
|
|
|
base_aval_out = core.ShapedArray(base_aval_x.shape, aval_out.dtype)
|
|
|
|
|
reduce_axes = tuple(range(aval_out.ndim, base_aval_out.ndim))
|
|
|
|
|
res, = mlir.delegate_lowering(
|
2023-10-13 12:20:22 -07:00
|
|
|
|
ctx, partial(_compare_lower_hlo, direction, False),
|
2023-05-12 16:48:26 -07:00
|
|
|
|
x, y, avals_in=[base_aval_x, base_aval_y], avals_out=[base_aval_out])
|
|
|
|
|
return mlir.delegate_lowering(
|
|
|
|
|
ctx, partial(_unary_reduce_lower, reduction_op, identity,
|
|
|
|
|
axes=reduce_axes),
|
|
|
|
|
res, avals_in=[base_aval_out], avals_out=[aval_out])
|
|
|
|
|
|
|
|
|
|
_opaque_eq_hlo = partial(
|
|
|
|
|
_opaque_comparison_hlo, 'EQ', hlo.AndOp, _get_bitwise_and_identity)
|
|
|
|
|
_opaque_ne_hlo = partial(
|
|
|
|
|
_opaque_comparison_hlo, 'NE', hlo.OrOp, _get_bitwise_or_identity)
|
|
|
|
|
|
|
|
|
|
def _compare_lower_hlo_opaque(direction: str, ctx, avals_in, aval_out, x, y):
|
|
|
|
|
broadcast_avals_in = tuple(
|
|
|
|
|
core.ShapedArray(aval_out.shape, aval.dtype) for aval in avals_in)
|
|
|
|
|
if direction == 'EQ':
|
|
|
|
|
return _opaque_eq_hlo(ctx, broadcast_avals_in, aval_out, x, y)
|
|
|
|
|
elif direction == 'NE':
|
|
|
|
|
return _opaque_ne_hlo(ctx, broadcast_avals_in, aval_out, x, y)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(
|
2023-07-24 14:29:37 -07:00
|
|
|
|
f"HLO comparison {direction} for extended dtype {avals_in[0].dtype}")
|
2023-05-12 16:48:26 -07:00
|
|
|
|
|
2023-10-13 12:20:22 -07:00
|
|
|
|
|
|
|
|
|
def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
|
2022-07-07 16:44:00 -07:00
|
|
|
|
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
x_dtype = avals_in[0].dtype
|
|
|
|
|
x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape)
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(x_dtype, dtypes.extended):
|
2023-10-13 12:20:22 -07:00
|
|
|
|
assert not total_order
|
2023-05-12 16:48:26 -07:00
|
|
|
|
return _compare_lower_hlo_opaque(direction, ctx, avals_in, aval_out, x, y)
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
if dtypes.issubdtype(x_dtype, np.inexact):
|
2023-10-13 12:20:22 -07:00
|
|
|
|
compare_type = "TOTALORDER" if total_order else "FLOAT"
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
elif dtypes.issubdtype(x_dtype, np.signedinteger):
|
2021-11-23 18:57:45 -08:00
|
|
|
|
compare_type = "SIGNED"
|
|
|
|
|
else:
|
|
|
|
|
compare_type = "UNSIGNED"
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return [mlir.compare_hlo(x, y, direction, compare_type)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-07-24 14:29:37 -07:00
|
|
|
|
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(eq_p)
|
2023-10-13 12:20:22 -07:00
|
|
|
|
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False))
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[eq_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-07-24 14:29:37 -07:00
|
|
|
|
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(ne_p)
|
2023-10-13 12:20:22 -07:00
|
|
|
|
mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE", False))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-06-01 16:32:10 -04:00
|
|
|
|
ge_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'ge')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(ge_p)
|
2023-10-13 12:20:22 -07:00
|
|
|
|
mlir.register_lowering(ge_p, partial(_compare_lower_hlo, "GE", False))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-06-01 16:32:10 -04:00
|
|
|
|
gt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'gt')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(gt_p)
|
2023-10-13 12:20:22 -07:00
|
|
|
|
mlir.register_lowering(gt_p, partial(_compare_lower_hlo, "GT", False))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-06-01 16:32:10 -04:00
|
|
|
|
le_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(le_p)
|
2023-10-13 12:20:22 -07:00
|
|
|
|
mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE", False))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-06-01 16:32:10 -04:00
|
|
|
|
lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(lt_p)
|
2023-10-13 12:20:22 -07:00
|
|
|
|
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False))
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule
|
2023-10-13 12:20:22 -07:00
|
|
|
|
|
|
|
|
|
eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to')
|
|
|
|
|
ad.defjvp_zero(eq_to_p)
|
|
|
|
|
mlir.register_lowering(eq_to_p, partial(_compare_lower_hlo, "EQ", True))
|
|
|
|
|
|
|
|
|
|
le_to_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le_to')
|
|
|
|
|
ad.defjvp_zero(le_to_p)
|
|
|
|
|
mlir.register_lowering(le_to_p, partial(_compare_lower_hlo, "LE", True))
|
|
|
|
|
|
|
|
|
|
lt_to_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt_to')
|
|
|
|
|
ad.defjvp_zero(lt_to_p)
|
|
|
|
|
mlir.register_lowering(lt_to_p, partial(_compare_lower_hlo, "LT", True))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2024-07-09 07:32:38 -07:00
|
|
|
|
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type,
|
|
|
|
|
sharding):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return operand.shape
|
|
|
|
|
|
2024-08-29 10:49:30 -07:00
|
|
|
|
def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
|
|
|
|
|
sharding):
|
2025-01-28 11:04:05 -08:00
|
|
|
|
if sharding is None:
|
|
|
|
|
return operand.sharding
|
2025-01-29 09:33:44 -08:00
|
|
|
|
if sharding._is_concrete:
|
2025-01-28 11:04:05 -08:00
|
|
|
|
if isinstance(sharding, NamedSharding):
|
|
|
|
|
return NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)
|
|
|
|
|
else:
|
2025-02-03 17:59:44 -08:00
|
|
|
|
return core.get_cur_mesh_sharding()
|
2024-08-29 10:49:30 -07:00
|
|
|
|
return sharding
|
|
|
|
|
|
2024-07-09 07:32:38 -07:00
|
|
|
|
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
|
|
|
|
|
sharding):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return new_dtype
|
|
|
|
|
|
2024-07-09 07:32:38 -07:00
|
|
|
|
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type,
|
|
|
|
|
sharding):
|
2021-02-08 13:37:25 -08:00
|
|
|
|
return weak_type
|
|
|
|
|
|
2024-07-09 07:32:38 -07:00
|
|
|
|
def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type,
|
|
|
|
|
sharding):
|
2020-12-03 11:49:43 -08:00
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
|
old_dtype = operand.aval.dtype
|
2021-02-08 13:37:25 -08:00
|
|
|
|
old_weak_type = dtypes.is_weakly_typed(operand)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if type(ct) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(operand.aval)]
|
2022-06-23 11:46:20 -07:00
|
|
|
|
elif core.primal_dtype_to_tangent_dtype(old_dtype) == dtypes.float0:
|
2021-01-27 15:13:30 -08:00
|
|
|
|
return [ad_util.Zero(operand.aval.update(dtype=dtypes.float0, weak_type=False))]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2024-07-09 07:32:38 -07:00
|
|
|
|
return [convert_element_type_p.bind(
|
|
|
|
|
ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
def _convert_element_type_jvp_rule(tangent, primal_result, operand, *,
|
|
|
|
|
new_dtype, weak_type, sharding):
|
|
|
|
|
new_tangent_dtype = core.primal_dtype_to_tangent_dtype(new_dtype)
|
|
|
|
|
if new_tangent_dtype == dtypes.float0:
|
|
|
|
|
return ad_util.Zero.from_primal_value(primal_result)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
return convert_element_type_p.bind(tangent, new_dtype=new_tangent_dtype,
|
2024-07-09 07:32:38 -07:00
|
|
|
|
weak_type=weak_type, sharding=sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-11-15 21:21:29 -08:00
|
|
|
|
def _convert_elt_type_folding_rule(consts, eqn):
|
2022-04-24 21:04:06 -07:00
|
|
|
|
# We constant-fold convert_element_types applied to constants if those
|
|
|
|
|
# constants are Python builtin numeric types or numpy.ndarrays (so as not
|
|
|
|
|
# to perform any device operations when constant-folding) and if the output
|
|
|
|
|
# type can be faithfully represented by a Python builtin numeric type or
|
|
|
|
|
# numpy.ndarray. If those conditions are met, we output a numpy.ndarray
|
|
|
|
|
# constant if the output type is not weak, and if the output type is weak then
|
|
|
|
|
# we output a Python builtin numeric type.
|
|
|
|
|
# TODO(mattjj): allow constant-folding CPU-backed JAX arrays
|
2021-11-15 21:21:29 -08:00
|
|
|
|
c, = consts
|
2022-04-24 21:04:06 -07:00
|
|
|
|
o, = eqn.outvars
|
2024-12-20 15:43:03 -05:00
|
|
|
|
new_dtype = eqn.params['new_dtype']
|
2022-04-24 21:04:06 -07:00
|
|
|
|
if (type(c) in {np.ndarray, *dtypes.python_scalar_dtypes} and
|
2022-10-10 18:51:04 -07:00
|
|
|
|
isinstance(o.aval, core.UnshapedArray) and not np.shape(c) and
|
2024-12-20 15:43:03 -05:00
|
|
|
|
not dtypes.issubdtype(new_dtype, dtypes.extended)):
|
|
|
|
|
out = np.array(c)
|
|
|
|
|
if (dtypes.issubdtype(out.dtype, np.complexfloating) and
|
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
|
|
|
out = out.real
|
|
|
|
|
out = out.astype(new_dtype)
|
2022-04-24 21:04:06 -07:00
|
|
|
|
if not o.aval.weak_type:
|
2022-04-23 15:05:26 -07:00
|
|
|
|
return [out], None
|
2022-04-24 21:04:06 -07:00
|
|
|
|
out = out.item()
|
|
|
|
|
if core.get_aval(out).dtype is o.aval.dtype:
|
|
|
|
|
return [out], None
|
|
|
|
|
return [None], eqn
|
2021-11-15 21:21:29 -08:00
|
|
|
|
|
|
|
|
|
def _convert_elt_type_fwd_rule(eqn):
|
|
|
|
|
v, = eqn.invars
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if (not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended) and
|
|
|
|
|
not dtypes.issubdtype(v.aval.dtype, dtypes.extended) and
|
2022-10-10 18:51:04 -07:00
|
|
|
|
v.aval.dtype == eqn.params['new_dtype'] and
|
2021-11-15 21:21:29 -08:00
|
|
|
|
v.aval.weak_type == eqn.params['weak_type']):
|
|
|
|
|
return [v], None
|
|
|
|
|
else:
|
|
|
|
|
return [None], eqn
|
|
|
|
|
|
2022-03-09 12:20:28 -08:00
|
|
|
|
def _convert_elt_type_pp_rule(eqn, context, settings):
|
2023-02-09 11:02:24 -08:00
|
|
|
|
params = dict(eqn.params)
|
2024-07-09 07:32:38 -07:00
|
|
|
|
if params['sharding'] is None:
|
|
|
|
|
del params['sharding'] # don't show trivial case
|
2023-02-09 11:02:24 -08:00
|
|
|
|
return core._pp_eqn(eqn.replace(params=params), context, settings)
|
2022-02-04 19:12:57 -08:00
|
|
|
|
|
2021-11-15 21:21:29 -08:00
|
|
|
|
convert_element_type_p = Primitive('convert_element_type')
|
2024-10-29 11:03:49 -07:00
|
|
|
|
|
|
|
|
|
# TODO(dougalm): I'm overriding bind_with_trace here because that's the closest thing to
|
|
|
|
|
# the old "custom bind" but it might not be the best way to do this.
|
|
|
|
|
def _convert_element_type_bind_with_trace(trace, args, params):
|
|
|
|
|
sharding = params['sharding']
|
|
|
|
|
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
|
2025-01-29 09:33:44 -08:00
|
|
|
|
if sharding is not None and sharding._is_concrete:
|
2024-10-29 11:03:49 -07:00
|
|
|
|
with core.set_current_trace(trace):
|
|
|
|
|
operand = pjit.with_sharding_constraint(operand, sharding)
|
2024-07-09 07:32:38 -07:00
|
|
|
|
return operand
|
2024-10-29 11:03:49 -07:00
|
|
|
|
convert_element_type_p.def_bind_with_trace(_convert_element_type_bind_with_trace)
|
|
|
|
|
|
2023-03-27 13:29:59 -07:00
|
|
|
|
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
|
2021-03-21 13:39:57 -07:00
|
|
|
|
convert_element_type_p.def_abstract_eval(
|
|
|
|
|
partial(standard_abstract_eval, convert_element_type_p,
|
|
|
|
|
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
2024-08-29 10:49:30 -07:00
|
|
|
|
_convert_element_type_weak_type_rule,
|
2025-03-21 10:25:38 -07:00
|
|
|
|
_convert_element_type_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
partial(core.standard_vma_rule, convert_element_type_p.name)))
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
|
|
|
|
|
batching.defvectorized(convert_element_type_p)
|
2021-11-15 21:21:29 -08:00
|
|
|
|
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
|
|
|
|
|
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
|
2022-08-05 09:12:32 -07:00
|
|
|
|
pe.def_trivial_padding(convert_element_type_p)
|
2024-07-09 07:32:38 -07:00
|
|
|
|
core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[convert_element_type_p] = (
|
|
|
|
|
batching.ragged_mask_elementwise_rule
|
|
|
|
|
)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-12-07 07:12:08 -08:00
|
|
|
|
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
|
|
|
|
|
2024-07-09 07:32:38 -07:00
|
|
|
|
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
|
|
|
|
|
sharding):
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
aval_in, = ctx.avals_in
|
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
|
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
|
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
2023-11-17 11:46:24 -08:00
|
|
|
|
operand = hlo.real(operand)
|
2021-12-07 07:12:08 -08:00
|
|
|
|
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
2024-10-22 13:10:05 -07:00
|
|
|
|
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
def _to_edtype_abstract_eval(x, *, edtype):
|
|
|
|
|
assert (isinstance(edtype, dtypes.ExtendedDType) and
|
|
|
|
|
not isinstance(x.dtype, dtypes.ExtendedDType))
|
|
|
|
|
# For backward compatibility, if the edtype rules have a `convert_to` method,
|
|
|
|
|
# use that rather than looking for an `allow_conversion: bool` attribute.
|
|
|
|
|
if convert_to := getattr(edtype._rules, 'convert_to', None):
|
|
|
|
|
allow_conversion = convert_to(x.dtype, edtype)
|
|
|
|
|
else:
|
|
|
|
|
allow_conversion = edtype._rules.allow_conversion
|
|
|
|
|
if not allow_conversion:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Cannot convert_element_type from {dtype_to_string(x.dtype)} "
|
|
|
|
|
f"to {dtype_to_string(edtype)}")
|
|
|
|
|
rep_aval = core.physical_element_aval(edtype)
|
|
|
|
|
if x.dtype != rep_aval.dtype:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"can only convert to extended dtype from its representation dtype, "
|
|
|
|
|
f"but tried to convert from {dtype_to_string(x.dtype)} to "
|
|
|
|
|
f"{dtype_to_string(edtype)} which doesn't match the representation type "
|
|
|
|
|
f"{dtype_to_string(rep_aval.dtype)}.")
|
|
|
|
|
if x.ndim < rep_aval.ndim:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"can only convert to extended dtype from an array of its "
|
|
|
|
|
f"representation type, but the extended dtype {dtype_to_string(edtype)}"
|
|
|
|
|
f" has a representation shape {rep_aval.shape} (rank {rep_aval.ndim}) "
|
|
|
|
|
f"while the given representation array has shape {x.shape} (rank "
|
|
|
|
|
f"{x.ndim} < {rep_aval.ndim}).")
|
|
|
|
|
n = x.ndim - rep_aval.ndim
|
|
|
|
|
shape_prefix, shape_suffix = x.shape[:n], x.shape[n:]
|
|
|
|
|
if shape_suffix != rep_aval.shape:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"can only convert to extended dtype from an array of its "
|
|
|
|
|
f"representation type, but the extended dtype {dtype_to_string(edtype)}"
|
|
|
|
|
f" has a representation shape {rep_aval.shape} while the given "
|
|
|
|
|
f"representation array has shape {x.shape}, so the shape suffix "
|
|
|
|
|
f"does not match: given {shape_suffix} but required {rep_aval.shape}.")
|
2024-11-05 07:16:32 -08:00
|
|
|
|
return x.update(shape=shape_prefix, dtype=edtype)
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
|
|
|
|
|
to_edtype_p = Primitive('to_edtype')
|
|
|
|
|
to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p))
|
|
|
|
|
to_edtype_p.def_abstract_eval(_to_edtype_abstract_eval)
|
|
|
|
|
ad.defjvp(to_edtype_p,
|
|
|
|
|
lambda t, x, edtype:
|
|
|
|
|
convert_element_type(t, core.primal_dtype_to_tangent_dtype(edtype)))
|
|
|
|
|
ad.primitive_transposes[to_edtype_p] = \
|
|
|
|
|
lambda ct, x, edtype: [from_edtype_p.bind(ct, dtype=x.aval.dtype)] # type: ignore
|
|
|
|
|
batching.defvectorized(to_edtype_p)
|
|
|
|
|
mlir.register_lowering(to_edtype_p, lambda _, x, **__: [x])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _from_edtype_abstract_eval(x, *, dtype):
|
|
|
|
|
assert (isinstance(x.dtype, dtypes.ExtendedDType) and
|
|
|
|
|
not isinstance(dtype, dtypes.ExtendedDType))
|
|
|
|
|
if convert_from := getattr(x.dtype._rules, 'convert_from', None):
|
|
|
|
|
allow_conversion = convert_from(x.dtype, dtype)
|
|
|
|
|
else:
|
|
|
|
|
allow_conversion = x.dtype._rules.allow_conversion
|
|
|
|
|
if not allow_conversion:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Cannot convert_element_type from {dtype_to_string(x.dtype)} "
|
|
|
|
|
f"to {dtype_to_string(dtype)}")
|
|
|
|
|
rep_aval = core.physical_element_aval(x.dtype)
|
|
|
|
|
if rep_aval.dtype != dtype:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"can only convert from extended dtype to its representation dtype, "
|
|
|
|
|
f"but tried to convert from {dtype_to_string(x.dtype)} to "
|
|
|
|
|
f"{dtype_to_string(dtype)} which doesn't match the representation type "
|
|
|
|
|
f"{dtype_to_string(rep_aval.dtype)}.")
|
|
|
|
|
if all(isinstance(d, int) for d in x.shape):
|
|
|
|
|
return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
from_edtype_p = Primitive('from_edtype')
|
|
|
|
|
from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p))
|
|
|
|
|
from_edtype_p.def_abstract_eval(_from_edtype_abstract_eval)
|
|
|
|
|
ad.defjvp(from_edtype_p,
|
|
|
|
|
lambda t, x, dtype:
|
|
|
|
|
convert_element_type(t, core.primal_dtype_to_tangent_dtype(dtype)))
|
|
|
|
|
ad.primitive_transposes[from_edtype_p] = \
|
|
|
|
|
lambda ct, x, dtype: [to_edtype_p.bind(ct, edtype=x.dtype)]
|
|
|
|
|
batching.defvectorized(from_edtype_p)
|
|
|
|
|
mlir.register_lowering(from_edtype_p, lambda _, x, **__: [x])
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
2023-02-16 08:21:18 -08:00
|
|
|
|
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
|
|
|
|
|
2024-12-20 12:45:24 -08:00
|
|
|
|
old_nbits = dtypes.bit_width(old_dtype)
|
|
|
|
|
new_nbits = dtypes.bit_width(new_dtype)
|
|
|
|
|
|
|
|
|
|
if old_nbits == new_nbits:
|
2023-02-16 08:21:18 -08:00
|
|
|
|
return operand.shape
|
2024-12-20 12:45:24 -08:00
|
|
|
|
elif old_nbits > new_nbits:
|
|
|
|
|
return (*operand.shape, old_nbits // new_nbits)
|
2023-02-16 10:56:06 -08:00
|
|
|
|
else:
|
|
|
|
|
dim_size = operand.shape[-1] if operand.shape else 1
|
2024-12-20 12:45:24 -08:00
|
|
|
|
if dim_size * old_nbits != new_nbits:
|
2023-02-16 10:56:06 -08:00
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Attempting to convert array of shape {operand.shape} "
|
2024-12-20 12:45:24 -08:00
|
|
|
|
f"from {old_dtype} of size {old_nbits} bits "
|
|
|
|
|
f"to {new_dtype} of size {new_nbits}, bits "
|
|
|
|
|
f"but {dim_size} * {old_nbits} != {new_nbits}")
|
2023-02-16 10:56:06 -08:00
|
|
|
|
return operand.shape[:-1]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-02-22 10:45:18 -08:00
|
|
|
|
def _bitcast_convert_type_sharding_rule(operand, *, new_dtype):
|
|
|
|
|
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
|
|
|
|
|
|
|
|
|
old_nbits = dtypes.bit_width(old_dtype)
|
|
|
|
|
new_nbits = dtypes.bit_width(new_dtype)
|
|
|
|
|
|
|
|
|
|
if old_nbits == new_nbits:
|
|
|
|
|
return operand.sharding
|
|
|
|
|
elif old_nbits > new_nbits:
|
|
|
|
|
return operand.sharding.with_spec((*operand.sharding.spec, None))
|
|
|
|
|
else:
|
|
|
|
|
return operand.sharding.with_spec(operand.sharding.spec[:-1])
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
2021-06-10 12:42:40 +02:00
|
|
|
|
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
2023-02-16 10:56:06 -08:00
|
|
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
|
|
|
|
if (dtypes.issubdtype(old_dtype, np.bool_) or
|
|
|
|
|
dtypes.issubdtype(old_dtype, np.complexfloating) or
|
|
|
|
|
dtypes.issubdtype(new_dtype, np.bool_) or
|
|
|
|
|
dtypes.issubdtype(new_dtype, np.complexfloating)):
|
2021-06-10 12:42:40 +02:00
|
|
|
|
if old_dtype != new_dtype:
|
2023-02-16 10:56:06 -08:00
|
|
|
|
raise TypeError("lax.bitcast_convert_type does not support bool or complex values "
|
|
|
|
|
"unless the operand and destination types match. "
|
|
|
|
|
f"Got operand dtype={old_dtype}, {new_dtype=}. "
|
|
|
|
|
"Consider using the arr.view() method instead.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return new_dtype
|
|
|
|
|
|
|
|
|
|
bitcast_convert_type_p = standard_primitive(
|
|
|
|
|
_bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule,
|
2025-02-22 10:45:18 -08:00
|
|
|
|
'bitcast_convert_type', weak_type_rule=_strip_weak_type,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_bitcast_convert_type_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'bitcast_convert_type'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(bitcast_convert_type_p)
|
|
|
|
|
batching.defvectorized(bitcast_convert_type_p)
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
|
|
|
|
|
aval_out, = ctx.avals_out
|
2025-02-22 10:45:18 -08:00
|
|
|
|
out = hlo.bitcast_convert(mlir.aval_to_ir_type(aval_out), operand)
|
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-04-21 23:58:34 +02:00
|
|
|
|
def _validate_preferred_element_type(input_dtype, preferred_element_type):
|
2023-05-12 19:56:59 -07:00
|
|
|
|
if (dtypes.issubdtype(input_dtype, np.integer) and
|
|
|
|
|
dtypes.issubdtype(preferred_element_type, np.floating)):
|
2022-06-15 14:12:09 -07:00
|
|
|
|
# Special-case integer->float multiply. This is allowed, and also allows
|
|
|
|
|
# different signedness between input and output.
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
allowed_types = (np.integer, np.floating, np.complexfloating)
|
2023-05-12 19:56:59 -07:00
|
|
|
|
if any(dtypes.issubdtype(input_dtype, t) and not
|
|
|
|
|
dtypes.issubdtype(preferred_element_type, t) for t in allowed_types):
|
|
|
|
|
raise TypeError("Input type is incompatible with "
|
|
|
|
|
"`preferred_element_type`. The compatible combinations "
|
|
|
|
|
"of (input_type, preferred_element_type) are "
|
|
|
|
|
"(integral, integral), (integral, floating), "
|
2022-06-15 14:12:09 -07:00
|
|
|
|
"(floating, floating), (complex, complex.")
|
2023-05-12 19:56:59 -07:00
|
|
|
|
if (dtypes.issubdtype(input_dtype, np.signedinteger) and
|
|
|
|
|
not dtypes.issubdtype(preferred_element_type, np.signedinteger)):
|
|
|
|
|
raise TypeError("`preferred_element_type` must have the same signedness "
|
|
|
|
|
"as the original type.")
|
2021-04-21 23:58:34 +02:00
|
|
|
|
input_bitwidth = np.dtype(input_dtype).itemsize
|
|
|
|
|
preferred_bitwidth = np.dtype(preferred_element_type).itemsize
|
|
|
|
|
if preferred_bitwidth < input_bitwidth:
|
2023-05-12 19:56:59 -07:00
|
|
|
|
raise TypeError("`preferred_element_type` must not be narrower than the "
|
|
|
|
|
"original type.")
|
|
|
|
|
|
2021-04-21 23:58:34 +02:00
|
|
|
|
|
2020-12-10 02:29:40 +00:00
|
|
|
|
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
2024-10-22 13:10:05 -07:00
|
|
|
|
preferred_element_type: DTypeLike | None,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding):
|
|
|
|
|
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
2024-10-22 13:10:05 -07:00
|
|
|
|
raise NotImplementedError
|
2025-03-10 12:24:38 -07:00
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
|
|
|
|
for d in (lhs_contracting, lhs_batch)):
|
|
|
|
|
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
|
|
|
|
|
"less than the number of axes of the lhs value, got "
|
|
|
|
|
f"lhs_batch of {lhs_batch} and lhs_contracting of {lhs_contracting} "
|
|
|
|
|
f"for lhs of rank {lhs.ndim}")
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, rhs.ndim))
|
|
|
|
|
for d in (rhs_contracting, rhs_batch)):
|
|
|
|
|
msg = ("dot_general requires rhs dimension numbers to be nonnegative and "
|
|
|
|
|
"less than the number of axes of the rhs value, got "
|
|
|
|
|
f"rhs_batch of {rhs_batch} and rhs_contracting of {rhs_contracting} "
|
|
|
|
|
f"for rhs of rank {rhs.ndim}")
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
if len(lhs_batch) != len(rhs_batch):
|
|
|
|
|
msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
|
|
|
|
|
"dimensions, got lhs_batch {} and rhs_batch {}.")
|
|
|
|
|
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
|
|
|
|
lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch)
|
|
|
|
|
rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch)
|
|
|
|
|
if len(lhs_batch_set) != len(lhs_batch):
|
|
|
|
|
msg = ("dot_general requires lhs batch dimensions to be distinct, got "
|
|
|
|
|
f"lhs_batch {lhs_batch}.")
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
if len(rhs_batch_set) != len(rhs_batch):
|
|
|
|
|
msg = ("dot_general requires rhs batch dimensions to be distinct, got "
|
|
|
|
|
f"rhs_batch {rhs_batch}.")
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
if len(lhs_contracting_set) != len(lhs_contracting):
|
|
|
|
|
msg = ("dot_general requires lhs contracting dimensions to be distinct, "
|
|
|
|
|
f"got lhs_contracting {lhs_contracting}.")
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
if len(rhs_contracting_set) != len(rhs_contracting):
|
|
|
|
|
msg = ("dot_general requires rhs contracting dimensions to be distinct, "
|
|
|
|
|
f"got rhs_contracting {rhs_contracting}.")
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
if lhs_contracting_set & lhs_batch_set:
|
|
|
|
|
msg = ("dot_general requires lhs batch dimensions to be disjoint from "
|
|
|
|
|
"contracting dimensions, got lhs_batch {} and lhs_contracting {}.")
|
|
|
|
|
raise TypeError(msg.format(lhs_batch, lhs_contracting))
|
|
|
|
|
if rhs_contracting_set & rhs_batch_set:
|
|
|
|
|
msg = ("dot_general requires rhs batch dimensions to be disjoint from "
|
|
|
|
|
"contracting dimensions, got rhs_batch {} and rhs_contracting {}.")
|
|
|
|
|
raise TypeError(msg.format(rhs_batch, rhs_contracting))
|
2022-01-20 22:58:09 -08:00
|
|
|
|
lhs_batch_shape = tuple(lhs.shape[i] for i in lhs_batch)
|
|
|
|
|
rhs_batch_shape = tuple(rhs.shape[i] for i in rhs_batch)
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if not core.definitely_equal_shape(lhs_batch_shape, rhs_batch_shape):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
|
|
|
|
|
"to have the same shape, got {} and {}.")
|
|
|
|
|
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
|
2022-01-20 22:58:09 -08:00
|
|
|
|
lhs_contracting_shape = tuple(lhs.shape[i] for i in lhs_contracting)
|
|
|
|
|
rhs_contracting_shape = tuple(rhs.shape[i] for i in rhs_contracting)
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if not core.definitely_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
msg = ("dot_general requires contracting dimensions to have the same "
|
|
|
|
|
"shape, got {} and {}.")
|
|
|
|
|
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
|
|
|
|
|
|
2021-04-02 13:23:43 -07:00
|
|
|
|
return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
|
|
|
|
|
|
|
|
|
|
def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
|
2025-03-10 12:24:38 -07:00
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
2022-01-20 22:58:09 -08:00
|
|
|
|
batch_shape = tuple(lhs_shape[i] for i in lhs_batch)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
2022-01-20 22:58:09 -08:00
|
|
|
|
lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
|
2025-03-10 12:24:38 -07:00
|
|
|
|
rhs_group = ()
|
|
|
|
|
if isinstance(dimension_numbers, RaggedDotDimensionNumbers):
|
|
|
|
|
rhs_group = tuple(dimension_numbers.rhs_group_dimensions)
|
|
|
|
|
rhs_contract_or_batch_or_group = tuple(
|
|
|
|
|
sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group)
|
|
|
|
|
)
|
|
|
|
|
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch_or_group)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
|
|
|
|
|
2024-10-11 16:04:35 -07:00
|
|
|
|
|
|
|
|
|
def _check_specs_match(lhs_spec, rhs_spec, msg):
|
|
|
|
|
for l, r in zip(lhs_spec, rhs_spec):
|
|
|
|
|
if l is not None and r is not None and l != r:
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(msg)
|
2024-10-11 16:04:35 -07:00
|
|
|
|
|
|
|
|
|
def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
|
2024-10-22 13:10:05 -07:00
|
|
|
|
preferred_element_type: DTypeLike | None,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding):
|
2024-10-11 16:04:35 -07:00
|
|
|
|
if lhs.sharding.mesh != rhs.sharding.mesh:
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2024-10-11 16:04:35 -07:00
|
|
|
|
'Mesh of both lhs and rhs should match. Got lhs:'
|
|
|
|
|
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')
|
|
|
|
|
|
2025-01-16 18:16:12 -08:00
|
|
|
|
if out_sharding is not None:
|
|
|
|
|
assert isinstance(out_sharding, NamedSharding)
|
|
|
|
|
return out_sharding
|
2024-10-22 13:10:05 -07:00
|
|
|
|
|
2024-10-11 16:04:35 -07:00
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
|
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
|
|
|
|
|
rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch)
|
|
|
|
|
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
|
2024-10-22 13:10:05 -07:00
|
|
|
|
f"to have the consistent sharding, got {lhs_batch_spec} and "
|
|
|
|
|
f"{rhs_batch_spec}.")
|
2024-10-11 16:04:35 -07:00
|
|
|
|
_check_specs_match(lhs_batch_spec, rhs_batch_spec, msg)
|
|
|
|
|
|
|
|
|
|
lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting)
|
|
|
|
|
rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting)
|
|
|
|
|
msg = ("dot_general requires contracting dimensions to have consistent "
|
2024-10-22 13:10:05 -07:00
|
|
|
|
f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.")
|
2024-10-11 16:04:35 -07:00
|
|
|
|
_check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg)
|
|
|
|
|
|
2025-01-21 15:06:03 -08:00
|
|
|
|
for l, r in zip(lhs_contracting_spec, rhs_contracting_spec):
|
|
|
|
|
if l is not None and r is not None:
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2025-01-21 15:06:03 -08:00
|
|
|
|
'Contracting dimensions are sharded and it is ambiguous how the'
|
|
|
|
|
' output should be sharded. Please specify the output sharding via'
|
|
|
|
|
' the `out_sharding` parameter of einsum. Or reshard your input via'
|
|
|
|
|
' `jax.experimental.shard.reshard` so that the dot is conflict free.'
|
|
|
|
|
f' Got {lhs_contracting_spec=} and {rhs_contracting_spec=}')
|
|
|
|
|
|
2024-10-11 16:04:35 -07:00
|
|
|
|
return _dot_general_sharding_computation(
|
|
|
|
|
lhs.sharding.spec, rhs.sharding.spec, dimension_numbers, lhs.sharding.mesh)
|
|
|
|
|
|
|
|
|
|
def _dot_general_sharding_computation(lhs_spec, rhs_spec,
|
|
|
|
|
dimension_numbers, mesh):
|
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
|
batch_spec = tuple(lhs_spec[i] for i in lhs_batch)
|
|
|
|
|
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
|
|
|
|
lhs_tensored_spec = tuple_delete(lhs_spec, lhs_contract_or_batch)
|
|
|
|
|
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
|
|
|
|
|
rhs_tensored_spec = tuple_delete(rhs_spec, rhs_contract_or_batch)
|
|
|
|
|
return NamedSharding(mesh, P(*(batch_spec + lhs_tensored_spec + rhs_tensored_spec)))
|
|
|
|
|
|
2022-01-20 22:58:09 -08:00
|
|
|
|
def tuple_delete(tup, idx):
|
|
|
|
|
idx_ = set(idx)
|
|
|
|
|
return tuple(tup[i] for i in range(len(tup)) if i not in idx_)
|
|
|
|
|
|
|
|
|
|
|
2020-12-10 02:29:40 +00:00
|
|
|
|
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
2024-10-22 13:10:05 -07:00
|
|
|
|
preferred_element_type: DTypeLike | None,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
out_sharding, name: str = 'lax.dot_general'):
|
2025-01-16 18:16:12 -08:00
|
|
|
|
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
2024-10-22 13:10:05 -07:00
|
|
|
|
raise NotImplementedError
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
del dimension_numbers # unused
|
2023-05-12 19:56:59 -07:00
|
|
|
|
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
|
|
|
|
|
# primitive_util.h's HigherPrecisionType, e.g.
|
|
|
|
|
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
|
|
|
|
|
def type_properties(dt):
|
|
|
|
|
c = _real_dtype(dt) if dtypes.issubdtype(dt, np.complexfloating) else dt
|
|
|
|
|
return (dtypes.issubdtype(dt, np.complexfloating),
|
|
|
|
|
dtypes.finfo(c).maxexp if dtypes.issubdtype(c, np.floating) else -1,
|
|
|
|
|
dtypes.finfo(c).nmant if dtypes.issubdtype(c, np.floating) else -1,
|
|
|
|
|
_bit_width(c),
|
|
|
|
|
not dtypes.issubdtype(c, np.unsignedinteger))
|
|
|
|
|
lhs_prop, rhs_prop = type_properties(lhs.dtype), type_properties(rhs.dtype)
|
|
|
|
|
if lhs_prop > rhs_prop:
|
|
|
|
|
result_dtype = lhs.dtype
|
|
|
|
|
elif rhs_prop > lhs_prop:
|
|
|
|
|
result_dtype = rhs.dtype
|
|
|
|
|
else:
|
|
|
|
|
if lhs.dtype != rhs.dtype:
|
2025-03-10 12:24:38 -07:00
|
|
|
|
raise TypeError(f'{name} argument type error: {lhs.dtype}, {rhs.dtype}')
|
2023-05-12 19:56:59 -07:00
|
|
|
|
result_dtype = lhs.dtype
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
has_algorithm = isinstance(precision, (DotAlgorithm, DotAlgorithmPreset))
|
|
|
|
|
return _maybe_upcast(result_dtype, preferred_element_type,
|
|
|
|
|
check_bit_width=not has_algorithm)
|
2023-05-12 19:56:59 -07:00
|
|
|
|
|
|
|
|
|
def _bit_width(d):
|
|
|
|
|
if dtypes.issubdtype(d, np.inexact): return dtypes.finfo(d).bits
|
|
|
|
|
elif dtypes.issubdtype(d, np.integer): return dtypes.iinfo(d).bits
|
|
|
|
|
elif d == np.dtype('bool'): return 1
|
|
|
|
|
else: assert False, d # should be unreachable, open an issue!
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
|
2023-05-12 19:56:59 -07:00
|
|
|
|
# replicates the logic in shape_inference.cc's MaybeUpcast
|
|
|
|
|
if (preferred_element_type is None or
|
|
|
|
|
result_dtype == preferred_element_type):
|
|
|
|
|
return result_dtype
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
if (check_bit_width and not dtypes.issubdtype(result_dtype, np.floating) and
|
2023-05-12 19:56:59 -07:00
|
|
|
|
_bit_width(preferred_element_type) < _bit_width(result_dtype)):
|
|
|
|
|
raise TypeError("`preferred_element_type` must not be narrower than the "
|
|
|
|
|
"original type, got preferred_element_type of "
|
|
|
|
|
f"{preferred_element_type} for result type of "
|
|
|
|
|
f"{result_dtype}.")
|
2020-12-10 02:29:40 +00:00
|
|
|
|
return preferred_element_type
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-01-31 09:07:39 -08:00
|
|
|
|
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
2023-12-08 12:09:04 +00:00
|
|
|
|
preferred_element_type: DTypeLike | None,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding, swap_ans=False):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
2023-01-31 09:07:39 -08:00
|
|
|
|
x_ndim = x.aval.ndim
|
2020-10-17 14:33:26 -04:00
|
|
|
|
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
2023-08-04 13:32:12 -07:00
|
|
|
|
y_kept = remaining(range(np.ndim(y)), y_contract, y_batch)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if swap_ans:
|
|
|
|
|
ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
|
|
|
|
|
else:
|
|
|
|
|
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
|
|
|
|
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
2024-05-17 09:46:36 +01:00
|
|
|
|
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
2024-10-25 10:34:33 -07:00
|
|
|
|
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
|
|
|
|
|
out_axes = np.argsort(unsorted_axes)
|
2025-02-19 06:52:52 -08:00
|
|
|
|
xs = x.aval.sharding
|
|
|
|
|
inverse_spec = tuple(xs.spec[o] for o in unsorted_axes)
|
|
|
|
|
ds = xs.with_spec(inverse_spec)
|
2024-10-25 10:34:33 -07:00
|
|
|
|
dot_general_out = dot_general(g, y, dims, precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding=ds)
|
2024-10-25 10:34:33 -07:00
|
|
|
|
x_bar = transpose(dot_general_out, tuple(out_axes))
|
2023-05-12 19:56:59 -07:00
|
|
|
|
if x_bar.dtype != x.aval.dtype:
|
|
|
|
|
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
|
|
|
|
|
return x_bar
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-01-31 09:07:39 -08:00
|
|
|
|
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
2024-10-22 13:10:05 -07:00
|
|
|
|
preferred_element_type: DTypeLike | None,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
|
|
|
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
2023-05-12 19:56:59 -07:00
|
|
|
|
y_bar = _dot_general_transpose_lhs(
|
2023-01-31 09:07:39 -08:00
|
|
|
|
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
preferred_element_type=preferred_element_type, out_sharding=out_sharding,
|
2024-10-22 13:10:05 -07:00
|
|
|
|
swap_ans=True)
|
2023-05-12 19:56:59 -07:00
|
|
|
|
if y_bar.dtype != y.aval.dtype:
|
|
|
|
|
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
|
|
|
|
|
return y_bar
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
|
|
|
|
def _dot_batch_rule(
|
|
|
|
|
unpack_args,
|
|
|
|
|
unpack_dims,
|
|
|
|
|
invoke_prim,
|
2025-01-29 09:33:44 -08:00
|
|
|
|
axis_data,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
batched_args,
|
|
|
|
|
batch_dims,
|
|
|
|
|
*,
|
|
|
|
|
dimension_numbers,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
precision,
|
|
|
|
|
preferred_element_type: DTypeLike | None,
|
|
|
|
|
**_,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
lhs, rhs = unpack_args(batched_args)
|
|
|
|
|
lbd, rbd = unpack_dims(batch_dims)
|
2023-05-05 15:25:42 -04:00
|
|
|
|
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
|
|
|
|
|
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
|
2023-06-10 17:12:58 -04:00
|
|
|
|
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
|
2023-07-24 14:01:07 -07:00
|
|
|
|
(np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim),
|
2023-06-10 17:12:58 -04:00
|
|
|
|
dimension_numbers)
|
2023-05-05 15:25:42 -04:00
|
|
|
|
# TODO Should probably check that any ragged dimensions have corresponding
|
|
|
|
|
# sizes, because otherwise the dot product is technically undefined.
|
2023-06-10 17:12:58 -04:00
|
|
|
|
#
|
|
|
|
|
# This masking is not strictly necessary for non-contraction dimensions;
|
|
|
|
|
# we could micro-optimize here by avoiding computing that mask.
|
2023-05-05 15:25:42 -04:00
|
|
|
|
if type(lbd) is RaggedAxis:
|
2023-06-05 09:23:41 -04:00
|
|
|
|
lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd)
|
2023-06-10 17:12:58 -04:00
|
|
|
|
lhs_shape = batching.bdim_as_shape(lbd, lhs.shape)
|
|
|
|
|
else:
|
2023-07-24 14:01:07 -07:00
|
|
|
|
lhs_shape = np.shape(lhs)
|
2023-05-05 15:25:42 -04:00
|
|
|
|
if type(rbd) is RaggedAxis:
|
2023-06-05 09:23:41 -04:00
|
|
|
|
rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd)
|
2023-06-10 17:12:58 -04:00
|
|
|
|
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
|
|
|
|
|
else:
|
2023-07-24 14:01:07 -07:00
|
|
|
|
rhs_shape = np.shape(rhs)
|
2025-01-29 09:33:44 -08:00
|
|
|
|
|
|
|
|
|
result_batch_dim = batching.shape_as_bdim(
|
|
|
|
|
result_stack_dim,
|
|
|
|
|
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
|
|
|
|
|
|
2025-01-16 18:16:12 -08:00
|
|
|
|
if out_sharding is not None:
|
2025-01-29 09:33:44 -08:00
|
|
|
|
out_sharding = batching.get_sharding_for_vmap(
|
|
|
|
|
axis_data, out_sharding, result_batch_dim)
|
|
|
|
|
|
2024-10-03 16:41:31 -07:00
|
|
|
|
batched_out = invoke_prim(
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
new_dimension_numbers,
|
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding=out_sharding,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
)
|
2020-11-25 15:23:00 -08:00
|
|
|
|
return batched_out, result_batch_dim
|
|
|
|
|
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
2020-11-25 15:23:00 -08:00
|
|
|
|
def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
2023-06-10 17:12:58 -04:00
|
|
|
|
# There are three kinds of dimensions in a dot_general:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
# - contraction dimensions appear in lhs and rhs but not the result
|
|
|
|
|
# - batch dimensions appear in lhs, rhs, and result
|
|
|
|
|
# - tensor product dimensions appear in the result and one of lhs or rhs
|
2023-06-10 17:12:58 -04:00
|
|
|
|
# The dimensions of the result are ordered as
|
|
|
|
|
# - Batch dimensions
|
|
|
|
|
# - Q: In what order? The order of appearance in lhs, rhs, or
|
|
|
|
|
# dimension_numbers?
|
|
|
|
|
# - Tensor dimensions from the LHS
|
|
|
|
|
# - Tensor dimensions from the RHS
|
2020-11-25 15:23:00 -08:00
|
|
|
|
lhs_ndim, rhs_ndim = ndims
|
2023-06-10 17:12:58 -04:00
|
|
|
|
# lbd and rbd are "batch" dimensions in the sense of dimensions being
|
|
|
|
|
# vmapped, not to be confused with "batch" dimensions in the sense of
|
|
|
|
|
# explicitly present dimensions that this dot_general is zipping together.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
lbd, rbd = batch_dims
|
|
|
|
|
assert lbd is not None or rbd is not None
|
2025-03-10 12:24:38 -07:00
|
|
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = _from_maybe_ragged(dimension_numbers)
|
2020-11-25 15:23:00 -08:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
is_ragged_dot = isinstance(dimension_numbers, RaggedDotDimensionNumbers)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def bump_dims(dims, b):
|
|
|
|
|
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
|
|
|
|
|
2022-10-17 11:15:14 -07:00
|
|
|
|
if type(lbd) is type(rbd) is int:
|
2023-06-10 17:12:58 -04:00
|
|
|
|
# The vmapped dimensions become an additional batch dimension in the
|
|
|
|
|
# batched dot_general, which we arbitrarily put first.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
|
|
|
|
|
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
|
|
|
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
|
|
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
|
|
|
|
result_batch_dim = 0
|
2023-05-05 15:25:42 -04:00
|
|
|
|
elif (type(lbd) is int and rbd is None):
|
2023-06-10 17:12:58 -04:00
|
|
|
|
# The left vmapped dimension becomes an additional tensor dimension in the
|
|
|
|
|
# batched dot_general.
|
2022-10-17 11:15:14 -07:00
|
|
|
|
lhs_tensor = [d for d in range(lhs_ndim)
|
|
|
|
|
if d not in lhs_batch and d not in lhs_contract]
|
|
|
|
|
result_batch_dim = len(lhs_batch) + int(sum(np.less(lhs_tensor, lbd)))
|
|
|
|
|
lhs_batch = bump_dims(lhs_batch, lbd)
|
|
|
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
2023-05-05 15:25:42 -04:00
|
|
|
|
elif (type(rbd) is int and lbd is None):
|
2023-06-10 17:12:58 -04:00
|
|
|
|
# The right vmapped dimension becomes an additional tensor dimension in the
|
|
|
|
|
# batched dot_general.
|
2025-03-10 12:24:38 -07:00
|
|
|
|
rhs_tensor = list(
|
|
|
|
|
remaining(
|
|
|
|
|
range(rhs_ndim),
|
|
|
|
|
rhs_batch,
|
|
|
|
|
rhs_contract,
|
|
|
|
|
dimension_numbers.rhs_group_dimensions if is_ragged_dot else [],
|
|
|
|
|
)
|
|
|
|
|
)
|
2022-10-17 11:15:14 -07:00
|
|
|
|
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
|
|
|
|
int(sum(np.less(rhs_tensor, rbd))))
|
|
|
|
|
rhs_batch = bump_dims(rhs_batch, rbd)
|
|
|
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
|
|
|
|
else:
|
2023-06-10 17:12:58 -04:00
|
|
|
|
# We wouldn't be here if we didn't have at least one vmapped dimension.
|
2022-10-17 11:15:14 -07:00
|
|
|
|
assert False
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
2025-03-10 12:24:38 -07:00
|
|
|
|
if is_ragged_dot:
|
|
|
|
|
new_dimension_numbers = RaggedDotDimensionNumbers(
|
|
|
|
|
dot_dimension_numbers=new_dimension_numbers,
|
|
|
|
|
lhs_ragged_dimensions=bump_dims(
|
|
|
|
|
dimension_numbers.lhs_ragged_dimensions, lbd
|
|
|
|
|
),
|
|
|
|
|
rhs_group_dimensions=bump_dims(
|
|
|
|
|
dimension_numbers.rhs_group_dimensions, rbd
|
|
|
|
|
),
|
|
|
|
|
)
|
2022-10-17 11:15:14 -07:00
|
|
|
|
return new_dimension_numbers, result_batch_dim
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
|
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
|
|
|
|
|
dimension_numbers, **params):
|
|
|
|
|
lhs_aval, _ = in_avals
|
|
|
|
|
(lhs_contract, _), _ = dimension_numbers
|
|
|
|
|
padded_axes = [(i, lhs_aval.shape[i].val) for i in lhs_contract
|
|
|
|
|
if isinstance(lhs_aval.shape[i], pe.BoundedAxisSize)]
|
|
|
|
|
lhs_ = _replace_masked_values(lhs, 0, padded_axes)
|
|
|
|
|
return [dot_general(lhs_, rhs, dimension_numbers=dimension_numbers, **params)]
|
|
|
|
|
|
2023-02-09 11:02:24 -08:00
|
|
|
|
def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
|
2022-12-21 10:16:18 -08:00
|
|
|
|
# * suppress printing precision or preferred_element_type when None.
|
|
|
|
|
# * print dimension_numbers as list-of-lists to be shorter.
|
2022-03-30 17:52:55 -07:00
|
|
|
|
printed_params = {k: v for k, v in eqn.params.items() if v is not None}
|
|
|
|
|
(lhs_cont, rhs_cont), (lhs_batch, rhs_batch) = eqn.params['dimension_numbers']
|
|
|
|
|
printed_params['dimension_numbers'] = (
|
|
|
|
|
(list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch)))
|
2023-02-09 11:02:24 -08:00
|
|
|
|
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
2022-03-30 17:52:55 -07:00
|
|
|
|
|
2024-10-14 14:00:58 -07:00
|
|
|
|
|
2024-10-25 12:06:59 -07:00
|
|
|
|
def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
|
2024-10-14 14:00:58 -07:00
|
|
|
|
assert len(invar_raggedness) == 2
|
|
|
|
|
assert len(outvars) == 1
|
|
|
|
|
invar_raggedness_lhs = invar_raggedness[0]
|
|
|
|
|
invar_raggedness_rhs = invar_raggedness[1]
|
|
|
|
|
|
2024-10-25 12:06:59 -07:00
|
|
|
|
dimension_numbers = eqn_params['dimension_numbers']
|
|
|
|
|
(lhs_contracting, rhs_contracting), (_, _) = dimension_numbers
|
|
|
|
|
|
|
|
|
|
if not invar_raggedness_lhs and not invar_raggedness_rhs:
|
|
|
|
|
# Both are dense - it is valid to reach here, because dense operations
|
|
|
|
|
# are legal in code running under ragged prop.
|
|
|
|
|
return invar_raggedness, [None]
|
|
|
|
|
|
|
|
|
|
if not invar_raggedness_lhs or not invar_raggedness_rhs:
|
|
|
|
|
# One ragged, one dense
|
|
|
|
|
if not invar_raggedness_lhs:
|
|
|
|
|
# left is dense, right is ragged
|
|
|
|
|
_, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
|
|
|
|
|
if rhs_contracting != ragged_axis_dim_rhs:
|
|
|
|
|
# Contraction is on a dense dimension, this is valid!
|
|
|
|
|
return invar_raggedness, [None]
|
|
|
|
|
if not invar_raggedness_rhs:
|
|
|
|
|
# left is ragged, right is dense
|
|
|
|
|
_, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
|
|
|
|
|
if lhs_contracting != ragged_axis_dim_lhs:
|
|
|
|
|
# Contraction is on a dense dimension, this is valid!
|
|
|
|
|
return invar_raggedness, [None]
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError('NYI - dense and ragged dim contraction')
|
|
|
|
|
|
2024-10-14 14:00:58 -07:00
|
|
|
|
stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
|
|
|
|
|
stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
|
|
|
|
|
|
|
|
|
|
if stacked_axis_rhs != 0 or stacked_axis_lhs != 0:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
'Dot general ragged prop for non 0 stacked axis, NYI'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# We only support ragged k atm, that is, lhs is (m, ragged_k) and rhs is
|
|
|
|
|
# (ragged_k, n), meaning the output is dense.
|
|
|
|
|
if ragged_axis_dim_lhs != 2 or ragged_axis_dim_rhs != 1:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
'Dot general ragged prop for non contraction raggedness, NYI'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(outvars) == 1
|
|
|
|
|
|
|
|
|
|
# TODO(mvoz): A constant on batching.* ?
|
|
|
|
|
# Dense (m, n) - no jumble only atm
|
2024-10-25 12:06:59 -07:00
|
|
|
|
return invar_raggedness, [None]
|
2024-10-14 14:00:58 -07:00
|
|
|
|
|
|
|
|
|
|
2024-10-11 16:04:35 -07:00
|
|
|
|
dot_general_p = standard_primitive(
|
2024-10-14 14:00:58 -07:00
|
|
|
|
_dot_general_shape_rule,
|
|
|
|
|
_dot_general_dtype_rule,
|
|
|
|
|
'dot_general',
|
|
|
|
|
sharding_rule=_dot_general_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'dot_general')
|
2024-10-14 14:00:58 -07:00
|
|
|
|
)
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dot_general_batch_unpack_args(batch_args):
|
|
|
|
|
lhs, rhs = batch_args
|
|
|
|
|
return (lhs, rhs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dot_general_batch_unpack_dims(batch_dims):
|
|
|
|
|
lbd, rbd = batch_dims
|
|
|
|
|
return (lbd, rbd)
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defbilinear(dot_general_p,
|
|
|
|
|
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
2024-10-03 16:41:31 -07:00
|
|
|
|
_dot_general_batch_rule = functools.partial(
|
|
|
|
|
_dot_batch_rule,
|
|
|
|
|
_dot_general_batch_unpack_args,
|
|
|
|
|
_dot_general_batch_unpack_dims,
|
|
|
|
|
dot_general,
|
|
|
|
|
)
|
2025-01-29 09:33:44 -08:00
|
|
|
|
batching.fancy_primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
|
|
|
|
batching.skippable_batchers[dot_general_p] = lambda _: ()
|
2022-03-30 17:52:55 -07:00
|
|
|
|
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
|
2022-12-21 10:16:18 -08:00
|
|
|
|
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule
|
2021-08-03 00:02:59 -04:00
|
|
|
|
|
2025-03-21 17:35:37 -07:00
|
|
|
|
|
|
|
|
|
def _full_precision(precision: Precision) -> tuple[Precision, Precision]:
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
2025-03-21 17:35:37 -07:00
|
|
|
|
return (Precision.DEFAULT, Precision.DEFAULT)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
elif not isinstance(precision, tuple):
|
2025-03-21 17:35:37 -07:00
|
|
|
|
return (precision, precision)
|
2022-03-04 04:20:57 -08:00
|
|
|
|
else:
|
2025-03-21 17:35:37 -07:00
|
|
|
|
return precision
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def precision_attr(precision: Precision) -> ir.ArrayAttr:
|
2022-04-16 09:59:48 -04:00
|
|
|
|
return ir.ArrayAttr.get(
|
2025-03-21 17:35:37 -07:00
|
|
|
|
[hlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chlo_precision_attr(precision: Precision) -> ir.ArrayAttr:
|
|
|
|
|
return ir.ArrayAttr.get(
|
|
|
|
|
[chlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)]
|
|
|
|
|
)
|
2022-03-17 23:10:46 -07:00
|
|
|
|
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
|
2024-09-25 06:16:22 -07:00
|
|
|
|
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
if not isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
2024-09-25 06:16:22 -07:00
|
|
|
|
return None
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
|
2024-09-25 06:16:22 -07:00
|
|
|
|
|
|
|
|
|
|
2024-11-12 05:29:40 -08:00
|
|
|
|
def get_algorithm_compute_types(
|
|
|
|
|
algorithm: DotAlgorithm | DotAlgorithmPreset,
|
|
|
|
|
lhs_dtype: DTypeLike,
|
|
|
|
|
rhs_dtype: DTypeLike,
|
|
|
|
|
out_dtype: DTypeLike | None = None,
|
|
|
|
|
) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]:
|
2024-12-03 06:25:55 -08:00
|
|
|
|
if isinstance(algorithm, DotAlgorithm):
|
|
|
|
|
return (
|
|
|
|
|
algorithm.lhs_precision_type,
|
|
|
|
|
algorithm.rhs_precision_type,
|
|
|
|
|
algorithm.accumulation_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def maybe_convert_dtype(input_dtype, target_dtypes):
|
|
|
|
|
if target_dtypes is None:
|
|
|
|
|
return input_dtype
|
|
|
|
|
if np.dtype(input_dtype) in map(np.dtype, target_dtypes):
|
|
|
|
|
return input_dtype
|
|
|
|
|
return target_dtypes[0]
|
|
|
|
|
|
2024-12-05 07:00:58 -08:00
|
|
|
|
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types)
|
|
|
|
|
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types)
|
|
|
|
|
out_type = maybe_convert_dtype(
|
|
|
|
|
out_dtype, algorithm.supported_output_types(lhs_dtype, rhs_dtype)
|
2024-12-03 06:25:55 -08:00
|
|
|
|
)
|
2024-12-05 07:00:58 -08:00
|
|
|
|
return lhs_dtype, rhs_dtype, out_type
|
2024-11-12 05:29:40 -08:00
|
|
|
|
|
|
|
|
|
|
2025-03-27 17:12:08 -07:00
|
|
|
|
def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr:
|
|
|
|
|
if isinstance(accuracy, AccuracyMode):
|
|
|
|
|
return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name))
|
|
|
|
|
elif isinstance(accuracy, Tolerance):
|
|
|
|
|
return hlo.ResultAccuracyAttr.get(
|
|
|
|
|
atol=accuracy.atol,
|
|
|
|
|
rtol=accuracy.rtol,
|
|
|
|
|
ulps=accuracy.ulps,
|
|
|
|
|
mode='TOLERANCE',
|
|
|
|
|
)
|
|
|
|
|
|
2025-03-21 17:35:37 -07:00
|
|
|
|
def _handle_dot_precision(ctx, lhs, rhs, precision, platform):
|
2024-05-13 16:50:52 +00:00
|
|
|
|
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
2024-05-22 05:50:28 +00:00
|
|
|
|
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
|
|
|
|
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
2024-10-07 15:33:24 -07:00
|
|
|
|
if dtypes.float8_e3m4 is not None:
|
|
|
|
|
fp8_dtypes += (dtypes.float8_e3m4,)
|
|
|
|
|
if dtypes.float8_e4m3 is not None:
|
|
|
|
|
fp8_dtypes += (dtypes.float8_e4m3,)
|
2025-01-22 21:57:43 +00:00
|
|
|
|
if dtypes.float8_e8m0fnu is not None:
|
|
|
|
|
fp8_dtypes += (dtypes.float8_e8m0fnu,)
|
2024-05-22 05:50:28 +00:00
|
|
|
|
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
|
2025-03-21 17:35:37 -07:00
|
|
|
|
|
|
|
|
|
# The *_ lets us reuse this for ragged_dot_general, which has group_sizes.
|
|
|
|
|
lhs_aval, rhs_aval, *_ = ctx.avals_in
|
2023-05-12 19:56:59 -07:00
|
|
|
|
lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
accumulation_aval = aval_out
|
|
|
|
|
algorithm_kwarg = {}
|
|
|
|
|
if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
|
|
|
|
# The CPU backend silently ignores the algorithm spec, so we check here to
|
|
|
|
|
# make sure that the selected algorithm is supported. We could be a little
|
|
|
|
|
# bit more liberal here (any algorithm where the input and output types
|
|
|
|
|
# match and all the other parameters have default values should work), but
|
|
|
|
|
# it's probably sufficient to just check the presets here.
|
|
|
|
|
if platform == "cpu" and precision not in {
|
|
|
|
|
DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16,
|
|
|
|
|
DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64,
|
2024-12-07 11:13:37 -08:00
|
|
|
|
DotAlgorithmPreset.BF16_BF16_F32, DotAlgorithmPreset.BF16_BF16_F32_X3,
|
|
|
|
|
DotAlgorithmPreset.BF16_BF16_F32_X6,
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
}:
|
2024-09-25 06:16:22 -07:00
|
|
|
|
raise ValueError(
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
f"The precision '{precision}' is not supported by dot_general on CPU")
|
|
|
|
|
|
|
|
|
|
# If an explicit algorithm was specified, we always cast the input types to
|
|
|
|
|
# the correct types.
|
|
|
|
|
def maybe_convert_dtype(operand, operand_aval, target_dtype):
|
2024-11-12 05:29:40 -08:00
|
|
|
|
if target_dtype is None or operand_aval.dtype == target_dtype:
|
|
|
|
|
return operand
|
|
|
|
|
aval = core.ShapedArray(operand_aval.shape, target_dtype)
|
|
|
|
|
return mlir.convert_hlo(ctx, operand, operand_aval, aval)
|
|
|
|
|
|
|
|
|
|
lhs_dtype, rhs_dtype, accumulation_dtype = get_algorithm_compute_types(
|
|
|
|
|
precision, lhs_dtype, rhs_dtype, aval_out.dtype)
|
|
|
|
|
lhs = maybe_convert_dtype(lhs, lhs_aval, lhs_dtype)
|
|
|
|
|
rhs = maybe_convert_dtype(rhs, rhs_aval, rhs_dtype)
|
|
|
|
|
if accumulation_dtype is not None:
|
|
|
|
|
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_dtype)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
|
|
|
|
|
if precision != DotAlgorithmPreset.DEFAULT:
|
|
|
|
|
algorithm_kwarg = {
|
|
|
|
|
"algorithm": dot_algorithm_attr(precision, lhs_dtype, rhs_dtype)
|
|
|
|
|
}
|
2024-09-25 06:16:22 -07:00
|
|
|
|
else:
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
# TODO(b/...): JAX's dot_general primitive accepts the same input dtype
|
|
|
|
|
# combinations that are accepted in XLA's shape_inference.cc (the canonical
|
|
|
|
|
# reference for the HLO type system), but actually different XLA platforms
|
|
|
|
|
# fail on codegen for different accepted cases. To handle those cases, we
|
|
|
|
|
# insert ConvertOps on the input, in a platform-dependent way.
|
|
|
|
|
if lhs_dtype != rhs_dtype:
|
|
|
|
|
if platform == "tpu":
|
|
|
|
|
handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or
|
|
|
|
|
dtypes.issubdtype(dt, np.integer))
|
|
|
|
|
if not (handled(lhs_dtype) and handled(rhs_dtype)):
|
|
|
|
|
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
|
|
|
|
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
|
|
|
|
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
|
|
|
|
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
|
|
|
|
else: # cpu and gpu
|
|
|
|
|
# Do not convert mixed fp8 types to output type.
|
|
|
|
|
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
|
|
|
|
|
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
|
|
|
|
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
|
|
|
|
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
|
|
|
|
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
2025-03-21 17:35:37 -07:00
|
|
|
|
return lhs, rhs, accumulation_aval, algorithm_kwarg
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
|
2025-03-21 17:35:37 -07:00
|
|
|
|
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
|
|
|
|
precision, preferred_element_type: np.dtype | None,
|
|
|
|
|
out_sharding, platform: str = "default"):
|
|
|
|
|
del preferred_element_type # Implied by the output aval
|
|
|
|
|
lhs, rhs, accumulation_aval, algorithm_kwarg = _handle_dot_precision(
|
|
|
|
|
ctx, lhs, rhs, precision, platform
|
|
|
|
|
)
|
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
|
dot_dnums = hlo.DotDimensionNumbers.get(
|
|
|
|
|
lhs_batching_dimensions=list(lhs_batch),
|
|
|
|
|
rhs_batching_dimensions=list(rhs_batch),
|
|
|
|
|
lhs_contracting_dimensions=list(lhs_contracting),
|
|
|
|
|
rhs_contracting_dimensions=list(rhs_contracting))
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
result = hlo.dot_general(
|
|
|
|
|
mlir.aval_to_ir_type(accumulation_aval),
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
dot_dnums,
|
|
|
|
|
precision_config=precision_attr(precision),
|
|
|
|
|
**algorithm_kwarg,
|
|
|
|
|
)
|
2025-03-21 17:35:37 -07:00
|
|
|
|
aval_out, = ctx.avals_out
|
2025-02-22 10:45:18 -08:00
|
|
|
|
result = mlir.lower_with_sharding_in_types(ctx, result, aval_out)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
if accumulation_aval.dtype != aval_out.dtype:
|
|
|
|
|
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
|
|
|
|
|
return [result]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(dot_general_p, _dot_general_lower)
|
|
|
|
|
|
2023-10-06 04:57:16 -07:00
|
|
|
|
for platform in ["cpu", "tpu"]:
|
|
|
|
|
mlir.register_lowering(dot_general_p,
|
|
|
|
|
partial(_dot_general_lower, platform=platform),
|
|
|
|
|
platform=platform)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
class RaggedDotMode(enum.Enum):
|
|
|
|
|
RAGGED_NONCONTRACTING = 1 # [b,m,k], [g,b,k,n], [b,g] -> [b,m,n]
|
|
|
|
|
RAGGED_CONTRACTING = 2 # [b,m,k], [b,k,n], [b,g] -> [g,b,m,n]
|
|
|
|
|
RAGGED_BATCH = 3 # [b,m,k], [b,k,n], [g] -> [b,m,n]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ragged_dot_mode_and_dim(
|
|
|
|
|
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
|
|
|
|
) -> tuple[RaggedDotMode, int]:
|
|
|
|
|
assert len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) == 1
|
|
|
|
|
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
|
|
|
|
|
(lhs_contracting, _), (lhs_batch, _) = ragged_dot_dimension_numbers.dot_dimension_numbers
|
|
|
|
|
lhs_noncontracting = remaining(range(lhs_rank), lhs_contracting, lhs_batch)
|
|
|
|
|
if lhs_ragged_dim in lhs_noncontracting:
|
|
|
|
|
mode = RaggedDotMode.RAGGED_NONCONTRACTING
|
|
|
|
|
elif lhs_ragged_dim in lhs_contracting:
|
|
|
|
|
mode = RaggedDotMode.RAGGED_CONTRACTING
|
|
|
|
|
elif lhs_ragged_dim in lhs_batch:
|
|
|
|
|
mode = RaggedDotMode.RAGGED_BATCH
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f'lhs_ragged_dim {lhs_ragged_dim} not found in '
|
|
|
|
|
f'lhs_noncontracting {lhs_noncontracting}, '
|
|
|
|
|
f'lhs_contracting {lhs_contracting}, or '
|
|
|
|
|
f'lhs_batch {lhs_batch}.'
|
|
|
|
|
)
|
|
|
|
|
return mode, lhs_ragged_dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ragged_dot_mode(
|
|
|
|
|
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
|
|
|
|
) -> RaggedDotMode:
|
|
|
|
|
return _ragged_dot_mode_and_dim(lhs_rank, ragged_dot_dimension_numbers)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_ragged_contracting(
|
|
|
|
|
lhs_rank: int, ragged_dot_dimension_numbers: RaggedDotDimensionNumbers
|
|
|
|
|
) -> bool:
|
|
|
|
|
return (
|
|
|
|
|
_ragged_dot_mode(lhs_rank, ragged_dot_dimension_numbers)
|
|
|
|
|
== RaggedDotMode.RAGGED_CONTRACTING
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ragged_dot_prefix_dims(mode, rank, ragged_dim, batch, contract):
|
|
|
|
|
batch, contract = map(list, (batch, contract))
|
|
|
|
|
noncontract = remaining(range(rank), contract, batch)
|
|
|
|
|
match mode:
|
|
|
|
|
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
|
|
|
|
return batch + noncontract[: noncontract.index(ragged_dim)]
|
|
|
|
|
case RaggedDotMode.RAGGED_CONTRACTING:
|
|
|
|
|
return batch + contract[: contract.index(ragged_dim)]
|
|
|
|
|
case RaggedDotMode.RAGGED_BATCH:
|
|
|
|
|
return batch[: batch.index(ragged_dim)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ragged_dot_general_shape_rule(
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
group_sizes,
|
|
|
|
|
*,
|
|
|
|
|
ragged_dot_dimension_numbers,
|
|
|
|
|
precision,
|
|
|
|
|
preferred_element_type: DTypeLike | None,
|
|
|
|
|
**_,
|
|
|
|
|
):
|
|
|
|
|
def _check_in_range(dim, rank, dim_name, arg_name):
|
|
|
|
|
if dim < 0 or dim >= rank:
|
2024-10-03 16:41:31 -07:00
|
|
|
|
raise TypeError(
|
2025-03-10 12:24:38 -07:00
|
|
|
|
f'ragged_dot_general requires {dim_name} numbers to be nonnegative '
|
|
|
|
|
f'and less than the number of axes of the {arg_name} value, '
|
|
|
|
|
f'got {dim} for {arg_name} of rank {rank}.'
|
2024-10-03 16:41:31 -07:00
|
|
|
|
)
|
2025-03-10 12:24:38 -07:00
|
|
|
|
|
|
|
|
|
# Validate the lhs ragged dimension, and find out which mode we're in.
|
|
|
|
|
if len(ragged_dot_dimension_numbers.lhs_ragged_dimensions) != 1:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
'ragged_dot_general expects exactly one lhs ragged dimension.'
|
|
|
|
|
)
|
|
|
|
|
lhs_ragged_dim = ragged_dot_dimension_numbers.lhs_ragged_dimensions[0]
|
|
|
|
|
_check_in_range(lhs_ragged_dim, lhs.ndim, 'lhs ragged dimension', 'lhs')
|
|
|
|
|
mode = _ragged_dot_mode(lhs.ndim, ragged_dot_dimension_numbers)
|
|
|
|
|
|
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = (
|
|
|
|
|
ragged_dot_dimension_numbers.dot_dimension_numbers
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Validate the shape of group_sizes, if it is something other than [g].
|
|
|
|
|
if group_sizes.ndim == 0:
|
|
|
|
|
raise TypeError('expected rank of group_sizes to be >=1.')
|
|
|
|
|
if group_sizes.ndim != 1:
|
|
|
|
|
# Construct the expected shape [b...,x...,g] of group_sizes.
|
|
|
|
|
prefix_dims = _ragged_dot_prefix_dims(
|
|
|
|
|
mode, lhs.ndim, lhs_ragged_dim, lhs_batch, lhs_contracting
|
|
|
|
|
)
|
|
|
|
|
expected_gs_shape = tuple(lhs.shape[i] for i in prefix_dims)
|
|
|
|
|
expected_gs_shape += (group_sizes.shape[-1],)
|
|
|
|
|
# TODO(pravnar): Permit other broadcastable shapes.
|
|
|
|
|
if not core.definitely_equal_shape(group_sizes.shape, expected_gs_shape):
|
2024-10-03 16:41:31 -07:00
|
|
|
|
raise TypeError(
|
2025-03-10 12:24:38 -07:00
|
|
|
|
'expected group_sizes to have shape '
|
|
|
|
|
f'{expected_gs_shape}, got {group_sizes.shape}.'
|
2024-10-03 16:41:31 -07:00
|
|
|
|
)
|
2025-03-10 12:24:38 -07:00
|
|
|
|
num_groups = group_sizes.shape[-1]
|
|
|
|
|
|
|
|
|
|
# Validate properties of the rhs group dimension(s).
|
|
|
|
|
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
|
|
|
|
|
match mode:
|
|
|
|
|
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
|
|
|
|
if len(rhs_group_dims) != 0:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
'ragged_dot_general requires zero group dimensions in the rhs '
|
|
|
|
|
'when lhs ragged dimension is contracting or batch.'
|
|
|
|
|
)
|
|
|
|
|
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
|
|
|
|
if len(rhs_group_dims) != 1:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
'ragged_dot_general requires exactly one rhs group dimension '
|
|
|
|
|
'when lhs ragged dimension is noncontracting.'
|
|
|
|
|
)
|
|
|
|
|
rhs_group_dim = rhs_group_dims[0]
|
|
|
|
|
_check_in_range(rhs_group_dim, rhs.ndim, 'rhs group dimension', 'rhs')
|
|
|
|
|
if rhs_group_dim in rhs_batch or rhs_group_dim in rhs_contracting:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
'ragged_dot_general requires rhs group dimension numbers to be '
|
|
|
|
|
'distinct from contracting and batch dimensions.'
|
|
|
|
|
)
|
|
|
|
|
if rhs.shape[rhs_group_dim] != num_groups:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
'expected rhs group dimension size to be '
|
|
|
|
|
f'{num_groups}, got {rhs.shape[rhs_group_dim]}.'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
out_shape = _dot_general_shape_rule(
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
dimension_numbers=ragged_dot_dimension_numbers,
|
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
out_sharding=None,
|
|
|
|
|
)
|
|
|
|
|
if mode == RaggedDotMode.RAGGED_CONTRACTING:
|
|
|
|
|
out_shape = (num_groups,) + out_shape
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ragged_dot_general_dtype_rule(
|
|
|
|
|
lhs: Array,
|
|
|
|
|
rhs: Array,
|
|
|
|
|
group_sizes: Array,
|
|
|
|
|
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
|
|
|
|
precision,
|
|
|
|
|
preferred_element_type: DTypeLike | None,
|
|
|
|
|
**_,
|
|
|
|
|
) -> np.dtype:
|
2024-05-11 06:40:18 -07:00
|
|
|
|
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
|
2025-03-10 12:24:38 -07:00
|
|
|
|
raise TypeError(
|
|
|
|
|
'ragged_dot_general requires that '
|
|
|
|
|
'group_sizes.dtype is subtype of np.integer.'
|
|
|
|
|
)
|
|
|
|
|
# defer the output dtype to dot_general, which is part of the _ragged_dot_general_impl.
|
2024-10-22 13:10:05 -07:00
|
|
|
|
return _dot_general_dtype_rule(
|
2025-03-10 12:24:38 -07:00
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
|
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
out_sharding=None,
|
|
|
|
|
name='lax.ragged_dot_general',
|
|
|
|
|
)
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
2024-07-20 17:56:21 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def _ragged_dot_general_jvp_rule(
|
|
|
|
|
primals, tangents, ragged_dot_dimension_numbers,
|
|
|
|
|
precision, preferred_element_type, group_offset
|
2024-07-20 17:56:21 -07:00
|
|
|
|
):
|
|
|
|
|
# note - we could ostensibly just get this by passing on the
|
|
|
|
|
# value to ragged_dot below, but, this feels cleaner.
|
|
|
|
|
if group_offset is not None:
|
|
|
|
|
raise NotImplementedError('Unimplemented group_offset support.')
|
|
|
|
|
x, y, gs = primals
|
|
|
|
|
dx, dy, _ = tangents # no tan on the gs
|
|
|
|
|
|
|
|
|
|
# primal
|
2025-03-10 12:24:38 -07:00
|
|
|
|
primal_out = ragged_dot_general(
|
2024-07-20 17:56:21 -07:00
|
|
|
|
x,
|
|
|
|
|
y,
|
|
|
|
|
gs,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
2024-07-20 17:56:21 -07:00
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# tangent
|
|
|
|
|
dx_out = (
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_general(
|
2024-07-20 17:56:21 -07:00
|
|
|
|
dx,
|
|
|
|
|
y,
|
|
|
|
|
gs,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
2024-07-20 17:56:21 -07:00
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
)
|
|
|
|
|
if type(dx) is not ad_util.Zero
|
2024-09-01 07:49:49 -07:00
|
|
|
|
else _zeros(primal_out)
|
2024-07-20 17:56:21 -07:00
|
|
|
|
)
|
|
|
|
|
dy_out = (
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_general(
|
2024-07-20 17:56:21 -07:00
|
|
|
|
x,
|
|
|
|
|
dy,
|
|
|
|
|
gs,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
2024-07-20 17:56:21 -07:00
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
)
|
|
|
|
|
if type(dy) is not ad_util.Zero
|
2024-09-01 07:49:49 -07:00
|
|
|
|
else _zeros(primal_out)
|
2024-07-20 17:56:21 -07:00
|
|
|
|
)
|
2025-03-10 11:03:52 -04:00
|
|
|
|
tangent_out = add(dx_out, dy_out)
|
2024-07-20 17:56:21 -07:00
|
|
|
|
|
|
|
|
|
return primal_out, tangent_out
|
|
|
|
|
|
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def _ragged_dot_general_transpose_rule(
|
|
|
|
|
ct,
|
|
|
|
|
x,
|
|
|
|
|
y,
|
|
|
|
|
group_sizes,
|
|
|
|
|
*,
|
|
|
|
|
ragged_dot_dimension_numbers,
|
|
|
|
|
precision,
|
|
|
|
|
preferred_element_type: DTypeLike | None,
|
|
|
|
|
group_offset: Array | None,
|
2024-07-20 17:56:21 -07:00
|
|
|
|
):
|
|
|
|
|
if group_offset is not None:
|
|
|
|
|
raise NotImplementedError('Unimplemented group_offset support.')
|
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
(x_contract, y_contract), (x_batch, y_batch) = ragged_dot_dimension_numbers.dot_dimension_numbers
|
|
|
|
|
x_ndim = x.aval.ndim if ad.is_undefined_primal(x) else np.ndim(x)
|
|
|
|
|
y_ndim = y.aval.ndim if ad.is_undefined_primal(y) else np.ndim(y)
|
|
|
|
|
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
|
|
|
|
y_group = ragged_dot_dimension_numbers.rhs_group_dimensions
|
|
|
|
|
y_kept = remaining(range(y_ndim), y_contract, y_batch, y_group)
|
|
|
|
|
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
|
|
|
|
|
x_ndim, ragged_dot_dimension_numbers
|
|
|
|
|
)
|
2024-07-20 17:56:21 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
unimplemented = lambda fn_name, ragged_dot_mode: NotImplementedError(
|
|
|
|
|
f'Unimplemented {fn_name} for ragged dot general in mode '
|
|
|
|
|
f'{ragged_dot_mode.name}.'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# This is a hack to ensure we continue to emit the `_matrix_transpose` for the
|
|
|
|
|
# grad_x case. This isn't strictly necessary since we have dot_dim_nums.
|
|
|
|
|
# TODO(pravnar): Remove this once we no longer care to emit the transpose.
|
|
|
|
|
_is_basic_ragged_dot = (
|
|
|
|
|
x_ndim == 2
|
|
|
|
|
and y_ndim == 3
|
|
|
|
|
and ragged_dot_dimension_numbers == _BASIC_RAGGED_DOT_DIMENSION_NUMBERS
|
|
|
|
|
)
|
2024-07-20 17:56:21 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def grad_x_dims():
|
|
|
|
|
match mode:
|
|
|
|
|
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
|
|
|
|
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
|
|
|
|
dims = (
|
|
|
|
|
ragged_dot_dimension_numbers
|
|
|
|
|
if _is_basic_ragged_dot
|
|
|
|
|
else RaggedDotDimensionNumbers(
|
|
|
|
|
dot_dimension_numbers=((ans_y, y_kept), (ans_batch, y_batch)),
|
|
|
|
|
lhs_ragged_dimensions=[
|
|
|
|
|
len(x_batch) + x_kept.index(lhs_ragged_dim)
|
|
|
|
|
],
|
|
|
|
|
rhs_group_dimensions=y_group,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
x_contract_sorted_by_y = list(
|
|
|
|
|
np.take(x_contract, np.argsort(y_contract))
|
|
|
|
|
)
|
|
|
|
|
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
|
|
|
|
|
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
|
|
|
|
raise unimplemented('grad_x_dims', mode)
|
|
|
|
|
return dims, unsorted_axes
|
|
|
|
|
|
|
|
|
|
def grad_y_dims():
|
|
|
|
|
match mode:
|
|
|
|
|
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
|
|
|
|
ans_batch, ans_x, _ = ranges_like(x_batch, x_kept, y_kept)
|
|
|
|
|
dims = RaggedDotDimensionNumbers(
|
|
|
|
|
dot_dimension_numbers=((x_kept, ans_x), (x_batch, ans_batch)),
|
|
|
|
|
lhs_ragged_dimensions=[lhs_ragged_dim],
|
|
|
|
|
rhs_group_dimensions=[],
|
|
|
|
|
)
|
|
|
|
|
y_contract_sorted_by_x = list(
|
|
|
|
|
np.take(y_contract, np.argsort(x_contract))
|
|
|
|
|
)
|
|
|
|
|
unsorted_axes = (
|
|
|
|
|
list(y_group) + list(y_batch) + y_contract_sorted_by_x + y_kept
|
|
|
|
|
)
|
|
|
|
|
case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH:
|
|
|
|
|
raise unimplemented('grad_y_dims', mode)
|
|
|
|
|
return dims, unsorted_axes
|
|
|
|
|
|
|
|
|
|
def _ragged_dot_grad(lhs, rhs, dims_fn, aval):
|
|
|
|
|
dims, unsorted_axes = dims_fn()
|
|
|
|
|
ragged_dot_general_out = ragged_dot_general(
|
|
|
|
|
lhs, rhs, group_sizes, dims, precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
group_offset=group_offset)
|
|
|
|
|
result = transpose(ragged_dot_general_out, tuple(np.argsort(unsorted_axes)))
|
|
|
|
|
if result.dtype != aval.dtype:
|
|
|
|
|
result = _convert_element_type(result, aval.dtype, aval.weak_type)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
x_bar = (
|
|
|
|
|
None
|
|
|
|
|
if ad.is_undefined_primal(y)
|
|
|
|
|
else _ragged_dot_grad(ct,
|
|
|
|
|
_matrix_transpose(y) if _is_basic_ragged_dot else y,
|
|
|
|
|
grad_x_dims,
|
|
|
|
|
x.aval)
|
|
|
|
|
)
|
|
|
|
|
y_bar = (
|
|
|
|
|
None
|
|
|
|
|
if ad.is_undefined_primal(x)
|
|
|
|
|
else _ragged_dot_grad(x, ct, grad_y_dims, y.aval)
|
|
|
|
|
)
|
|
|
|
|
return x_bar, y_bar, None
|
2024-07-20 17:56:21 -07:00
|
|
|
|
|
|
|
|
|
|
2024-10-03 16:41:31 -07:00
|
|
|
|
def _ragged_dot_batch_unpack_args(batched_args):
|
|
|
|
|
lhs, rhs, _ = batched_args
|
|
|
|
|
return (lhs, rhs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ragged_dot_batch_unpack_dims(batch_dims):
|
|
|
|
|
if not all(dim == 0 for dim in batch_dims):
|
|
|
|
|
raise NotImplementedError('ragged_dot vmap over any dim but 0 - NYI')
|
|
|
|
|
lbd, rbd, _ = batch_dims
|
|
|
|
|
return (lbd, rbd)
|
|
|
|
|
|
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def _ragged_dot_general_invoke_prim(
|
2024-10-03 16:41:31 -07:00
|
|
|
|
group_sizes,
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
new_ragged_dot_dimension_numbers,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
precision,
|
|
|
|
|
preferred_element_type,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
):
|
2025-01-16 18:16:12 -08:00
|
|
|
|
del out_sharding
|
2025-03-10 12:24:38 -07:00
|
|
|
|
return ragged_dot_general(
|
2024-10-03 16:41:31 -07:00
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
group_sizes,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_dimension_numbers=new_ragged_dot_dimension_numbers,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def _ragged_dot_general_batch_rule(
|
2025-01-29 09:33:44 -08:00
|
|
|
|
axis_data,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
batched_args,
|
|
|
|
|
batch_dims,
|
|
|
|
|
*,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_dimension_numbers,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
precision,
|
|
|
|
|
preferred_element_type: DTypeLike | None,
|
|
|
|
|
**_,
|
|
|
|
|
):
|
2025-03-10 12:24:38 -07:00
|
|
|
|
invoke = partial(_ragged_dot_general_invoke_prim, batched_args[2])
|
|
|
|
|
batched_out, result_batch_dim = _dot_batch_rule(
|
2024-10-03 16:41:31 -07:00
|
|
|
|
_ragged_dot_batch_unpack_args,
|
|
|
|
|
_ragged_dot_batch_unpack_dims,
|
|
|
|
|
invoke,
|
2025-01-29 09:33:44 -08:00
|
|
|
|
axis_data,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
batched_args,
|
|
|
|
|
batch_dims,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
dimension_numbers=ragged_dot_dimension_numbers,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
2025-01-16 18:16:12 -08:00
|
|
|
|
out_sharding=None,
|
2024-10-03 16:41:31 -07:00
|
|
|
|
)
|
2025-03-10 12:24:38 -07:00
|
|
|
|
if _is_ragged_contracting(batched_args[0].ndim - 1,
|
|
|
|
|
ragged_dot_dimension_numbers):
|
|
|
|
|
result_batch_dim += 1
|
|
|
|
|
return batched_out, result_batch_dim
|
|
|
|
|
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_general_p = standard_primitive(
|
|
|
|
|
_ragged_dot_general_shape_rule,
|
|
|
|
|
_ragged_dot_general_dtype_rule,
|
|
|
|
|
'ragged_dot_general',
|
|
|
|
|
)
|
|
|
|
|
ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule
|
|
|
|
|
ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule
|
|
|
|
|
batching.fancy_primitive_batchers[ragged_dot_general_p] = _ragged_dot_general_batch_rule
|
|
|
|
|
batching.skippable_batchers[ragged_dot_general_p] = lambda _: ()
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def _ragged_dot_general_impl(
|
2024-05-11 06:40:18 -07:00
|
|
|
|
lhs: Array,
|
|
|
|
|
rhs: Array,
|
|
|
|
|
group_sizes: Array,
|
2025-03-10 12:24:38 -07:00
|
|
|
|
ragged_dot_dimension_numbers: RaggedDotDimensionNumbers,
|
2024-05-11 06:40:18 -07:00
|
|
|
|
precision: PrecisionLike = None,
|
|
|
|
|
preferred_element_type: DTypeLike | None = None,
|
|
|
|
|
group_offset: Array | None = None,
|
|
|
|
|
) -> Array:
|
|
|
|
|
if group_offset is not None:
|
|
|
|
|
raise NotImplementedError("Unimplemented group_offset support.")
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def ragged_to_dense(x: Array, gs: Array, *, dim: int):
|
|
|
|
|
from jax._src.lax import control_flow # avoid circular imports
|
|
|
|
|
assert gs.ndim == 1
|
|
|
|
|
shape = gs.shape + x.shape
|
|
|
|
|
x = broadcast_in_dim(x, shape, list(range(1, len(shape))))
|
|
|
|
|
iota = broadcasted_iota(gs.dtype, shape, dim+1)
|
|
|
|
|
group_ends = control_flow.cumsum(gs)
|
|
|
|
|
group_starts = concatenate(
|
|
|
|
|
[_zeros(gs)[:1], group_ends[:-1]],
|
|
|
|
|
dimension=0,
|
|
|
|
|
)
|
|
|
|
|
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
|
|
|
|
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
|
|
|
|
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
|
|
|
|
x = select(mask, x, _zeros(x))
|
|
|
|
|
return x
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
def batched_ragged_to_dense(dim, *x_in_axes: int):
|
|
|
|
|
if not x_in_axes:
|
|
|
|
|
return partial(ragged_to_dense, dim=dim)
|
|
|
|
|
x_axis, *rest = x_in_axes
|
|
|
|
|
decr = lambda d: d - 1 if d >= x_axis else d
|
|
|
|
|
return api.vmap(
|
|
|
|
|
batched_ragged_to_dense(decr(dim), *[decr(ax) for ax in rest]),
|
|
|
|
|
in_axes=(x_axis, 0),
|
|
|
|
|
)
|
2024-10-03 16:41:31 -07:00
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
incr = lambda dims: [d + 1 for d in dims]
|
|
|
|
|
|
|
|
|
|
# Expand the ragged `dim` of `x`, given its batching `axes`.
|
|
|
|
|
# The group axis from `gs` becomes the outermost axis of the result.
|
|
|
|
|
# Some examples:
|
|
|
|
|
# x: [m,k] , gs: [g] ==> expand(x, 0, gs): [g,m,k]
|
|
|
|
|
# x: [b1,m,b2,k], gs: [b1,b2,g] ==> expand(x, 1, gs, 0, 2): [g,b1,m,b2,k]
|
|
|
|
|
def expand(x, dim, gs, *axes):
|
|
|
|
|
expanded = batched_ragged_to_dense(dim, *axes)(x, gs)
|
|
|
|
|
unsorted_dims = incr(axes) + [0] + incr(remaining(range(x.ndim), axes))
|
|
|
|
|
return transpose(expanded, np.argsort(unsorted_dims))
|
|
|
|
|
|
|
|
|
|
mode, lhs_ragged_dim = _ragged_dot_mode_and_dim(
|
|
|
|
|
lhs.ndim, ragged_dot_dimension_numbers
|
|
|
|
|
)
|
|
|
|
|
(l_contract, r_contract), (l_batch, r_batch) = (
|
|
|
|
|
ragged_dot_dimension_numbers.dot_dimension_numbers
|
|
|
|
|
)
|
|
|
|
|
l_prefix = _ragged_dot_prefix_dims(
|
|
|
|
|
mode, lhs.ndim, lhs_ragged_dim, l_batch, l_contract
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_dot_general = partial(
|
|
|
|
|
dot_general,
|
2024-05-11 06:40:18 -07:00
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
)
|
2025-03-10 12:24:38 -07:00
|
|
|
|
# TODO(pravnar): Permit other broadcastable shapes.
|
|
|
|
|
if group_sizes.ndim == 1:
|
|
|
|
|
group_sizes = broadcast(group_sizes, [lhs.shape[i] for i in l_prefix])
|
|
|
|
|
|
|
|
|
|
match mode:
|
|
|
|
|
case RaggedDotMode.RAGGED_NONCONTRACTING:
|
|
|
|
|
rhs_group_dims = ragged_dot_dimension_numbers.rhs_group_dimensions
|
|
|
|
|
assert len(rhs_group_dims) == 1
|
|
|
|
|
return _dot_general(
|
|
|
|
|
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
|
|
|
|
|
rhs,
|
|
|
|
|
dimension_numbers=(
|
|
|
|
|
(incr(l_contract) + [0], list(r_contract) + [rhs_group_dims[0]]),
|
|
|
|
|
(incr(l_batch), r_batch),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
case RaggedDotMode.RAGGED_CONTRACTING:
|
|
|
|
|
rhs_ragged_dim = r_contract[l_contract.index(lhs_ragged_dim)]
|
|
|
|
|
r_prefix = _ragged_dot_prefix_dims(
|
|
|
|
|
mode, rhs.ndim, rhs_ragged_dim, r_batch, r_contract
|
|
|
|
|
)
|
|
|
|
|
return _dot_general(
|
|
|
|
|
expand(lhs, lhs_ragged_dim, group_sizes, *l_prefix),
|
|
|
|
|
expand(rhs, rhs_ragged_dim, group_sizes, *r_prefix),
|
|
|
|
|
dimension_numbers=(
|
|
|
|
|
(incr(l_contract), incr(r_contract)),
|
|
|
|
|
([0] + incr(l_batch), [0] + incr(r_batch)),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
case RaggedDotMode.RAGGED_BATCH:
|
|
|
|
|
return _dot_general(
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers,
|
|
|
|
|
)
|
|
|
|
|
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
2025-03-21 17:35:37 -07:00
|
|
|
|
def _ragged_dot_general_lower(
|
|
|
|
|
ctx,
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
group_sizes,
|
|
|
|
|
*,
|
|
|
|
|
ragged_dot_dimension_numbers,
|
|
|
|
|
precision,
|
|
|
|
|
preferred_element_type: np.dtype | None,
|
|
|
|
|
group_offset: Array | None = None,
|
|
|
|
|
platform: str = 'default',
|
|
|
|
|
):
|
|
|
|
|
if group_offset is not None:
|
|
|
|
|
raise NotImplementedError('Unimplemented group_offset support.')
|
|
|
|
|
|
|
|
|
|
# TODO(pravnar): Remove this once we have sharding support.
|
|
|
|
|
def use_default_lowering():
|
|
|
|
|
axis_context = ctx.module_context.axis_context
|
|
|
|
|
return (
|
|
|
|
|
isinstance(axis_context, SPMDAxisContext)
|
|
|
|
|
or isinstance(axis_context, ShardingContext)
|
|
|
|
|
and axis_context.num_devices > 1
|
|
|
|
|
)
|
|
|
|
|
if use_default_lowering():
|
|
|
|
|
result = mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)(
|
|
|
|
|
ctx, lhs, rhs, group_sizes,
|
|
|
|
|
ragged_dot_dimension_numbers=ragged_dot_dimension_numbers,
|
|
|
|
|
precision=precision,
|
|
|
|
|
preferred_element_type=preferred_element_type,
|
|
|
|
|
group_offset=group_offset
|
|
|
|
|
)
|
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
|
return mlir.lower_with_sharding_in_types(ctx, result, aval_out)
|
|
|
|
|
|
|
|
|
|
del preferred_element_type # Implied by the output aval
|
|
|
|
|
lhs, rhs, accumulation_aval, _ = _handle_dot_precision(
|
|
|
|
|
ctx, lhs, rhs, precision, platform
|
|
|
|
|
)
|
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = (
|
|
|
|
|
ragged_dot_dimension_numbers.dot_dimension_numbers
|
|
|
|
|
)
|
|
|
|
|
ragged_dot_dnums = chlo.RaggedDotDimensionNumbers.get(
|
|
|
|
|
lhs_batching_dimensions=list(lhs_batch),
|
|
|
|
|
rhs_batching_dimensions=list(rhs_batch),
|
|
|
|
|
lhs_contracting_dimensions=list(lhs_contracting),
|
|
|
|
|
rhs_contracting_dimensions=list(rhs_contracting),
|
|
|
|
|
lhs_ragged_dimensions=list(
|
|
|
|
|
ragged_dot_dimension_numbers.lhs_ragged_dimensions
|
|
|
|
|
),
|
|
|
|
|
rhs_group_dimensions=list(
|
|
|
|
|
ragged_dot_dimension_numbers.rhs_group_dimensions
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
result = chlo.ragged_dot(
|
|
|
|
|
mlir.aval_to_ir_type(accumulation_aval),
|
|
|
|
|
lhs,
|
|
|
|
|
rhs,
|
|
|
|
|
group_sizes,
|
|
|
|
|
ragged_dot_dnums,
|
|
|
|
|
precision_config=chlo_precision_attr(precision),
|
|
|
|
|
)
|
|
|
|
|
(aval_out,) = ctx.avals_out
|
|
|
|
|
result = mlir.lower_with_sharding_in_types(ctx, result, aval_out)
|
|
|
|
|
if accumulation_aval.dtype != aval_out.dtype:
|
|
|
|
|
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
|
|
|
|
|
return [result]
|
|
|
|
|
|
|
|
|
|
|
2025-03-10 12:24:38 -07:00
|
|
|
|
mlir.register_lowering(ragged_dot_general_p,
|
|
|
|
|
mlir.lower_fun(_ragged_dot_general_impl,
|
|
|
|
|
multiple_results=False))
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
2025-03-21 17:35:37 -07:00
|
|
|
|
for platform in ['tpu']:
|
|
|
|
|
mlir.register_lowering(
|
|
|
|
|
ragged_dot_general_p,
|
|
|
|
|
partial(_ragged_dot_general_lower, platform=platform),
|
|
|
|
|
platform=platform,
|
|
|
|
|
)
|
|
|
|
|
|
2024-05-11 06:40:18 -07:00
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions,
|
|
|
|
|
sharding):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_check_shapelike('broadcast_in_dim', 'shape', shape)
|
|
|
|
|
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
|
|
|
|
|
broadcast_dimensions)
|
|
|
|
|
operand_ndim = np.ndim(operand)
|
|
|
|
|
if operand_ndim != len(broadcast_dimensions):
|
|
|
|
|
msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
|
|
|
|
|
'operand ndim; got broadcast_dimensions {} for operand ndim {}.')
|
|
|
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim))
|
|
|
|
|
if len(shape) < operand_ndim:
|
|
|
|
|
msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank '
|
|
|
|
|
'to the operand shape; got operand ndim {} and target broadcast ndim {}.')
|
|
|
|
|
raise TypeError(msg.format(operand_ndim, len(shape)))
|
|
|
|
|
if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
|
|
|
|
|
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
|
|
|
|
'dimensions, got {} for operand ndim {} and shape {}.')
|
|
|
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if not all(core.definitely_equal_one_of_dim(operand.shape[i],
|
|
|
|
|
[1, shape[broadcast_dimensions[i]]])
|
2021-04-01 15:37:01 +03:00
|
|
|
|
for i in range(operand_ndim)):
|
2020-11-03 15:26:22 +01:00
|
|
|
|
msg = (
|
|
|
|
|
"broadcast_in_dim operand dimension sizes must either be 1, or be "
|
|
|
|
|
"equal to their corresponding dimensions in the target broadcast "
|
|
|
|
|
"shape; got operand of shape {}, target broadcast shape {}, "
|
|
|
|
|
"broadcast_dimensions {} ")
|
2023-06-13 15:56:36 -04:00
|
|
|
|
raise TypeError(msg.format(
|
2023-09-21 22:19:29 +01:00
|
|
|
|
tuple(core.replace_tracer_for_error_message(d) for d in operand.shape),
|
2023-06-13 15:56:36 -04:00
|
|
|
|
shape, broadcast_dimensions))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
|
|
|
|
|
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
|
2020-11-03 15:26:22 +01:00
|
|
|
|
msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; "
|
|
|
|
|
"got broadcast_dimensions {}")
|
|
|
|
|
raise TypeError(msg.format(broadcast_dimensions))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return shape
|
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions,
|
|
|
|
|
sharding):
|
|
|
|
|
if sharding is not None:
|
|
|
|
|
return sharding
|
2024-10-17 14:54:26 -07:00
|
|
|
|
bds = set(broadcast_dimensions)
|
|
|
|
|
orig_spec = iter(operand.sharding.spec)
|
|
|
|
|
new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))]
|
|
|
|
|
assert next(orig_spec, None) is None
|
2024-11-22 11:01:20 -08:00
|
|
|
|
return operand.sharding.with_spec(new_spec)
|
2024-10-17 14:54:26 -07:00
|
|
|
|
|
2022-06-11 15:46:05 -07:00
|
|
|
|
def _broadcast_in_dim_typecheck_rule(
|
2024-10-25 10:34:33 -07:00
|
|
|
|
_, operand, *dyn_shape, shape, broadcast_dimensions, sharding):
|
2022-06-11 15:46:05 -07:00
|
|
|
|
if not dyn_shape:
|
|
|
|
|
out_aval, effects = broadcast_in_dim_p.abstract_eval(
|
2024-10-25 10:34:33 -07:00
|
|
|
|
operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions,
|
|
|
|
|
sharding=sharding)
|
2022-06-11 15:46:05 -07:00
|
|
|
|
return [out_aval], effects
|
|
|
|
|
else:
|
|
|
|
|
# TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule
|
2021-11-16 11:17:42 +02:00
|
|
|
|
out_shape = _merge_dyn_shape(shape, dyn_shape)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
2022-06-11 15:46:05 -07:00
|
|
|
|
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
|
|
|
|
|
operand.aval.weak_type)
|
|
|
|
|
return [out_aval], core.no_effects
|
|
|
|
|
|
2022-06-23 15:29:46 -07:00
|
|
|
|
def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
|
2024-10-25 10:34:33 -07:00
|
|
|
|
shape, broadcast_dimensions, sharding):
|
2022-06-23 15:29:46 -07:00
|
|
|
|
if type(ct) is ad_util.Zero:
|
|
|
|
|
return [ad_util.Zero(operand.aval)]
|
2022-06-17 15:53:53 -07:00
|
|
|
|
unit_dims = [i for i, s in enumerate(operand.aval.shape)
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if core.definitely_equal(s, 1)]
|
2022-06-17 15:53:53 -07:00
|
|
|
|
bdims = tuple(np.delete(broadcast_dimensions, unit_dims))
|
2020-12-29 10:30:22 -08:00
|
|
|
|
axes = tuple(np.delete(range(len(shape)), bdims))
|
2025-02-11 16:00:03 -08:00
|
|
|
|
return ([expand_dims(reduce_sum(ct, axes), unit_dims)] +
|
2022-06-17 15:53:53 -07:00
|
|
|
|
[None] * len(dyn_shape))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-01-29 09:33:44 -08:00
|
|
|
|
def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
|
2024-10-25 10:34:33 -07:00
|
|
|
|
broadcast_dimensions, sharding):
|
2023-06-04 10:18:36 -04:00
|
|
|
|
# `dyn_shape` is the dynamic portion of the target shape. `shape`
|
|
|
|
|
# is the target shape, with `None` for dynamic sections.
|
|
|
|
|
# broadcast_dimensions gives indices where dimensions of the input
|
|
|
|
|
# have to go: dimension i of the input becomes dimension
|
|
|
|
|
# broadcast_dimensions[i] of the output.
|
2022-10-17 11:15:14 -07:00
|
|
|
|
operand, *dyn_shape = batched_args
|
|
|
|
|
operand_bdim, *dyn_shape_bdims = batch_dims
|
2023-06-04 10:18:36 -04:00
|
|
|
|
|
|
|
|
|
stacked_size = None
|
|
|
|
|
if operand_bdim is not None:
|
|
|
|
|
if isinstance(operand_bdim, RaggedAxis):
|
|
|
|
|
stacked_axis = operand_bdim.stacked_axis
|
|
|
|
|
stacked_size = operand_bdim.size
|
|
|
|
|
else:
|
2025-01-29 09:33:44 -08:00
|
|
|
|
stacked_axis = operand_bdim
|
2023-06-04 10:18:36 -04:00
|
|
|
|
stacked_size = operand.shape[stacked_axis]
|
2025-01-29 09:33:44 -08:00
|
|
|
|
new_operand = batching.moveaxis(operand, stacked_axis, 0)
|
2022-10-17 11:15:14 -07:00
|
|
|
|
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
|
|
|
|
|
else:
|
2023-06-04 10:18:36 -04:00
|
|
|
|
new_operand = operand
|
|
|
|
|
new_broadcast_dimensions = tuple(np.add(1, broadcast_dimensions))
|
|
|
|
|
|
2023-06-13 10:44:52 -04:00
|
|
|
|
# TODO(mattjj,axch) This section assumes that the shape of the operand is
|
|
|
|
|
# broadcast-compatible with the requested shape. We should tweak vmap to run
|
|
|
|
|
# the abstract_eval rule so this can be checked while the raggedness
|
|
|
|
|
# information is available.
|
2023-06-04 10:18:36 -04:00
|
|
|
|
dyn_limits = []
|
|
|
|
|
out_ragged_sizes = []
|
|
|
|
|
for sizes, bdim in zip(dyn_shape, dyn_shape_bdims):
|
|
|
|
|
if bdim is None:
|
|
|
|
|
# TODO(mattjj,axch) Is this what bdim == None means?
|
|
|
|
|
assert isinstance(sizes, int)
|
|
|
|
|
bound = sizes
|
|
|
|
|
else:
|
|
|
|
|
bound = sizes.dtype.bound
|
|
|
|
|
out_ragged_sizes.append(sizes)
|
|
|
|
|
if stacked_size is None:
|
|
|
|
|
stacked_size = len(sizes)
|
|
|
|
|
else:
|
|
|
|
|
msg = "All segments lengths arrays must be the same length"
|
|
|
|
|
assert len(sizes) == stacked_size, msg
|
|
|
|
|
dyn_limits.append(bound)
|
|
|
|
|
new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
|
2025-01-29 09:33:44 -08:00
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
if sharding is not None:
|
2025-01-29 09:33:44 -08:00
|
|
|
|
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
|
|
|
|
|
|
2025-01-28 11:04:05 -08:00
|
|
|
|
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions,
|
2025-02-12 13:58:38 -08:00
|
|
|
|
out_sharding=sharding)
|
2023-06-04 10:18:36 -04:00
|
|
|
|
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
|
2023-06-05 15:01:47 -04:00
|
|
|
|
out_bdim = batching.make_batch_axis(
|
|
|
|
|
result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes))
|
2023-06-04 10:18:36 -04:00
|
|
|
|
return result, out_bdim
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-11-15 21:21:29 -08:00
|
|
|
|
def _broadcast_in_dim_fwd_rule(eqn):
|
2022-01-20 22:58:09 -08:00
|
|
|
|
v, *dyn = eqn.invars
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape):
|
2021-11-15 21:21:29 -08:00
|
|
|
|
return [v], None
|
|
|
|
|
else:
|
|
|
|
|
return [None], eqn
|
|
|
|
|
|
2022-01-20 22:58:09 -08:00
|
|
|
|
def _broadcast_in_dim_staging_rule(
|
2024-10-25 10:34:33 -07:00
|
|
|
|
trace, x, *dyn, shape, broadcast_dimensions, sharding):
|
|
|
|
|
params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
|
|
|
|
|
sharding=sharding)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if not dyn:
|
|
|
|
|
return trace.default_process_primitive(broadcast_in_dim_p, (x,), params)
|
|
|
|
|
aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type)
|
|
|
|
|
return _dyn_shape_staging_rule(trace, broadcast_in_dim_p, aval, x, *dyn,
|
|
|
|
|
**params)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
|
def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
|
|
|
|
|
shape, broadcast_dimensions):
|
|
|
|
|
del in_avals, dyn_shape
|
|
|
|
|
out_aval, = out_avals
|
|
|
|
|
new_shape = []
|
|
|
|
|
new_dyn_shape = []
|
|
|
|
|
for d in out_aval.shape:
|
|
|
|
|
if type(d) is pe.BoundedAxisSize:
|
|
|
|
|
new_shape.append(d.bound)
|
|
|
|
|
elif type(d) is int:
|
|
|
|
|
new_shape.append(d)
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(d, core.Tracer)
|
|
|
|
|
new_shape.append(None)
|
|
|
|
|
new_dyn_shape.append(d)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape),
|
2022-03-30 17:52:55 -07:00
|
|
|
|
broadcast_dimensions=broadcast_dimensions)]
|
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions,
|
|
|
|
|
sharding):
|
2022-06-23 15:29:46 -07:00
|
|
|
|
operand, *dyn_shape = primals
|
|
|
|
|
operand_dot, *_ = tangents
|
2022-06-23 20:53:45 -07:00
|
|
|
|
y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape,
|
2024-10-25 10:34:33 -07:00
|
|
|
|
broadcast_dimensions=broadcast_dimensions,
|
|
|
|
|
sharding=sharding)
|
2022-06-23 20:53:45 -07:00
|
|
|
|
if type(operand_dot) is ad_util.Zero:
|
2024-09-18 13:43:14 -07:00
|
|
|
|
y_dot = ad_util.Zero.from_primal_value(y)
|
2022-06-23 20:53:45 -07:00
|
|
|
|
else:
|
|
|
|
|
y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape,
|
2024-10-25 10:34:33 -07:00
|
|
|
|
broadcast_dimensions=broadcast_dimensions,
|
|
|
|
|
sharding=sharding)
|
2022-06-23 15:29:46 -07:00
|
|
|
|
return y, y_dot
|
2022-06-17 15:53:53 -07:00
|
|
|
|
|
|
|
|
|
def _broadcast_in_dim_partial_eval(
|
2024-10-25 10:34:33 -07:00
|
|
|
|
trace, operand, *dyn_shape, shape, broadcast_dimensions, sharding):
|
2022-06-17 15:53:53 -07:00
|
|
|
|
if not dyn_shape:
|
|
|
|
|
return trace.default_process_primitive(
|
|
|
|
|
broadcast_in_dim_p, (operand, *dyn_shape),
|
2024-10-25 10:34:33 -07:00
|
|
|
|
dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
|
|
|
|
|
sharding=sharding))
|
2022-06-17 15:53:53 -07:00
|
|
|
|
assert all(t.pval.is_known() for t in dyn_shape)
|
|
|
|
|
operand_tracer = trace.instantiate_const(operand)
|
|
|
|
|
dyn_shape_tracers = map(trace.instantiate_const, dyn_shape)
|
|
|
|
|
dyn_shape_tracers_ = iter(dyn_shape_tracers)
|
|
|
|
|
shape_ = [next(dyn_shape_tracers_) if d is None else d for d in shape]
|
|
|
|
|
out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type)
|
|
|
|
|
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
|
|
|
|
eqn = pe.new_eqn_recipe(
|
|
|
|
|
[operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
|
2024-10-25 10:34:33 -07:00
|
|
|
|
dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
|
|
|
|
|
sharding=None),
|
2022-06-17 15:53:53 -07:00
|
|
|
|
core.no_effects, source_info_util.current())
|
|
|
|
|
out_tracer.recipe = eqn
|
|
|
|
|
return out_tracer
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
|
|
|
|
|
sharding) -> Sequence[ir.Value]:
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
2022-04-13 13:44:42 -07:00
|
|
|
|
if dyn_shape:
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
2024-10-17 14:54:26 -07:00
|
|
|
|
out = mlir.broadcast_in_dim(ctx, x, aval_out,
|
|
|
|
|
broadcast_dimensions=broadcast_dimensions)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2022-06-17 15:53:53 -07:00
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
|
|
|
|
|
sharding):
|
2023-06-13 15:56:36 -04:00
|
|
|
|
if (not dyn_shape and
|
|
|
|
|
not any(isinstance(d, core.DArray) and
|
|
|
|
|
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
2022-06-29 13:55:30 -07:00
|
|
|
|
shape = _broadcast_in_dim_shape_rule( # error checking
|
2024-10-25 10:34:33 -07:00
|
|
|
|
x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None)
|
2025-02-19 06:52:52 -08:00
|
|
|
|
new_sharding = _broadcast_in_dim_sharding_rule(
|
|
|
|
|
x, shape=shape, broadcast_dimensions=broadcast_dimensions,
|
|
|
|
|
sharding=sharding)
|
2025-03-27 16:55:45 -07:00
|
|
|
|
new_vma = (core.standard_vma_rule('broadcast_in_dim', x)
|
2025-03-25 17:02:45 -07:00
|
|
|
|
if config.varying_axes_in_types.value else frozenset())
|
|
|
|
|
return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding,
|
|
|
|
|
vma=new_vma)
|
2023-06-13 15:56:36 -04:00
|
|
|
|
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
|
|
|
|
|
# (even if x is a ShapedArray)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
2023-06-13 15:56:36 -04:00
|
|
|
|
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
|
2024-10-14 14:00:58 -07:00
|
|
|
|
|
2024-10-25 12:06:59 -07:00
|
|
|
|
def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
|
2024-10-14 14:00:58 -07:00
|
|
|
|
assert len(invar_raggedness) == 1
|
|
|
|
|
assert not isinstance(invar_raggedness[0], core.Var)
|
|
|
|
|
return invar_raggedness, [None] * len(outvars)
|
|
|
|
|
|
|
|
|
|
|
2022-06-17 15:53:53 -07:00
|
|
|
|
broadcast_in_dim_p = standard_primitive(
|
|
|
|
|
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
2022-06-29 13:55:30 -07:00
|
|
|
|
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
|
2022-06-17 15:53:53 -07:00
|
|
|
|
ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule
|
|
|
|
|
ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule
|
2025-01-29 09:33:44 -08:00
|
|
|
|
batching.fancy_primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
|
|
|
|
batching.skippable_batchers[broadcast_in_dim_p] = lambda _: ()
|
2022-06-17 15:53:53 -07:00
|
|
|
|
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
|
|
|
|
|
pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_partial_eval
|
|
|
|
|
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
|
|
|
|
|
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
|
|
|
|
|
core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[broadcast_in_dim_p] = (
|
|
|
|
|
_broadcast_in_dim_ragged_prop_rule
|
|
|
|
|
)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _clamp_shape_rule(min, operand, max):
|
|
|
|
|
if min.shape and min.shape != operand.shape:
|
2021-11-11 06:36:31 -08:00
|
|
|
|
raise TypeError("clamp requires min.shape == operand.shape or min.shape == "
|
2022-12-01 09:12:01 -08:00
|
|
|
|
f"(), got min.shape={min.shape}, {operand.shape=}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if max.shape and max.shape != operand.shape:
|
2021-11-11 06:36:31 -08:00
|
|
|
|
raise TypeError("clamp requires max.shape == operand.shape or max.shape == "
|
2022-12-01 09:12:01 -08:00
|
|
|
|
f"(), got max.shape={max.shape}, {operand.shape=}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return operand.shape
|
|
|
|
|
|
2025-01-27 21:57:36 -08:00
|
|
|
|
def _clamp_sharding_rule(min, operand, max):
|
|
|
|
|
return operand.sharding
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
|
|
|
|
|
'clamp')
|
|
|
|
|
|
2021-06-20 19:09:13 -07:00
|
|
|
|
def _clamp_batch_rule(batched_args, batch_dims, **params):
|
|
|
|
|
min, x, max = batched_args
|
|
|
|
|
min_bdim, x_bdim, max_bdim = batch_dims
|
|
|
|
|
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
|
|
|
|
if i is not None)
|
|
|
|
|
|
|
|
|
|
# avoid transposes and some broadcasts in special cases
|
|
|
|
|
if min_bdim == x_bdim == max_bdim:
|
|
|
|
|
if np.shape(min) == np.shape(x) == np.shape(max):
|
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
|
elif np.ndim(min) == np.ndim(max) == 0:
|
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
|
elif np.ndim(min) == np.ndim(max) == 1:
|
|
|
|
|
min = broadcast_in_dim(min, x.shape, [min_bdim])
|
|
|
|
|
max = broadcast_in_dim(max, x.shape, [max_bdim])
|
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
|
elif np.ndim(min) == 0 and np.ndim(max) == 0 and x_bdim is not None:
|
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
|
|
|
|
|
|
min = batching.bdim_at_front(min, min_bdim, size) if np.shape(min) else min
|
|
|
|
|
max = batching.bdim_at_front(max, max_bdim, size) if np.shape(max) else max
|
|
|
|
|
x = batching.bdim_at_front(x, x_bdim, size) if np.shape(x) else x
|
|
|
|
|
if np.ndim(min) == 0 and np.ndim(x) > 0:
|
|
|
|
|
min = broadcast(min, x.shape)
|
|
|
|
|
if np.ndim(max) == 0 and np.ndim(x) > 0:
|
|
|
|
|
max = broadcast(max, x.shape)
|
|
|
|
|
if 0 < np.ndim(min) < np.ndim(x):
|
|
|
|
|
assert np.ndim(min) == 1, np.ndim(min)
|
|
|
|
|
min = broadcast_in_dim(min, x.shape, [0])
|
|
|
|
|
if 0 < np.ndim(max) < np.ndim(x):
|
|
|
|
|
assert np.ndim(max) == 1, np.ndim(max)
|
|
|
|
|
max = broadcast_in_dim(max, x.shape, [0])
|
|
|
|
|
if np.ndim(min) > np.ndim(x):
|
|
|
|
|
assert np.ndim(x) == 0, np.ndim(x)
|
|
|
|
|
x = broadcast(x, min.shape)
|
|
|
|
|
return clamp_p.bind(min, x, max), 0
|
|
|
|
|
|
2025-01-27 21:57:36 -08:00
|
|
|
|
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_clamp_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'clamp'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp(clamp_p,
|
|
|
|
|
lambda g, min, operand, max:
|
|
|
|
|
select(bitwise_and(gt(min, operand), lt(min, max)),
|
2021-04-15 15:16:29 -07:00
|
|
|
|
g, _zeros(operand)),
|
2020-10-17 14:33:26 -04:00
|
|
|
|
lambda g, min, operand, max:
|
|
|
|
|
select(bitwise_and(gt(operand, min), lt(operand, max)),
|
|
|
|
|
g, _zeros(operand)),
|
|
|
|
|
lambda g, min, operand, max:
|
2021-04-15 15:16:29 -07:00
|
|
|
|
select(lt(max, operand), g, _zeros(operand)))
|
2021-06-20 19:09:13 -07:00
|
|
|
|
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
|
2023-11-17 11:46:24 -08:00
|
|
|
|
mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp))
|
2022-08-05 07:37:55 -07:00
|
|
|
|
pe.def_trivial_padding(clamp_p)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _concatenate_shape_rule(*operands, **kwargs):
|
|
|
|
|
dimension = kwargs.pop('dimension')
|
|
|
|
|
if not operands:
|
|
|
|
|
msg = "concatenate expects at least one operand, got 0."
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
if not all(isinstance(operand, UnshapedArray) for operand in operands):
|
|
|
|
|
msg = "All objects to concatenate must be arrays, got {}."
|
|
|
|
|
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
|
|
|
|
raise TypeError(msg.format(type(op)))
|
|
|
|
|
if len({operand.ndim for operand in operands}) != 1:
|
2021-08-17 14:37:27 -07:00
|
|
|
|
msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."
|
|
|
|
|
raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
|
2020-11-13 14:22:17 -08:00
|
|
|
|
if not 0 <= dimension < operands[0].ndim:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
msg = "concatenate dimension out of bounds: dimension {} for shapes {}."
|
2020-11-18 09:35:40 -08:00
|
|
|
|
raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands])))
|
2020-11-13 14:22:17 -08:00
|
|
|
|
shapes = [operand.shape[:dimension] + operand.shape[dimension+1:]
|
|
|
|
|
for operand in operands]
|
2024-11-20 20:12:01 -08:00
|
|
|
|
if shapes[:-1] != shapes[1:]:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
|
2020-11-13 14:55:04 -08:00
|
|
|
|
"other than the one being concatenated: concatenating along "
|
|
|
|
|
"dimension {} for shapes {}.")
|
|
|
|
|
shapes = [operand.shape for operand in operands]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
|
|
|
|
|
|
2022-05-06 12:59:40 -07:00
|
|
|
|
concat_size = sum(o.shape[dimension] for o in operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ex_shape = operands[0].shape
|
|
|
|
|
return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
|
|
|
|
|
|
2024-11-20 20:12:01 -08:00
|
|
|
|
def _concatenate_sharding_rule(*operands, **kwargs):
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
non_empty_s = [o.sharding for o in operands if not o.sharding.mesh.empty]
|
|
|
|
|
if not non_empty_s:
|
|
|
|
|
return core.get_cur_mesh_sharding()
|
|
|
|
|
if not all(s == non_empty_s[0] for s in non_empty_s):
|
2024-11-20 20:12:01 -08:00
|
|
|
|
ss = ", ".join(str(o.sharding) for o in operands)
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2024-11-20 20:12:01 -08:00
|
|
|
|
f"All operands should have the same sharding. Got shardings {ss}")
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
return non_empty_s[0]
|
2024-11-20 20:12:01 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _concatenate_dtype_rule(*operands, **kwargs):
|
2023-04-11 13:11:41 -07:00
|
|
|
|
check_same_dtypes('concatenate', *operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return operands[0].dtype
|
|
|
|
|
|
|
|
|
|
def _concatenate_transpose_rule(t, *operands, dimension):
|
|
|
|
|
operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape
|
|
|
|
|
for o in operands]
|
|
|
|
|
if type(t) is ad_util.Zero:
|
2020-11-28 09:13:21 -08:00
|
|
|
|
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
|
|
|
|
|
for o in operands]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2024-12-17 10:05:58 -08:00
|
|
|
|
return split(t, tuple(shape[dimension] for shape in operand_shapes),
|
|
|
|
|
axis=dimension)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
|
|
|
|
|
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
|
|
|
|
|
if bdim is not None)
|
2025-03-04 18:34:34 -08:00
|
|
|
|
spec = next(core.get_aval(op).sharding.spec[bdim]
|
|
|
|
|
for op, bdim in zip(batched_args, batch_dims) if bdim is not None)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
operands = [batching.moveaxis(op, bdim, 0) if bdim is not None
|
2025-03-04 18:34:34 -08:00
|
|
|
|
else broadcast(
|
|
|
|
|
op, (size,), out_sharding=core.get_aval(op).sharding.with_spec(
|
|
|
|
|
(spec, *core.get_aval(op).sharding.spec)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
for op, bdim in zip(batched_args, batch_dims)]
|
|
|
|
|
return concatenate(operands, dimension + 1), 0
|
|
|
|
|
|
2022-10-10 18:51:04 -07:00
|
|
|
|
def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension):
|
|
|
|
|
if all(isinstance(a.shape[dimension], (int, np.integer))
|
|
|
|
|
for a in in_avals):
|
|
|
|
|
return [concatenate(operands, dimension)]
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
concatenate_p = standard_primitive(
|
2024-11-20 20:12:01 -08:00
|
|
|
|
_concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_concatenate_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'concatenate'))
|
2020-12-30 17:42:04 -08:00
|
|
|
|
ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
|
|
|
|
|
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
|
2022-10-10 18:51:04 -07:00
|
|
|
|
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _concatenate_lower(ctx, *xs, dimension):
|
2024-11-20 20:12:01 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
|
out = hlo.concatenate(xs, mlir.i64_attr(dimension))
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-12-17 10:05:58 -08:00
|
|
|
|
def _split_shape_rule(operand, *, sizes, axis):
|
|
|
|
|
shapes = []
|
|
|
|
|
shape = list(operand.shape)
|
|
|
|
|
if any(s < 0 for s in sizes):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Sizes passed to split must be nonnegative, got {list(sizes)}")
|
|
|
|
|
if operand.shape[axis] != np.sum(sizes):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the "
|
|
|
|
|
f"operand shape {list(operand.shape)}")
|
|
|
|
|
for size in sizes:
|
|
|
|
|
shape[axis] = size
|
|
|
|
|
shapes.append(tuple(shape))
|
|
|
|
|
return shapes
|
|
|
|
|
|
|
|
|
|
def _split_dtype_rule(operand, *, sizes, axis):
|
|
|
|
|
return (operand.dtype,) * len(sizes)
|
|
|
|
|
|
|
|
|
|
def _split_weak_type_rule(operand, *, sizes, axis):
|
|
|
|
|
return (operand.weak_type,) * len(sizes)
|
|
|
|
|
|
|
|
|
|
def _split_transpose_rule(cotangents, operand, *, sizes, axis):
|
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
|
if all(type(t) is ad_util.Zero for t in cotangents):
|
|
|
|
|
return ad_util.Zero(operand.aval),
|
|
|
|
|
cotangents = [
|
|
|
|
|
_zeros(t.aval) if type(t) is ad_util.Zero else t
|
|
|
|
|
for t in cotangents
|
|
|
|
|
]
|
|
|
|
|
return concatenate(cotangents, dimension=axis),
|
|
|
|
|
|
|
|
|
|
def _split_batch_rule(batched_args, batch_dims, *, sizes, axis):
|
|
|
|
|
operand, = batched_args
|
|
|
|
|
bdim, = batch_dims
|
|
|
|
|
new_bdims = (bdim,) * len(sizes)
|
|
|
|
|
out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis)
|
|
|
|
|
return out, new_bdims
|
|
|
|
|
|
|
|
|
|
def _split_lower(ctx, x, *, sizes, axis):
|
|
|
|
|
x_aval, = ctx.avals_in
|
|
|
|
|
start_indices = [0] * x_aval.ndim
|
|
|
|
|
limit_indices = list(x_aval.shape)
|
|
|
|
|
strides = (1,) * x_aval.ndim
|
|
|
|
|
outs = []
|
|
|
|
|
for aval_out in ctx.avals_out:
|
|
|
|
|
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
|
|
|
|
|
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
|
|
|
|
|
limit_indices=limit_indices, strides=strides)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
outs.append(mlir.lower_with_sharding_in_types(ctx, out, aval_out))
|
2024-12-17 10:05:58 -08:00
|
|
|
|
start_indices[axis] = limit_indices[axis]
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
def _split_sharding_rule(operand, *, sizes, axis):
|
|
|
|
|
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
|
|
|
|
|
# change this logic to `return operand.sharding` directly.
|
|
|
|
|
out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis)
|
|
|
|
|
return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split')
|
|
|
|
|
for out_sh in out_shapes]
|
|
|
|
|
|
2025-03-25 17:02:45 -07:00
|
|
|
|
def _split_vma_rule(operand, *, sizes, axis):
|
2025-03-27 16:55:45 -07:00
|
|
|
|
out_vma = core.standard_vma_rule('split', operand)
|
2025-03-25 17:02:45 -07:00
|
|
|
|
out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis)
|
|
|
|
|
return [out_vma] * len(out_shapes)
|
|
|
|
|
|
2024-12-17 10:05:58 -08:00
|
|
|
|
split_p = core.Primitive('split')
|
|
|
|
|
split_p.multiple_results = True
|
|
|
|
|
split_p.def_abstract_eval(
|
|
|
|
|
partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
_split_dtype_rule, _split_weak_type_rule, _split_sharding_rule,
|
|
|
|
|
_split_vma_rule))
|
2024-12-17 10:05:58 -08:00
|
|
|
|
split_p.def_impl(partial(dispatch.apply_primitive, split_p))
|
|
|
|
|
ad.deflinear2(split_p, _split_transpose_rule)
|
|
|
|
|
batching.primitive_batchers[split_p] = _split_batch_rule
|
|
|
|
|
mlir.register_lowering(split_p, _split_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _pad_dtype_rule(operand, padding_value, *, padding_config):
|
|
|
|
|
if operand.dtype != padding_value.dtype:
|
|
|
|
|
msg = "pad operand and padding_value must be same dtype: got {} and {}."
|
|
|
|
|
raise TypeError(msg.format(operand.dtype, padding_value.dtype))
|
|
|
|
|
|
|
|
|
|
return _input_dtype(operand, padding_value)
|
|
|
|
|
|
|
|
|
|
def _pad_shape_rule(operand, padding_value, *, padding_config):
|
2024-11-20 16:21:45 -08:00
|
|
|
|
if np.ndim(padding_value) != 0:
|
|
|
|
|
raise ValueError(f"padding_value must be a scalar; got {np.shape(padding_value)=}")
|
2021-04-05 11:08:46 +03:00
|
|
|
|
op_shape = np.shape(operand)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if not len(padding_config) == np.ndim(operand):
|
|
|
|
|
raise ValueError("length of padding_config must equal the number of axes "
|
|
|
|
|
f"of operand, got padding_config {padding_config} "
|
2021-04-05 11:08:46 +03:00
|
|
|
|
f"for operand shape {op_shape}")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if not all(i >= 0 for _, _, i in padding_config):
|
|
|
|
|
raise ValueError("interior padding in padding_config must be nonnegative, "
|
|
|
|
|
f"got padding_config {padding_config}")
|
2023-07-11 14:03:52 +01:00
|
|
|
|
result = tuple(l + h + core.dilate_dim(d, i + 1)
|
2021-04-05 11:08:46 +03:00
|
|
|
|
for (l, h, i), d in zip(padding_config, op_shape))
|
2023-07-11 14:03:52 +01:00
|
|
|
|
if not all(d >= 0 for d in result):
|
2021-04-05 11:08:46 +03:00
|
|
|
|
msg = (f"Dimension size after padding is not at least 0, "
|
|
|
|
|
f"got result shape {result}, for padding_config {padding_config}"
|
|
|
|
|
f" and operand shape {op_shape}")
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
return result
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-11-20 17:12:29 -08:00
|
|
|
|
def _pad_sharding_rule(operand, padding_value, *, padding_config):
|
|
|
|
|
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
|
|
|
|
|
# change this logic to `return operand.sharding` directly.
|
|
|
|
|
out_shape = _pad_shape_rule(operand, padding_value,
|
|
|
|
|
padding_config=padding_config)
|
2024-11-21 20:12:21 -08:00
|
|
|
|
return slicing._get_sharding_for_varying_out_shape(
|
|
|
|
|
out_shape, operand, 'padding')
|
2024-11-20 17:12:29 -08:00
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _pad_transpose(t, operand, padding_value, *, padding_config):
|
|
|
|
|
if type(t) is ad_util.Zero:
|
2020-11-28 09:13:21 -08:00
|
|
|
|
t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
|
|
|
|
|
t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None
|
|
|
|
|
else:
|
2022-06-01 15:26:54 -04:00
|
|
|
|
lo, hi, interior = util.unzip3(padding_config)
|
2025-02-11 16:00:03 -08:00
|
|
|
|
total = lambda x: reduce_sum(x, list(range(t.ndim)))
|
2020-11-27 18:01:22 -08:00
|
|
|
|
|
2020-11-28 09:13:21 -08:00
|
|
|
|
def t_op():
|
|
|
|
|
unpad_config = safe_zip(np.negative(lo), np.negative(hi),
|
|
|
|
|
np.zeros_like(interior))
|
|
|
|
|
unpadded = pad(t, np.array(0., t.dtype), unpad_config)
|
2021-11-23 16:34:33 -08:00
|
|
|
|
return slicing.slice(unpadded, np.zeros_like(lo), unpadded.shape,
|
|
|
|
|
np.add(interior, 1))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2020-11-28 09:13:21 -08:00
|
|
|
|
t_operand = t_op() if ad.is_undefined_primal(operand) else None
|
|
|
|
|
t_padv = sub(total(t), total(t_operand)) if ad.is_undefined_primal(padding_value) else None
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return [t_operand, t_padv]
|
|
|
|
|
|
|
|
|
|
def _pad_batch_rule(batched_args, batch_dims, *, padding_config):
|
|
|
|
|
operand, padding_value = batched_args
|
|
|
|
|
operand_bdim, padding_value_bdim = batch_dims
|
2021-02-17 13:27:26 -05:00
|
|
|
|
if operand_bdim is None:
|
|
|
|
|
operand_bdim = 0
|
|
|
|
|
operand = broadcast(operand, (padding_value.shape[padding_value_bdim],))
|
|
|
|
|
|
|
|
|
|
padding_config = list(padding_config)
|
|
|
|
|
padding_config.insert(operand_bdim, (0, 0, 0))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if padding_value_bdim is None:
|
|
|
|
|
return pad(operand, padding_value, padding_config), operand_bdim
|
2021-02-17 13:27:26 -05:00
|
|
|
|
|
|
|
|
|
assert padding_value_bdim == 0, padding_value_bdim
|
|
|
|
|
|
|
|
|
|
x = pad(operand, _zero(operand), padding_config)
|
|
|
|
|
mask = pad(full_like(operand, True, np.bool_), False, padding_config)
|
|
|
|
|
broadcasted_padding = broadcast_in_dim(padding_value, x.shape,
|
|
|
|
|
(operand_bdim,))
|
|
|
|
|
return select(mask, x, broadcasted_padding), operand_bdim
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-11-20 17:12:29 -08:00
|
|
|
|
pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_pad_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'pad'))
|
2020-12-30 17:42:04 -08:00
|
|
|
|
ad.deflinear2(pad_p, _pad_transpose)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _pad_lower(ctx, x, padding_value, *, padding_config):
|
2023-01-04 15:47:36 +02:00
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
|
low, high, interior = util.unzip3(padding_config)
|
2024-11-20 17:12:29 -08:00
|
|
|
|
out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2023-01-04 15:47:36 +02:00
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(pad_p, _pad_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
# The squeeze primitive exists for the benefit of masking and other
|
|
|
|
|
# transformations that need to keep track of axis identity.
|
|
|
|
|
# For example, consider reshaping a 2D array with shape (1, N) into a 1D array
|
|
|
|
|
# with shape (N,). This results in the following JAXpr:
|
|
|
|
|
# reshape[ dimension=None new_sizes=(N,) ]
|
|
|
|
|
# For N > 1, we can match up the output array axis with the second axis of the
|
|
|
|
|
# input. But for N = 1, it is not clear how axes match up: all we know from the
|
|
|
|
|
# JAXpr is that we are reshaping from (1, 1) to (1,).
|
2023-09-22 14:54:31 -07:00
|
|
|
|
# In contrast, squeeze[ dimensions=(0,) ] is unambiguous.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _squeeze_dtype_rule(operand, *, dimensions):
|
|
|
|
|
return operand.dtype
|
|
|
|
|
|
|
|
|
|
def _squeeze_shape_rule(operand, *, dimensions):
|
|
|
|
|
return _compute_squeeze_shape(np.shape(operand), dimensions)
|
|
|
|
|
|
2024-11-20 14:29:59 -08:00
|
|
|
|
def _squeeze_sharding_rule(operand, *, dimensions):
|
|
|
|
|
dims_set = set(dimensions)
|
|
|
|
|
new_spec = tuple(s for i, s in enumerate(operand.sharding.spec)
|
|
|
|
|
if i not in dims_set)
|
2024-11-22 11:01:20 -08:00
|
|
|
|
return operand.sharding.with_spec(new_spec)
|
2024-11-20 14:29:59 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _compute_squeeze_shape(shape, dimensions):
|
|
|
|
|
dims_set = set(dimensions)
|
|
|
|
|
if len(dims_set) != len(dimensions):
|
|
|
|
|
raise ValueError(f"dimensions are not unique: {dimensions}")
|
|
|
|
|
if not all(0 <= d < len(shape) for d in dims_set):
|
|
|
|
|
raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
|
2023-06-30 12:31:47 +03:00
|
|
|
|
if any(not core.definitely_equal(shape[d], 1) for d in dimensions):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
raise ValueError(
|
|
|
|
|
"cannot select an axis to squeeze out which has size not equal to "
|
2022-12-01 09:12:01 -08:00
|
|
|
|
f"one, got {shape=} and {dimensions=}")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return tuple(s for i, s in enumerate(shape) if i not in dims_set)
|
|
|
|
|
|
|
|
|
|
def _squeeze_transpose_rule(t, operand, *, dimensions):
|
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
|
return [expand_dims(t, dimensions)]
|
|
|
|
|
|
|
|
|
|
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
|
|
|
|
|
operand, = batched_args
|
|
|
|
|
bdim, = batch_dims
|
2023-06-13 16:59:24 -04:00
|
|
|
|
operand, bdim = batching.move_stacked_axis(operand, bdim, 0)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
dimensions = tuple(np.add(1, dimensions))
|
2023-06-13 16:59:24 -04:00
|
|
|
|
out_stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
|
|
|
|
|
bdim_out = batching.shape_as_bdim(
|
|
|
|
|
out_stack_dim,
|
|
|
|
|
_compute_squeeze_shape(batching.bdim_as_shape(bdim, operand.shape), dimensions))
|
2023-06-06 09:45:32 -04:00
|
|
|
|
return squeeze(operand, dimensions=dimensions), bdim_out
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
'squeeze', sharding_rule=_squeeze_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'squeeze'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
|
|
|
|
|
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
2022-10-10 18:51:04 -07:00
|
|
|
|
pe.def_trivial_padding(squeeze_p)
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _squeeze_lower(ctx, operand, *, dimensions):
|
2021-11-23 18:57:45 -08:00
|
|
|
|
del dimensions # Implied by the output aval.
|
2024-11-20 14:29:59 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
|
out = mlir.reshape(ctx, operand, aval_out)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(squeeze_p, _squeeze_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-04-28 06:01:22 -07:00
|
|
|
|
def shape_as_value(shape: core.Shape):
|
2021-11-23 16:34:33 -08:00
|
|
|
|
"""Converts a shape that may contain Poly values into a JAX value."""
|
|
|
|
|
if len(shape) == 0:
|
|
|
|
|
return full((0,), np.array(0, np.int64))
|
2024-11-16 13:38:23 -08:00
|
|
|
|
if core.is_constant_shape(shape):
|
|
|
|
|
return np.asarray(shape, dtype=np.int64)
|
2021-11-23 16:34:33 -08:00
|
|
|
|
dims = [
|
|
|
|
|
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
|
|
|
|
|
(0,))
|
|
|
|
|
for d in shape
|
|
|
|
|
]
|
|
|
|
|
return concatenate(dims, dimension=0)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-11-25 18:14:30 -08:00
|
|
|
|
def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding):
|
2023-07-11 14:03:52 +01:00
|
|
|
|
if not all(d >= 0 for d in new_sizes):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
msg = 'reshape new_sizes must all be positive, got {}.'
|
|
|
|
|
raise TypeError(msg.format(new_sizes))
|
2021-11-16 11:17:42 +02:00
|
|
|
|
# TODO(necula): re-enable this check
|
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
|
|
|
|
operand_size = math.prod(np.shape(operand))
|
|
|
|
|
new_size = math.prod(new_sizes)
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if (not config.dynamic_shapes.value and
|
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
|
|
|
|
not operand_size == new_size):
|
|
|
|
|
msg = (f"reshape total size must be unchanged, got new_sizes {new_sizes} "
|
|
|
|
|
f"(of total size {new_size}) for shape {np.shape(operand)} "
|
|
|
|
|
f"(of total size {operand_size}).")
|
|
|
|
|
raise TypeError(msg)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if dimensions is not None:
|
|
|
|
|
if set(dimensions) != set(range(np.ndim(operand))):
|
|
|
|
|
msg = ('reshape dimensions must be a permutation of operand dimensions, '
|
|
|
|
|
'got dimensions {} for shape {}.')
|
|
|
|
|
raise TypeError(msg.format(dimensions, np.shape(operand)))
|
|
|
|
|
return tuple(new_sizes)
|
|
|
|
|
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
def _split_on_one_axis(op_shape, new_sizes, name):
|
|
|
|
|
if len(new_sizes) <= len(op_shape):
|
|
|
|
|
return False, []
|
|
|
|
|
i, j, count, out = 0, 0, 0, []
|
|
|
|
|
while j < len(new_sizes):
|
|
|
|
|
if op_shape[i] == new_sizes[j]:
|
|
|
|
|
out.append(op_shape[i])
|
|
|
|
|
else:
|
|
|
|
|
count += 1
|
|
|
|
|
if count > 1:
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
f'{name} on more than 1 axis is not supported. Please specify'
|
|
|
|
|
' the sharding of the output via the `sharding` argument of'
|
|
|
|
|
f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}')
|
|
|
|
|
temp = [new_sizes[j]]
|
|
|
|
|
while math.prod(temp) != op_shape[i]:
|
2025-01-17 06:57:57 -08:00
|
|
|
|
if math.prod(temp) > op_shape[i]:
|
|
|
|
|
return False, []
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
j += 1
|
|
|
|
|
temp.append(new_sizes[j])
|
|
|
|
|
out.append(temp)
|
|
|
|
|
i += 1
|
|
|
|
|
j += 1
|
|
|
|
|
assert len(op_shape) == len(out)
|
|
|
|
|
return True, out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_on_one_axis(operand, new_sizes):
|
|
|
|
|
if len(new_sizes) >= len(operand.shape):
|
|
|
|
|
return False, []
|
|
|
|
|
return _split_on_one_axis(new_sizes, operand.shape, 'Merging')
|
|
|
|
|
|
|
|
|
|
|
2024-11-25 18:14:30 -08:00
|
|
|
|
def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding):
|
|
|
|
|
if sharding is not None:
|
|
|
|
|
return sharding
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
non_1s_op_shape = [s for s in operand.shape if s != 1]
|
|
|
|
|
non_1s_new_shape = [s for s in new_sizes if s != 1]
|
|
|
|
|
if non_1s_op_shape == non_1s_new_shape:
|
|
|
|
|
return _split_merge_singleton_dim_sharding_rule(operand, new_sizes)
|
|
|
|
|
|
|
|
|
|
is_split, out_split = _split_on_one_axis(operand.shape, new_sizes, 'Splitting')
|
|
|
|
|
if is_split:
|
2025-02-23 21:38:40 -08:00
|
|
|
|
return _split_an_axis_sharding_rule(operand, out_split, new_sizes,
|
|
|
|
|
dimensions)
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
|
|
|
|
|
is_merge, operand_merge = _merge_on_one_axis(operand, new_sizes)
|
|
|
|
|
if is_merge:
|
2025-02-23 21:38:40 -08:00
|
|
|
|
return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes,
|
|
|
|
|
dimensions)
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2025-02-23 21:38:40 -08:00
|
|
|
|
'This reshape is not supported. Please specify the sharding of'
|
|
|
|
|
' the output via the `out_sharding` argument of jax.lax.reshape. Got'
|
|
|
|
|
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
|
|
|
|
|
f' operand spec: {operand.sharding.spec}')
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
|
|
|
|
|
def _split_merge_singleton_dim_sharding_rule(operand, new_sizes):
|
|
|
|
|
filtered_spec = [sp for sh, sp in zip(operand.shape, operand.sharding.spec)
|
|
|
|
|
if sh != 1]
|
2024-10-25 10:34:33 -07:00
|
|
|
|
fs = iter(filtered_spec)
|
|
|
|
|
new_spec = []
|
|
|
|
|
for n in new_sizes:
|
|
|
|
|
if n == 1:
|
|
|
|
|
new_spec.append(None)
|
|
|
|
|
else:
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
sp = next(fs)
|
|
|
|
|
new_spec.append(sp)
|
|
|
|
|
return operand.sharding.with_spec(new_spec)
|
|
|
|
|
|
2025-02-23 21:38:40 -08:00
|
|
|
|
def _get_spec_size(sp, mesh):
|
|
|
|
|
tup_sp = sp if isinstance(sp, tuple) else (sp,)
|
|
|
|
|
return math.prod(mesh.shape[t] for t in tup_sp)
|
|
|
|
|
|
|
|
|
|
def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions):
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
new_spec = []
|
2025-02-23 21:38:40 -08:00
|
|
|
|
mesh = operand.sharding.mesh
|
|
|
|
|
for out, sp in safe_zip(out_split, operand.sharding.spec):
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
if isinstance(out, list):
|
2025-02-23 21:38:40 -08:00
|
|
|
|
if sp is None:
|
|
|
|
|
new_spec.extend([None] * len(out))
|
|
|
|
|
elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0:
|
|
|
|
|
new_spec.extend([sp] + [None] * (len(out) - 1))
|
|
|
|
|
else:
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2025-02-23 21:38:40 -08:00
|
|
|
|
'This reshape is not supported. Please specify the sharding of the'
|
|
|
|
|
' output via the `sharding` argument of jax.lax.reshape. Got'
|
|
|
|
|
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
|
|
|
|
|
f' operand spec: {operand.sharding.spec}')
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
else:
|
2024-10-25 10:34:33 -07:00
|
|
|
|
new_spec.append(sp)
|
2025-02-23 21:38:40 -08:00
|
|
|
|
assert len(new_spec) == len(new_sizes), (new_spec, new_sizes)
|
2024-11-22 11:01:20 -08:00
|
|
|
|
return operand.sharding.with_spec(new_spec)
|
2024-10-25 10:34:33 -07:00
|
|
|
|
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
|
2025-02-23 21:38:40 -08:00
|
|
|
|
def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions):
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
new_spec = []
|
2025-02-23 21:38:40 -08:00
|
|
|
|
mesh = operand.sharding.mesh
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
op_spec = iter(operand.sharding.spec)
|
2025-02-23 21:38:40 -08:00
|
|
|
|
for new_size, op_merge in zip(new_sizes, operand_merge):
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
if isinstance(op_merge, list):
|
|
|
|
|
sp = [next(op_spec) for _ in op_merge]
|
2025-02-23 21:38:40 -08:00
|
|
|
|
if all(s is None for s in sp):
|
|
|
|
|
new_spec.append(None)
|
|
|
|
|
elif (sp[0] is not None and all(s is None for s in sp[1:]) and
|
|
|
|
|
dimensions is None):
|
|
|
|
|
assert new_size % _get_spec_size(sp[0], mesh) == 0
|
|
|
|
|
new_spec.append(sp[0])
|
|
|
|
|
else:
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2025-02-23 21:38:40 -08:00
|
|
|
|
'This reshape is not supported. Please specify the sharding of the'
|
|
|
|
|
' output via the `sharding` argument of jax.lax.reshape. Got'
|
|
|
|
|
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
|
|
|
|
|
f' operand spec: {operand.sharding.spec}')
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
else:
|
|
|
|
|
new_spec.append(next(op_spec))
|
|
|
|
|
assert next(op_spec, None) is None
|
2025-02-23 21:38:40 -08:00
|
|
|
|
assert len(new_spec) == len(new_sizes), (new_spec, new_sizes)
|
[sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.
`operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`
* Merging into 1 dimension only and all the merging dimensions should be unsharded.
`operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`
* Split into singleton dimensions i.e. adding extra dims of size 1
`operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`
* Merge singleton dimensions i.e. removing extra dims of size 1
`operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`
* Identity reshape
`operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`
These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.
PiperOrigin-RevId: 716216240
2025-01-16 06:46:35 -08:00
|
|
|
|
return operand.sharding.with_spec(new_spec)
|
|
|
|
|
|
|
|
|
|
|
2024-11-25 18:14:30 -08:00
|
|
|
|
def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions,
|
|
|
|
|
sharding):
|
2022-07-19 08:19:22 +02:00
|
|
|
|
if not dyn_shape:
|
|
|
|
|
out_aval, effects = reshape_p.abstract_eval(
|
2024-11-25 18:14:30 -08:00
|
|
|
|
operand.aval, new_sizes=new_sizes, dimensions=dimensions,
|
|
|
|
|
sharding=sharding)
|
2022-07-19 08:19:22 +02:00
|
|
|
|
return [out_aval], effects
|
|
|
|
|
else:
|
|
|
|
|
# TODO(mattjj, necula): perform more checks like _reshape_shape_rule
|
|
|
|
|
out_shape = _merge_dyn_shape(new_sizes, dyn_shape)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
2022-07-19 08:19:22 +02:00
|
|
|
|
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
|
|
|
|
|
operand.aval.weak_type)
|
|
|
|
|
return [out_aval], core.no_effects
|
|
|
|
|
|
|
|
|
|
|
2024-11-25 18:14:30 -08:00
|
|
|
|
def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return operand.dtype
|
|
|
|
|
|
2024-11-25 18:14:30 -08:00
|
|
|
|
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
|
if dimensions is None:
|
2025-02-19 06:52:52 -08:00
|
|
|
|
return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2025-02-19 06:52:52 -08:00
|
|
|
|
t_s = operand.aval.sharding.with_spec(
|
|
|
|
|
tuple(map(lambda s: s if s is None else str(s),
|
|
|
|
|
np.take(operand.aval.sharding.spec, dimensions))))
|
2024-11-25 18:14:30 -08:00
|
|
|
|
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions),
|
2025-02-12 13:58:38 -08:00
|
|
|
|
out_sharding=t_s),
|
2020-10-17 14:33:26 -04:00
|
|
|
|
np.argsort(dimensions))]
|
|
|
|
|
|
2025-01-29 09:33:44 -08:00
|
|
|
|
def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
|
|
|
|
|
dimensions, sharding):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
operand, = batched_args
|
|
|
|
|
bdim, = batch_dims
|
|
|
|
|
operand = batching.moveaxis(operand, bdim, 0)
|
|
|
|
|
if dimensions is not None:
|
|
|
|
|
dimensions = (0,) + tuple(np.add(1, dimensions))
|
2025-01-29 09:33:44 -08:00
|
|
|
|
|
|
|
|
|
if sharding is not None:
|
|
|
|
|
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
|
|
|
|
|
|
2025-01-28 11:04:05 -08:00
|
|
|
|
out = reshape(operand, operand.shape[:1] + new_sizes, dimensions,
|
2025-02-12 13:58:38 -08:00
|
|
|
|
out_sharding=sharding)
|
2025-01-28 11:04:05 -08:00
|
|
|
|
return out, 0
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2024-11-25 18:14:30 -08:00
|
|
|
|
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
|
if dimensions is not None:
|
2023-12-11 12:29:57 -08:00
|
|
|
|
x = hlo.transpose(x, mlir.dense_int_array(dimensions))
|
2021-11-16 11:17:42 +02:00
|
|
|
|
if dyn_shape:
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
|
2024-10-25 10:34:33 -07:00
|
|
|
|
out = mlir.reshape(ctx, x, aval_out)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-16 11:17:42 +02:00
|
|
|
|
|
|
|
|
|
def _reshape_staging_rule(
|
2024-11-25 18:14:30 -08:00
|
|
|
|
trace, x, *dyn, new_sizes, dimensions, sharding):
|
|
|
|
|
params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if not dyn:
|
|
|
|
|
return trace.default_process_primitive(reshape_p, (x,), params)
|
|
|
|
|
av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type)
|
|
|
|
|
return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params)
|
2021-11-16 11:17:42 +02:00
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
'reshape', sharding_rule=_reshape_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reshape'))
|
2022-06-29 13:55:30 -07:00
|
|
|
|
ad.deflinear2(reshape_p, _reshape_transpose_rule)
|
2025-01-29 09:33:44 -08:00
|
|
|
|
batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule
|
|
|
|
|
batching.skippable_batchers[reshape_p] = lambda _: ()
|
2022-06-29 13:55:30 -07:00
|
|
|
|
mlir.register_lowering(reshape_p, _reshape_lower)
|
2022-07-19 08:19:22 +02:00
|
|
|
|
core.custom_typechecks[reshape_p] = _reshape_typecheck_rule
|
2021-11-16 11:17:42 +02:00
|
|
|
|
pe.custom_staging_rules[reshape_p] = _reshape_staging_rule
|
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _rev_shape_rule(operand, *, dimensions):
|
|
|
|
|
_check_shapelike('rev', 'dimensions', dimensions)
|
|
|
|
|
if len(set(dimensions)) != len(dimensions):
|
|
|
|
|
msg = 'rev dimensions must be unique, got {}.'
|
|
|
|
|
raise TypeError(msg.format(dimensions))
|
|
|
|
|
if dimensions and not _max(dimensions) < operand.ndim:
|
|
|
|
|
msg = ('rev dimensions must all be less than operand ndim, got dimensions '
|
|
|
|
|
'{} for operand ndim {}.')
|
|
|
|
|
raise TypeError(msg.format(dimensions, operand.ndim))
|
|
|
|
|
return operand.shape
|
|
|
|
|
|
2025-01-27 20:29:25 -08:00
|
|
|
|
def _rev_sharding_rule(operand, *, dimensions):
|
|
|
|
|
# TODO(yashkatariya): Will lead to data movement. Maybe just error out and
|
|
|
|
|
# require the operand to be unsharded?
|
|
|
|
|
return operand.sharding
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
|
|
|
|
|
operand, = batched_args
|
|
|
|
|
bdim, = batch_dims
|
|
|
|
|
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
|
|
|
|
|
return rev(operand, new_dimensions), bdim
|
|
|
|
|
|
2025-01-27 20:29:25 -08:00
|
|
|
|
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_rev_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'rev'))
|
2020-12-30 17:42:04 -08:00
|
|
|
|
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _rev_lower(ctx, x, *, dimensions):
|
2025-01-27 20:29:25 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
|
out = hlo.reverse(x, mlir.dense_int_array(dimensions))
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(rev_p, _rev_lower)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _transpose_shape_rule(operand, *, permutation):
|
|
|
|
|
if not isinstance(permutation, (tuple, list, np.ndarray)):
|
|
|
|
|
msg = "transpose permutation must be a tuple/list/ndarray, got {}."
|
|
|
|
|
raise TypeError(msg.format(type(permutation)))
|
|
|
|
|
if tuple(sorted(permutation)) != tuple(range(operand.ndim)):
|
|
|
|
|
msg = ("transpose permutation isn't a permutation of operand dimensions, "
|
|
|
|
|
"got permutation {} for operand shape {}.")
|
|
|
|
|
raise TypeError(msg.format(permutation, operand.shape))
|
2022-06-27 16:46:46 +03:00
|
|
|
|
return tuple(operand.shape[old_idx] for old_idx in permutation)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-10-17 17:07:19 -07:00
|
|
|
|
def _transpose_sharding_rule(operand, *, permutation):
|
|
|
|
|
o_spec = operand.sharding.spec
|
|
|
|
|
new_spec = [o_spec[old_idx] for old_idx in permutation]
|
2024-11-22 11:01:20 -08:00
|
|
|
|
return operand.sharding.with_spec(new_spec)
|
2024-10-17 17:07:19 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
|
|
|
|
operand, = batched_args
|
|
|
|
|
bdim, = batch_dims
|
2023-06-10 17:28:58 -04:00
|
|
|
|
stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
|
|
|
|
|
perm = (stack_dim,) + tuple(i if i < stack_dim else i+1 for i in permutation)
|
|
|
|
|
if isinstance(bdim, RaggedAxis):
|
2023-06-23 13:21:21 -04:00
|
|
|
|
res_bdim = batching.transpose_ragged_axes(bdim.move_stacked_axis(0), perm)
|
2023-06-10 17:28:58 -04:00
|
|
|
|
else:
|
2023-06-23 13:21:21 -04:00
|
|
|
|
res_bdim = 0
|
|
|
|
|
return transpose(operand, perm), res_bdim
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-08-05 07:37:55 -07:00
|
|
|
|
def _transpose_lower(ctx, x, *, permutation):
|
|
|
|
|
aval_out, = ctx.avals_out
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
2024-09-20 22:58:01 +00:00
|
|
|
|
elt_shape = core.physical_element_aval(aval_out.dtype).shape
|
2023-05-12 08:43:09 -07:00
|
|
|
|
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
|
|
|
|
|
permutation = [*permutation, *trailing_dims]
|
2024-10-17 17:07:19 -07:00
|
|
|
|
out = hlo.transpose(x, mlir.dense_int_array(permutation))
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2022-08-05 07:37:55 -07:00
|
|
|
|
|
2024-10-17 17:07:19 -07:00
|
|
|
|
transpose_p = standard_primitive(
|
|
|
|
|
_transpose_shape_rule, _input_dtype, 'transpose',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_transpose_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'transpose'))
|
2020-12-30 17:42:04 -08:00
|
|
|
|
ad.deflinear2(transpose_p,
|
2024-05-17 09:46:36 +01:00
|
|
|
|
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(transpose_p, _transpose_lower)
|
2022-08-05 07:37:55 -07:00
|
|
|
|
pe.def_trivial_padding(transpose_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-02-09 11:02:31 -08:00
|
|
|
|
def _select_shape_rule(which, *cases):
|
|
|
|
|
if len(cases) == 0:
|
|
|
|
|
raise TypeError("select must have at least one case")
|
|
|
|
|
if any(case.shape != cases[0].shape for case in cases[1:]):
|
|
|
|
|
msg = "select cases must have the same shapes, got [{}]."
|
2022-02-15 13:12:19 -08:00
|
|
|
|
raise TypeError(msg.format(", ".join([str(c.shape) for c in cases])))
|
2022-02-09 11:02:31 -08:00
|
|
|
|
if which.shape and which.shape != cases[0].shape:
|
|
|
|
|
msg = ("select `which` must be scalar or have the same shape as cases, "
|
|
|
|
|
"got `which` shape {} but case shape {}.")
|
|
|
|
|
raise TypeError(msg.format(which.shape, cases[0].shape))
|
|
|
|
|
return cases[0].shape
|
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
def _select_sharding_rule(which, *cases):
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
non_empty_s = [c.sharding for c in cases if not c.sharding.mesh.empty]
|
|
|
|
|
if not non_empty_s:
|
|
|
|
|
return core.get_cur_mesh_sharding()
|
|
|
|
|
if any(s != non_empty_s[0] for s in non_empty_s[1:]):
|
2024-10-25 10:34:33 -07:00
|
|
|
|
msg = "select cases must have the same shardings, got [{}]."
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
|
|
|
|
msg.format(", ".join([str(c.sharding) for c in cases])))
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
if (which.shape and not which.sharding.mesh.empty and
|
|
|
|
|
which.sharding != non_empty_s[0]):
|
2025-03-21 09:25:11 -07:00
|
|
|
|
raise core.ShardingTypeError(
|
2024-10-25 10:34:33 -07:00
|
|
|
|
'select `which` must be scalar or have the same sharding as cases, got'
|
|
|
|
|
f' `which` sharding {which.sharding} but case sharding'
|
|
|
|
|
f' {cases[0].sharding}.')
|
[sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent
* canonicalization does not happen for avals on an empty mesh
* jax.jit does not set abstract mesh context anymore before tracing
* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode
* Even if use_mesh is not used in explicit sharding mode, computation follows data works!
* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)
* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.
As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.
PiperOrigin-RevId: 726097292
2025-02-12 10:02:13 -08:00
|
|
|
|
return non_empty_s[0]
|
2024-10-25 10:34:33 -07:00
|
|
|
|
|
|
|
|
|
|
2022-02-09 11:02:31 -08:00
|
|
|
|
def _select_dtype_rule(which, *cases):
|
2023-04-11 13:11:41 -07:00
|
|
|
|
check_same_dtypes("select", *cases)
|
2022-02-09 11:02:31 -08:00
|
|
|
|
if (not dtypes.issubdtype(which.dtype, np.bool_) and
|
|
|
|
|
not dtypes.issubdtype(which.dtype, np.integer)):
|
|
|
|
|
raise TypeError("select `which` must be boolean or integer type, got "
|
|
|
|
|
f"{which.dtype}.")
|
|
|
|
|
if dtypes.issubdtype(which.dtype, np.bool_) and len(cases) > 2:
|
|
|
|
|
raise TypeError("select with boolean `which` cannot have > 2 cases.")
|
|
|
|
|
return cases[0].dtype
|
|
|
|
|
|
|
|
|
|
def _select_weak_type_rule(which, *cases):
|
|
|
|
|
return all(c.weak_type for c in cases)
|
|
|
|
|
|
|
|
|
|
def _select_transpose_rule(t, which, *cases):
|
|
|
|
|
assert not ad.is_undefined_primal(which)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if type(t) is ad_util.Zero:
|
2022-02-09 11:02:31 -08:00
|
|
|
|
return [None] + [ad_util.Zero(c.aval) if ad.is_undefined_primal(c) else None
|
|
|
|
|
for c in cases]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
|
|
|
|
zeros = full_like(t, 0)
|
2024-03-13 07:00:48 -07:00
|
|
|
|
if dtypes.dtype(which) == np.dtype(np.bool_):
|
2024-02-28 14:38:51 -08:00
|
|
|
|
ct0 = select(which, zeros, t) if ad.is_undefined_primal(cases[0]) else None
|
|
|
|
|
ct1 = select(which, t, zeros) if ad.is_undefined_primal(cases[1]) else None
|
|
|
|
|
return (None, ct0, ct1)
|
|
|
|
|
else:
|
|
|
|
|
return [None] + [
|
|
|
|
|
select(eq(which, _const(which, i)), t, zeros)
|
|
|
|
|
if ad.is_undefined_primal(case) else None for i, case in enumerate(cases)
|
|
|
|
|
]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
2022-02-09 11:02:31 -08:00
|
|
|
|
which, *cases = batched_args
|
|
|
|
|
which_bdim, *case_bdims = batch_dims
|
2020-10-17 14:33:26 -04:00
|
|
|
|
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
|
|
|
|
if i is not None)
|
|
|
|
|
|
|
|
|
|
# avoid transposes and some broadcasts in special cases
|
2022-02-09 11:02:31 -08:00
|
|
|
|
if all(which_bdim == bdim for bdim in case_bdims):
|
|
|
|
|
if np.shape(which) == np.shape(cases[0]):
|
|
|
|
|
return select_n(which, *cases), which_bdim
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
2022-02-09 11:02:31 -08:00
|
|
|
|
# vmapped function had a scalar which with nonscalar args
|
|
|
|
|
assert np.ndim(which) == 1
|
|
|
|
|
which = broadcast_in_dim(which, cases[0].shape, [which_bdim])
|
|
|
|
|
return select_n(which, *cases), which_bdim
|
|
|
|
|
elif np.ndim(which) == 0 and all(bdim is not None for bdim in case_bdims):
|
|
|
|
|
if all(case_bdims[0] == bdim for bdim in case_bdims[1:]):
|
|
|
|
|
return select_n(which, *cases), case_bdims[0]
|
|
|
|
|
elif all(np.shape(cases[0]) == np.shape(c) for c in cases):
|
|
|
|
|
bdim = case_bdims[0]
|
|
|
|
|
other_cases = [batching.moveaxis(c, c_bdim, bdim)
|
|
|
|
|
for c, c_bdim in zip(cases[1:], case_bdims[1:])]
|
|
|
|
|
return select_n(which, cases[0], *other_cases), bdim
|
|
|
|
|
|
|
|
|
|
which = (batching.bdim_at_front(which, which_bdim, size) if np.shape(which)
|
|
|
|
|
else which)
|
|
|
|
|
if not all(() == np.shape(c) for c in cases):
|
|
|
|
|
cases = [batching.bdim_at_front(c, bdim, size)
|
|
|
|
|
for c, bdim in zip(cases, case_bdims)]
|
|
|
|
|
assert all(np.shape(cases[0]) == np.shape(c) for c in cases[1:])
|
|
|
|
|
if 0 < np.ndim(which) < np.ndim(cases[0]):
|
|
|
|
|
# vmapped function had a scalar which with nonscalar args
|
|
|
|
|
assert np.ndim(which) == 1
|
|
|
|
|
which = broadcast_in_dim(which, cases[0].shape, [0])
|
|
|
|
|
if np.ndim(which) > np.ndim(cases[0]):
|
|
|
|
|
assert np.ndim(cases[0]) == 0
|
|
|
|
|
cases = [broadcast(c, which.shape) for c in cases]
|
|
|
|
|
return select_n(which, *cases), 0
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2020-10-26 15:32:31 -07:00
|
|
|
|
def _select_jvp(primals, tangents):
|
2022-02-09 11:02:31 -08:00
|
|
|
|
which, *case_primals = primals
|
|
|
|
|
case_tangents = tangents[1:]
|
|
|
|
|
out = select_n(which, *case_primals)
|
|
|
|
|
if all(type(t) is ad_util.Zero for t in case_tangents):
|
|
|
|
|
out_dot = ad_util.Zero(case_tangents[0].aval)
|
2020-10-26 15:32:31 -07:00
|
|
|
|
else:
|
2022-02-09 11:02:31 -08:00
|
|
|
|
z = _zeros(next(t for t in case_tangents if type(t) is not ad_util.Zero))
|
|
|
|
|
case_tangents = [z if type(t) is ad_util.Zero else t for t in case_tangents]
|
|
|
|
|
out_dot = select_n(which, *case_tangents)
|
2020-10-26 15:32:31 -07:00
|
|
|
|
return out, out_dot
|
|
|
|
|
|
2023-05-12 14:16:54 -07:00
|
|
|
|
def _select_hlo_lowering_opaque(ctx, which, *cases):
|
|
|
|
|
avals_in = ctx.avals_in
|
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
|
assert all(aval_case == aval_out for aval_case in avals_in[1:])
|
|
|
|
|
select_lower = _select_hlo_lowering
|
|
|
|
|
|
|
|
|
|
physical_aval_out = core.physical_aval(aval_out)
|
|
|
|
|
physical_avals_cases = [physical_aval_out] * (len(avals_in) - 1)
|
|
|
|
|
aval_which = avals_in[0]
|
|
|
|
|
aval_which_bcast = physical_aval_out.update(dtype=aval_which.dtype)
|
|
|
|
|
assert aval_which_bcast.shape[:aval_which.ndim] == aval_which.shape
|
|
|
|
|
|
|
|
|
|
bcast_dims = list(range(aval_which.ndim))
|
|
|
|
|
which_bcast = mlir.broadcast_in_dim(
|
|
|
|
|
ctx, which, aval_which_bcast, broadcast_dimensions=bcast_dims)
|
|
|
|
|
|
|
|
|
|
return mlir.delegate_lowering(
|
|
|
|
|
ctx, select_lower, which_bcast, *cases,
|
|
|
|
|
avals_in=[aval_which_bcast, *physical_avals_cases],
|
|
|
|
|
avals_out=[physical_aval_out])[0]
|
|
|
|
|
|
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
|
def _select_hlo_lowering(ctx, which, *cases):
|
2022-02-09 11:02:31 -08:00
|
|
|
|
which_aval = ctx.avals_in[0]
|
2023-05-08 18:44:24 -07:00
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
|
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
2024-10-25 10:34:33 -07:00
|
|
|
|
op = _select_hlo_lowering_opaque(ctx, which, *cases)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, op, aval_out)]
|
2023-05-08 18:44:24 -07:00
|
|
|
|
|
2022-02-09 11:02:31 -08:00
|
|
|
|
if which_aval.dtype == np.dtype(np.bool_):
|
|
|
|
|
assert len(cases) <= 2
|
|
|
|
|
if len(cases) == 1: return cases
|
2024-10-25 10:34:33 -07:00
|
|
|
|
op = hlo.select(which, cases[1], cases[0])
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, op, aval_out)]
|
2022-02-09 11:02:31 -08:00
|
|
|
|
|
|
|
|
|
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
|
2022-03-17 23:10:46 -07:00
|
|
|
|
compare_type = 'SIGNED'
|
2022-02-09 11:02:31 -08:00
|
|
|
|
else:
|
2022-03-17 23:10:46 -07:00
|
|
|
|
compare_type = 'UNSIGNED'
|
|
|
|
|
lt = 'LT'
|
2022-02-09 11:02:31 -08:00
|
|
|
|
|
|
|
|
|
def _select(offset, cases):
|
|
|
|
|
assert len(cases) > 0
|
|
|
|
|
if len(cases) == 1:
|
|
|
|
|
return cases[0]
|
|
|
|
|
mid = len(cases) // 2
|
2022-12-15 20:59:34 -08:00
|
|
|
|
pred = mlir.compare_hlo(which,
|
|
|
|
|
mlir.full_like_aval(ctx, offset + mid, which_aval),
|
|
|
|
|
lt, compare_type)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return hlo.select(pred, _select(offset, cases[:mid]),
|
|
|
|
|
_select(offset + mid, cases[mid:]))
|
2022-02-09 11:02:31 -08:00
|
|
|
|
|
2024-10-25 10:34:33 -07:00
|
|
|
|
op = _select(0, cases)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, op, aval_out)]
|
2022-02-09 11:02:31 -08:00
|
|
|
|
|
|
|
|
|
select_n_p = standard_primitive(
|
|
|
|
|
_select_shape_rule, _select_dtype_rule, 'select_n',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'select_n'))
|
2022-02-09 11:02:31 -08:00
|
|
|
|
ad.primitive_jvps[select_n_p] = _select_jvp
|
|
|
|
|
ad.primitive_transposes[select_n_p] = _select_transpose_rule
|
|
|
|
|
batching.primitive_batchers[select_n_p] = _select_batch_rule
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(select_n_p, _select_hlo_lowering)
|
2022-08-05 07:37:55 -07:00
|
|
|
|
pe.def_trivial_padding(select_n_p)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2023-12-21 22:15:12 -08:00
|
|
|
|
def _reduce_shape_rule(*avals, computation, jaxpr, dimensions):
|
2021-02-12 10:30:46 -08:00
|
|
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
|
|
|
if any(arg.shape != () for arg in init_val_avals):
|
|
|
|
|
init_val_shapes = [a.shape for a in init_val_avals]
|
|
|
|
|
raise ValueError(f'reduce found non-scalar initial value: {init_val_shapes}')
|
|
|
|
|
return [tuple(np.delete(op.shape, dimensions)) for op in operand_avals]
|
|
|
|
|
|
2024-11-22 11:59:39 -08:00
|
|
|
|
def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions):
|
|
|
|
|
operand_avals, _ = split_list(avals, [len(avals) // 2])
|
|
|
|
|
return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions))
|
|
|
|
|
for op in operand_avals]
|
|
|
|
|
|
2023-12-21 22:15:12 -08:00
|
|
|
|
def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions):
|
2021-02-12 10:30:46 -08:00
|
|
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
|
|
|
operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals]
|
|
|
|
|
init_val_dtypes = [dtypes.canonicalize_dtype(init.dtype) for init in init_val_avals]
|
|
|
|
|
if operand_dtypes != init_val_dtypes:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"reduce operand dtypes should match corresponding initial value dtypes, "
|
|
|
|
|
f"got operands={operand_avals} and initial_values={init_val_avals}")
|
2020-12-14 09:20:26 -08:00
|
|
|
|
return operand_dtypes
|
2020-11-10 15:57:19 -08:00
|
|
|
|
|
2023-12-21 22:15:12 -08:00
|
|
|
|
def _reduce_weak_type_rule(*avals, computation, jaxpr, dimensions):
|
2021-02-12 10:30:46 -08:00
|
|
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
|
|
|
return [op.weak_type and init_val.weak_type
|
|
|
|
|
for op, init_val in safe_zip(operand_avals, init_val_avals)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2020-11-10 15:57:19 -08:00
|
|
|
|
def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
|
2023-12-21 22:15:12 -08:00
|
|
|
|
dimensions):
|
2021-02-10 11:04:19 -08:00
|
|
|
|
# TODO(mattjj,frostig): use batch_jaxpr, delete computation (assumes poly??)
|
2020-11-10 15:57:19 -08:00
|
|
|
|
num_operands = len(batched_args) // 2
|
|
|
|
|
operands, init_values = split_list(batched_args, [num_operands])
|
|
|
|
|
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
|
2021-02-10 11:04:19 -08:00
|
|
|
|
if all(init_value_bdim is batching.not_mapped
|
|
|
|
|
for init_value_bdim in init_value_bdims):
|
2021-11-03 09:36:31 -04:00
|
|
|
|
size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
|
|
|
|
|
if ax is not None)
|
|
|
|
|
operands = [batching.bdim_at_front(arg, bdim, size)
|
|
|
|
|
for arg, bdim in zip(operands, operand_bdims)]
|
|
|
|
|
new_dimensions = [d + 1 for d in dimensions]
|
|
|
|
|
new_operand_bdims = [0] * num_operands
|
2020-11-10 15:57:19 -08:00
|
|
|
|
return reduce_p.bind(*(operands + init_values),
|
2021-11-03 09:36:31 -04:00
|
|
|
|
computation=computation,
|
|
|
|
|
dimensions=tuple(new_dimensions),
|
2020-11-10 15:57:19 -08:00
|
|
|
|
jaxpr=jaxpr), new_operand_bdims
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError # loop and stack
|
|
|
|
|
|
2021-03-23 10:31:02 -04:00
|
|
|
|
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
|
2023-09-21 10:24:52 -07:00
|
|
|
|
input_shape = np.array(primals[0].shape, dtype=int)
|
2021-03-23 10:31:02 -04:00
|
|
|
|
|
|
|
|
|
n = np.prod(input_shape[list(axes)])
|
|
|
|
|
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
|
|
|
|
|
|
|
|
|
# Move the reduced axes to the front, and flatten them to 1D.
|
|
|
|
|
permutation = axes + tuple(non_axes)
|
|
|
|
|
new_shape = (n,) + tuple(input_shape[non_axes])
|
|
|
|
|
primals = tuple(reshape(x, new_shape, permutation) for x in primals)
|
|
|
|
|
tangents = tuple(reshape(t, new_shape, permutation) for t in tangents)
|
|
|
|
|
|
|
|
|
|
for d in range(len(non_axes) + 1):
|
|
|
|
|
reducer = api.vmap(reducer)
|
|
|
|
|
def _reduce_tree(*xs, axis=0):
|
|
|
|
|
"""Reduce by repeatedly splitting the array and multiplying."""
|
|
|
|
|
while xs[0].shape[axis] > 1:
|
|
|
|
|
n = xs[0].shape[axis]
|
|
|
|
|
n1 = (n + 1) // 2
|
|
|
|
|
n2 = n - n1
|
2021-11-23 16:34:33 -08:00
|
|
|
|
xs1 = [slicing.slice_in_dim(x, 0, n1) for x in xs]
|
|
|
|
|
xs2 = [slicing.slice_in_dim(x, n1, None) for x in xs]
|
2021-03-23 10:31:02 -04:00
|
|
|
|
if n2 != n1:
|
|
|
|
|
paddings = [(0, 0, 0)] * len(xs[0].shape)
|
|
|
|
|
paddings[axis] = (0, 1, 0)
|
|
|
|
|
xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)]
|
|
|
|
|
xs = reducer(*(xs1 + xs2))
|
|
|
|
|
if xs[0].shape[axis] == 0:
|
|
|
|
|
return [full(input_shape[non_axes], i) for i in init_values]
|
|
|
|
|
return tuple(squeeze(x, (axis,)) for x in xs)
|
|
|
|
|
|
|
|
|
|
return api.jvp(_reduce_tree, primals, tangents)
|
|
|
|
|
|
2023-12-21 22:15:12 -08:00
|
|
|
|
def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions):
|
2021-03-23 10:31:02 -04:00
|
|
|
|
primal_xs, init_values = split_list(primals, [len(primals) // 2])
|
|
|
|
|
tangent_xs, tangent_init = split_list(tangents, [len(tangents) // 2])
|
|
|
|
|
# This test may be too strict, if a value is actually zero but we cannot prove
|
|
|
|
|
# it is symbolically zero.
|
|
|
|
|
if any(type(t) is not ad_util.Zero for t in tangent_init):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Gradient of general lax.reduce with non-zero tangents for "
|
|
|
|
|
"initial values to reduction not implemented")
|
2023-12-21 22:15:12 -08:00
|
|
|
|
reducer = core.jaxpr_as_fun(jaxpr)
|
2021-03-23 10:31:02 -04:00
|
|
|
|
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
|
|
|
|
|
|
2021-02-12 10:30:46 -08:00
|
|
|
|
reduce_p = core.Primitive('reduce')
|
|
|
|
|
reduce_p.multiple_results = True
|
2023-03-27 13:29:59 -07:00
|
|
|
|
reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
|
2021-02-12 10:30:46 -08:00
|
|
|
|
reduce_p.def_abstract_eval(
|
|
|
|
|
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
_reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule,
|
|
|
|
|
None))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
2021-03-23 10:31:02 -04:00
|
|
|
|
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-12-21 22:15:12 -08:00
|
|
|
|
def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
2021-11-23 18:57:45 -08:00
|
|
|
|
operands, init_values = util.split_list(values, [len(values) // 2])
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
init_value_avals = ctx.avals_in[len(values) // 2:]
|
2022-12-15 20:59:34 -08:00
|
|
|
|
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
2024-05-28 10:58:10 -07:00
|
|
|
|
operands, init_values, mlir.dense_int_array(dimensions))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
|
|
|
|
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
|
|
|
|
with ir.InsertionPoint(reducer):
|
2024-02-20 07:16:38 -08:00
|
|
|
|
name_stack = source_info_util.new_name_stack()
|
2022-04-19 10:45:09 -07:00
|
|
|
|
if jaxpr.effects:
|
|
|
|
|
raise NotImplementedError('Cannot lower effectful `reduce`.')
|
2024-02-20 07:16:38 -08:00
|
|
|
|
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
|
|
|
|
|
name_stack, mlir.TokenSet(),
|
2023-12-21 22:15:12 -08:00
|
|
|
|
jaxpr.consts,
|
2024-07-01 08:42:48 -04:00
|
|
|
|
*reducer.arguments,
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
dim_var_values=ctx.dim_var_values)
|
2024-07-01 08:42:48 -04:00
|
|
|
|
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, r, aval)
|
2025-02-19 06:52:52 -08:00
|
|
|
|
for r, aval in safe_zip(op.results, ctx.avals_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(reduce_p, _reduce_lower)
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
|
|
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
|
|
|
|
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
|
2023-04-24 16:54:25 -07:00
|
|
|
|
"of number.".format(name, dtype_to_string(operand.dtype)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
|
|
|
|
|
|
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
|
input_shape = operand.aval.shape
|
|
|
|
|
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
|
2025-02-19 06:52:52 -08:00
|
|
|
|
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions,
|
|
|
|
|
out_sharding=operand.aval.sharding)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
assert result.shape == input_shape
|
|
|
|
|
return [result]
|
|
|
|
|
|
2022-10-10 18:51:04 -07:00
|
|
|
|
def _reducer_padding(traceable, ident, in_avals, out_avals, operand, *, axes):
|
2022-03-30 17:52:55 -07:00
|
|
|
|
del out_avals
|
|
|
|
|
aval, = in_avals
|
|
|
|
|
padded_axes = [(i, d.val) for i, d in enumerate(aval.shape)
|
|
|
|
|
if isinstance(d, pe.BoundedAxisSize)]
|
2022-10-10 18:51:04 -07:00
|
|
|
|
operand_ = _replace_masked_values(operand, ident(aval.dtype), padded_axes)
|
|
|
|
|
return [traceable(operand_, axes)]
|
2022-03-30 17:52:55 -07:00
|
|
|
|
|
|
|
|
|
def _replace_masked_values(x, val, padded_axes):
|
|
|
|
|
if not padded_axes: return x
|
2022-06-29 13:55:30 -07:00
|
|
|
|
dtype = dtypes._scalar_type_to_dtype(int)
|
|
|
|
|
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
|
2022-03-30 17:52:55 -07:00
|
|
|
|
return select(_reduce(operator.and_, masks), x, full_like(x, val))
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
|
|
|
|
del input_shape # Unused.
|
|
|
|
|
if len(axes) != len(set(axes)):
|
|
|
|
|
raise ValueError(f"duplicate value in 'axes' of reduction: {axes}")
|
2021-01-13 14:16:54 -08:00
|
|
|
|
if not all(0 <= a < operand.ndim for a in axes):
|
|
|
|
|
raise ValueError(f"reduction axes {axes} contains out-of-bounds indices for {operand}.")
|
2021-12-10 23:22:11 -08:00
|
|
|
|
axes = frozenset(axes)
|
|
|
|
|
return tuple(d for i, d in enumerate(operand.shape) if i not in axes)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2024-10-17 15:54:42 -07:00
|
|
|
|
def _reduce_op_sharding_rule(operand, *, axes):
|
|
|
|
|
axes = frozenset(axes)
|
|
|
|
|
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
|
|
|
|
|
if i not in axes))
|
2024-11-22 11:01:20 -08:00
|
|
|
|
return operand.sharding.with_spec(new_spec)
|
2024-10-17 15:54:42 -07:00
|
|
|
|
|
|
|
|
|
reduce_sum_p = standard_primitive(
|
|
|
|
|
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
2025-03-25 17:02:45 -07:00
|
|
|
|
'reduce_sum', sharding_rule=_reduce_op_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_sum'))
|
2024-10-17 15:54:42 -07:00
|
|
|
|
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
|
|
|
|
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
2025-02-11 16:00:03 -08:00
|
|
|
|
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum,
|
2024-10-17 15:54:42 -07:00
|
|
|
|
_get_sum_identity)
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule
|
2024-10-17 15:54:42 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
2021-03-23 10:31:02 -04:00
|
|
|
|
reducer = lambda x, y: [mul(x, y)]
|
|
|
|
|
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
|
|
|
|
|
primals, tangents, axes)
|
|
|
|
|
return primals_out[0], tangents_out[0]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
reduce_prod_p = standard_primitive(
|
|
|
|
|
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
2025-03-25 17:02:45 -07:00
|
|
|
|
'reduce_prod', sharding_rule=_reduce_op_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_prod'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(reduce_prod_p, _get_prod_identity)
|
2025-02-11 16:00:03 -08:00
|
|
|
|
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_prod_identity)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
|
|
|
|
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
|
|
|
|
|
# locations in a single pass (rather than comparing equality) and use a
|
|
|
|
|
# gather, and/or even push along the chosen elements of g (b/112040122)
|
|
|
|
|
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
|
|
|
|
location_indicators = convert_element_type(
|
|
|
|
|
_eq_meet(operand, reshape(ans, shape)), g.dtype)
|
2025-02-11 16:00:03 -08:00
|
|
|
|
counts = reduce_sum(location_indicators, axes)
|
|
|
|
|
return div(reduce_sum(mul(g, location_indicators), axes), counts)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-10-10 18:51:04 -07:00
|
|
|
|
|
2024-10-17 15:54:42 -07:00
|
|
|
|
reduce_max_p = standard_primitive(
|
|
|
|
|
_reduce_op_shape_rule, _input_dtype, 'reduce_max',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_reduce_op_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_max'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(reduce_max_p, _get_max_identity)
|
2025-02-11 16:00:03 -08:00
|
|
|
|
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_max_identity)
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule
|
2022-10-10 18:51:04 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2025-01-22 16:47:58 -08:00
|
|
|
|
reduce_min_p = standard_primitive(
|
|
|
|
|
_reduce_op_shape_rule, _input_dtype, 'reduce_min',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_reduce_op_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_min'))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(reduce_min_p, _get_min_identity)
|
2025-02-11 16:00:03 -08:00
|
|
|
|
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_min_identity)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _argminmax_shape_rule(operand, *, axes, index_dtype):
|
|
|
|
|
axis, = axes
|
2021-07-12 01:11:17 -07:00
|
|
|
|
if not (0 <= axis < len(operand.shape)):
|
|
|
|
|
raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}")
|
2023-07-11 14:03:52 +01:00
|
|
|
|
if operand.shape[axis] < 1:
|
2021-07-12 01:11:17 -07:00
|
|
|
|
raise ValueError("argmin and argmax require non-empty reduced dimension. "
|
2022-12-01 09:12:01 -08:00
|
|
|
|
f"operand.shape={operand.shape} {axis=}")
|
2024-11-22 11:59:39 -08:00
|
|
|
|
return util.tuple_delete(operand.shape, axis)
|
|
|
|
|
|
|
|
|
|
def _argminmax_sharding_rule(operand, *, axes, index_dtype):
|
|
|
|
|
axis, = axes
|
|
|
|
|
return operand.sharding.with_spec(
|
|
|
|
|
util.tuple_delete(operand.sharding.spec, axis))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
2020-11-17 09:45:48 +01:00
|
|
|
|
if not dtypes.issubdtype(index_dtype, np.integer):
|
|
|
|
|
raise TypeError("index_dtype must be an integer type, but got {}"
|
2023-04-24 16:54:25 -07:00
|
|
|
|
.format(dtype_to_string(index_dtype)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return index_dtype
|
|
|
|
|
|
2024-01-26 17:43:27 -08:00
|
|
|
|
class _ArgMinMaxReducer:
|
|
|
|
|
|
2025-02-05 19:17:47 +02:00
|
|
|
|
def __init__(self, value_comparator: Callable[[Any, Any], Any]):
|
2024-01-26 17:43:27 -08:00
|
|
|
|
self._value_comparator = value_comparator
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
# Override the repr so that the metadata attached to the lowered op does not
|
|
|
|
|
# contain unstable function ids. This plays more nicely with computation
|
|
|
|
|
# fingerprint calculation in the compilation cache.
|
|
|
|
|
return f'_ArgMinMaxReducer({self._value_comparator.__name__})'
|
|
|
|
|
|
|
|
|
|
def __call__(self, op_val_index, acc_val_index):
|
2021-07-12 01:11:17 -07:00
|
|
|
|
op_val, op_index = op_val_index
|
|
|
|
|
acc_val, acc_index = acc_val_index
|
|
|
|
|
# Pick op_val if Lt (for argmin) or if NaN
|
2024-01-26 17:43:27 -08:00
|
|
|
|
pick_op_val = bitwise_or(self._value_comparator(op_val, acc_val),
|
2021-07-12 01:11:17 -07:00
|
|
|
|
ne(op_val, op_val))
|
|
|
|
|
# If x and y are not NaN and x = y, then pick the first
|
|
|
|
|
pick_op_index = bitwise_or(pick_op_val,
|
|
|
|
|
bitwise_and(eq(op_val, acc_val),
|
|
|
|
|
lt(op_index, acc_index)))
|
|
|
|
|
return (select(pick_op_val, op_val, acc_val),
|
|
|
|
|
select(pick_op_index, op_index, acc_index))
|
2024-01-26 17:43:27 -08:00
|
|
|
|
|
|
|
|
|
def _compute_argminmax(value_comparator, get_identity,
|
|
|
|
|
operand, *, index_dtype, axes):
|
|
|
|
|
# value_comparator is either lax.lt (for argmin) or lax.gt
|
|
|
|
|
# get_identity(operand.dtype) is inf for argmin or -inf for argmax
|
|
|
|
|
axis, = axes
|
2024-11-22 11:59:39 -08:00
|
|
|
|
indices = broadcasted_iota(
|
|
|
|
|
index_dtype, np.shape(operand), axis,
|
2025-02-19 06:52:52 -08:00
|
|
|
|
out_sharding=operand.aval.sharding)
|
2021-07-12 01:11:17 -07:00
|
|
|
|
res = reduce([operand, indices],
|
|
|
|
|
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
2024-01-26 17:43:27 -08:00
|
|
|
|
_ArgMinMaxReducer(value_comparator),
|
2021-07-12 01:11:17 -07:00
|
|
|
|
axes)
|
|
|
|
|
return res[1]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
2024-11-22 11:59:39 -08:00
|
|
|
|
'argmin', weak_type_rule=_strip_weak_type,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_argminmax_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'argmin'))
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(argmin_p, _get_min_identity)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(argmin_p)
|
|
|
|
|
|
|
|
|
|
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
2024-11-22 11:59:39 -08:00
|
|
|
|
'argmax', weak_type_rule=_strip_weak_type,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
sharding_rule=_argminmax_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'argmax'))
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(argmax_p, _get_max_identity)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.defjvp_zero(argmax_p)
|
|
|
|
|
|
2024-11-22 11:59:39 -08:00
|
|
|
|
mlir.register_lowering(argmin_p, mlir.cache_lowering(
|
|
|
|
|
mlir.lower_fun(partial(_compute_argminmax, lt, _get_min_identity),
|
|
|
|
|
multiple_results=False)))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2024-11-22 11:59:39 -08:00
|
|
|
|
mlir.register_lowering(argmax_p, mlir.cache_lowering(
|
|
|
|
|
mlir.lower_fun(partial(_compute_argminmax, gt, _get_max_identity),
|
|
|
|
|
multiple_results=False)))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _reduce_logical_shape_rule(operand, *, axes):
|
2022-06-15 21:27:42 +02:00
|
|
|
|
if operand.dtype != np.bool_ and not np.issubdtype(operand.dtype, np.integer):
|
|
|
|
|
raise TypeError(f"logical reduction requires operand dtype bool or int, got {operand.dtype}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return tuple(np.delete(operand.shape, axes))
|
|
|
|
|
|
2025-01-22 16:47:58 -08:00
|
|
|
|
def _reduce_logical_sharding_rule(operand, *, axes):
|
|
|
|
|
return operand.sharding.with_spec(tuple_delete(operand.sharding.spec, axes))
|
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
reduce_or_p = standard_primitive(
|
2022-06-15 21:27:42 +02:00
|
|
|
|
_reduce_logical_shape_rule, _input_dtype, 'reduce_or',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_or'))
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(reduce_or_p, _get_bitwise_or_identity)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2022-04-18 08:28:08 -07:00
|
|
|
|
reduce_and_p = standard_primitive(
|
2022-06-15 21:27:42 +02:00
|
|
|
|
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_and'))
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
|
2024-10-14 14:00:58 -07:00
|
|
|
|
batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2022-06-09 20:38:53 +02:00
|
|
|
|
reduce_xor_p = standard_primitive(
|
2022-06-15 21:27:42 +02:00
|
|
|
|
_reduce_logical_shape_rule, _input_dtype, 'reduce_xor',
|
2025-03-25 17:02:45 -07:00
|
|
|
|
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_xor'))
|
2023-05-05 15:25:42 -04:00
|
|
|
|
batching.defreducer(reduce_xor_p, _get_bitwise_or_identity)
|
2022-06-09 20:38:53 +02:00
|
|
|
|
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
|
dtype = aval_out.dtype
|
2022-12-15 20:59:34 -08:00
|
|
|
|
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
2024-07-01 08:42:48 -04:00
|
|
|
|
[mlir.ir_constant(unit_factory(aval_out.dtype))],
|
2024-05-28 10:58:10 -07:00
|
|
|
|
mlir.dense_int_array(axes))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
|
|
|
|
|
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
|
|
|
|
with ir.InsertionPoint(reducer_region):
|
2023-11-17 11:46:24 -08:00
|
|
|
|
hlo.return_([reducer(*reducer_region.arguments)])
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, op.result, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_sum_identity))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, hlo.MulOp,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_prod_identity))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, hlo.OrOp,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_bitwise_or_identity))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, hlo.AndOp,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_bitwise_and_identity))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, hlo.XorOp,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_bitwise_or_identity))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_hlo,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_min_identity))
|
2022-12-15 20:59:34 -08:00
|
|
|
|
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_hlo,
|
2022-10-10 18:51:04 -07:00
|
|
|
|
_get_max_identity))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
|
2021-04-05 09:54:14 -07:00
|
|
|
|
def _reduce_precision_shape_rule(operand, *, exponent_bits, mantissa_bits):
|
|
|
|
|
exponent_bits = operator.index(exponent_bits)
|
|
|
|
|
mantissa_bits = operator.index(mantissa_bits)
|
|
|
|
|
if exponent_bits < 1:
|
|
|
|
|
raise ValueError(f"reduce_precision: exponent_bits must be positive; got {exponent_bits}")
|
|
|
|
|
if mantissa_bits < 0:
|
|
|
|
|
raise ValueError(f"reduce_precision: mantissa_bits must be non-negative; got {mantissa_bits}")
|
|
|
|
|
return operand.shape
|
|
|
|
|
|
2025-01-17 17:30:24 -08:00
|
|
|
|
def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits):
|
|
|
|
|
return operand.sharding
|
2021-04-05 09:54:14 -07:00
|
|
|
|
|
|
|
|
|
reduce_precision_p = standard_primitive(
|
|
|
|
|
_reduce_precision_shape_rule,
|
|
|
|
|
partial(unop_dtype_rule, _identity, _float, 'reduce_precision'),
|
2025-03-25 17:02:45 -07:00
|
|
|
|
name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule,
|
2025-03-27 16:55:45 -07:00
|
|
|
|
vma_rule=partial(core.standard_vma_rule, 'reduce_precision'))
|
2023-09-11 16:35:00 -07:00
|
|
|
|
ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)])
|
2021-04-05 14:16:50 -07:00
|
|
|
|
batching.defvectorized(reduce_precision_p)
|
2021-04-05 09:54:14 -07:00
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
|
|
|
|
aval_out, = ctx.avals_out
|
2025-01-17 17:30:24 -08:00
|
|
|
|
out = hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits),
|
|
|
|
|
mlir.i32_attr(mantissa_bits))
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
2021-04-05 09:54:14 -07:00
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_UINT_DTYPES = {
|
2021-10-19 06:48:34 -07:00
|
|
|
|
16: np.dtype(np.uint16),
|
|
|
|
|
32: np.dtype(np.uint32),
|
|
|
|
|
64: np.dtype(np.uint64),
|
2020-10-17 14:33:26 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_INT_DTYPES = {
|
2021-10-19 06:48:34 -07:00
|
|
|
|
16: np.dtype(np.int16),
|
|
|
|
|
32: np.dtype(np.int32),
|
|
|
|
|
64: np.dtype(np.int64),
|
2020-10-17 14:33:26 -04:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sort_abstract_eval(*args, **kwargs):
|
2024-11-05 07:16:32 -08:00
|
|
|
|
args = tuple(args)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if any(arg.shape != args[0].shape for arg in args[1:]):
|
|
|
|
|
shapes = " ".join(str(a.shape) for a in args)
|
|
|
|
|
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
2023-10-13 12:20:22 -07:00
|
|
|
|
def _canonicalize_float_for_sort(x):
|
|
|
|
|
# In the sort comparator, we are going to use a comparision operator where -0
|
|
|
|
|
# would be before 0, and -NaN and NaN appear at the beginning and end of the
|
|
|
|
|
# ordering. In this scheme, -0 would be before 0, and -NaN and NaN appear at
|
2022-01-13 13:03:41 -08:00
|
|
|
|
# the beginning and end of the ordering. This causes issues for stable
|
|
|
|
|
# sorts, so we avoid this by standardizing the representation of zeros
|
|
|
|
|
# and NaNs in the output.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-10-25 12:43:13 -07:00
|
|
|
|
result = select(eq(x, _zero(x)), _zeros(x), x)
|
2024-09-01 07:49:49 -07:00
|
|
|
|
with config.debug_nans(False):
|
2023-11-27 08:41:19 -08:00
|
|
|
|
result = select(_isnan(x), full_like(result, np.nan), result)
|
2023-10-13 12:20:22 -07:00
|
|
|
|
|
2023-10-25 12:43:13 -07:00
|
|
|
|
return result
|
2022-01-13 13:03:41 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
# Default comparator that sorts the operands lexicographically on the
|
|
|
|
|
# first `num_keys` arguments.
|
|
|
|
|
# For floating point types, a total order is created where
|
2022-01-13 13:03:41 -08:00
|
|
|
|
# -infinity < ... < 0 < ... < infinity < NaN.
|
|
|
|
|
# 0.0 and -0.0 are treated as equivalent, as are all NaN representations.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
# For complex types, the (real, imag) pairs are sorted lexicographically
|
|
|
|
|
# (following NumPy's semantics).
|
|
|
|
|
# This code adds complex-number support and lexicographic ordering to the algorithm from:
|
|
|
|
|
# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33
|
|
|
|
|
def _sort_lt_comparator(*operands, num_keys=1):
|
2022-01-13 13:45:59 -08:00
|
|
|
|
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
|
|
|
|
|
p = None
|
|
|
|
|
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
2023-10-13 12:20:22 -07:00
|
|
|
|
p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None
|
|
|
|
|
else lt_to_p.bind(xk, yk))
|
2022-01-13 13:45:59 -08:00
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
# Similar to sort_lt_comparator, but implements less than or equal. Used by
|
|
|
|
|
# the searchsorted() implementation.
|
|
|
|
|
def _sort_le_comparator(*operands, num_keys=1):
|
|
|
|
|
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
|
|
|
|
|
p = None
|
|
|
|
|
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
2023-10-13 12:20:22 -07:00
|
|
|
|
p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None
|
|
|
|
|
else le_to_p.bind(xk, yk))
|
2022-01-13 13:45:59 -08:00
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
def _operands_to_keys(*operands, num_keys=1):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
assert len(operands) >= 2 and len(operands) % 2 == 0, operands
|
|
|
|
|
assert len(operands) // 2 >= num_keys, (operands, num_keys)
|
|
|
|
|
x_keys, y_keys = [], []
|
|
|
|
|
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
|
|
|
|
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
2021-11-16 17:36:28 -05:00
|
|
|
|
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
2023-10-13 12:20:22 -07:00
|
|
|
|
x_keys.extend([_canonicalize_float_for_sort(real(x)), _canonicalize_float_for_sort(imag(x))])
|
|
|
|
|
y_keys.extend([_canonicalize_float_for_sort(real(y)), _canonicalize_float_for_sort(imag(y))])
|
2021-11-16 17:36:28 -05:00
|
|
|
|
elif dtypes.issubdtype(x.dtype, np.floating):
|
2023-10-13 12:20:22 -07:00
|
|
|
|
x_keys.append(_canonicalize_float_for_sort(x))
|
|
|
|
|
y_keys.append(_canonicalize_float_for_sort(y))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
|
|
|
|
x_keys.append(x)
|
|
|
|
|
y_keys.append(y)
|
2022-01-13 13:45:59 -08:00
|
|
|
|
return x_keys, y_keys
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
|
|
|
|
|
shape = primals[0].shape
|
2024-11-10 09:19:35 +02:00
|
|
|
|
sorted_primals_and_idx = sort_p.bind(
|
2024-12-12 21:26:10 -05:00
|
|
|
|
*primals, broadcasted_iota(np.uint64, shape, dimension),
|
|
|
|
|
dimension=dimension, is_stable=is_stable, num_keys=num_keys)
|
|
|
|
|
batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64),
|
|
|
|
|
dimension))
|
|
|
|
|
dnums = slicing.GatherDimensionNumbers(
|
|
|
|
|
offset_dims=(),
|
|
|
|
|
collapsed_slice_dims=(dimension,),
|
|
|
|
|
start_index_map=(dimension,),
|
|
|
|
|
operand_batching_dims=batch_dims,
|
|
|
|
|
start_indices_batching_dims=batch_dims,
|
|
|
|
|
)
|
|
|
|
|
idx = expand_dims(sorted_primals_and_idx[-1], (len(shape),))
|
|
|
|
|
gather_idx = partial(
|
|
|
|
|
slicing.gather,
|
|
|
|
|
start_indices=idx, dimension_numbers=dnums, slice_sizes=(1,) * len(shape),
|
|
|
|
|
mode=slicing.GatherScatterMode.PROMISE_IN_BOUNDS
|
|
|
|
|
)
|
|
|
|
|
tangents_out = [t if type(t) is ad_util.Zero else gather_idx(t)
|
|
|
|
|
for t in tangents]
|
2024-11-10 09:19:35 +02:00
|
|
|
|
return tuple(sorted_primals_and_idx[:-1]), tangents_out
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
|
|
|
|
|
prototype_arg, new_bdim = next(
|
|
|
|
|
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
|
|
|
|
new_args = []
|
|
|
|
|
for arg, bdim in zip(batched_args, batch_dims):
|
|
|
|
|
if bdim is None:
|
|
|
|
|
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
|
|
|
|
|
new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
|
|
|
|
|
else:
|
|
|
|
|
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
|
|
|
|
new_dimension = dimension + (new_bdim <= dimension)
|
|
|
|
|
bdims = (new_bdim,) * len(new_args)
|
|
|
|
|
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),
|
|
|
|
|
bdims)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sort_p = Primitive('sort')
|
|
|
|
|
sort_p.multiple_results = True
|
2023-03-27 13:29:59 -07:00
|
|
|
|
sort_p.def_impl(partial(dispatch.apply_primitive, sort_p))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
sort_p.def_abstract_eval(_sort_abstract_eval)
|
|
|
|
|
ad.primitive_jvps[sort_p] = _sort_jvp
|
|
|
|
|
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
|
|
|
|
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
|
|
|
|
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
2022-12-15 20:59:34 -08:00
|
|
|
|
sort = hlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
2024-07-01 08:42:48 -04:00
|
|
|
|
mlir.flatten_ir_values(operands),
|
2022-12-15 20:59:34 -08:00
|
|
|
|
dimension=mlir.i64_attr(dimension),
|
|
|
|
|
is_stable=ir.BoolAttr.get(is_stable))
|
2025-02-19 06:52:52 -08:00
|
|
|
|
scalar_s = lambda a: a.sharding.with_spec(P())
|
2025-01-27 20:29:25 -08:00
|
|
|
|
scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval))
|
|
|
|
|
for aval in ctx.avals_in]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
|
|
|
|
|
comparator = sort.comparator.blocks.append(
|
|
|
|
|
*util.flatten(zip(scalar_types, scalar_types)))
|
|
|
|
|
with ir.InsertionPoint(comparator):
|
|
|
|
|
lower_comparator = mlir.lower_fun(partial(_sort_lt_comparator),
|
|
|
|
|
multiple_results=False)
|
2022-07-12 12:40:55 +03:00
|
|
|
|
sub_ctx = ctx.replace(primitive=None,
|
|
|
|
|
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
|
|
|
|
|
avals_out=[core.ShapedArray((), np.bool_)])
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
|
2024-07-01 08:42:48 -04:00
|
|
|
|
out = lower_comparator(sub_ctx, *comparator.arguments, num_keys=num_keys)
|
|
|
|
|
hlo.return_(mlir.flatten_ir_values(out))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
return sort.results
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(sort_p, _sort_lower)
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _top_k_abstract_eval(operand, *, k):
|
2023-02-08 11:07:56 -08:00
|
|
|
|
if dtypes.issubdtype(operand.dtype, np.complexfloating):
|
|
|
|
|
raise ValueError("top_k is not compatible with complex inputs.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if k < 0:
|
2022-05-12 19:13:00 +01:00
|
|
|
|
raise ValueError(f"k argument to top_k must be nonnegative, got {k}")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if len(operand.shape) == 0:
|
|
|
|
|
raise TypeError("top_k operand must have >= 1 dimension, got {}"
|
|
|
|
|
.format(operand.shape))
|
|
|
|
|
shape = list(operand.shape)
|
|
|
|
|
if shape[-1] < k:
|
|
|
|
|
msg = "k argument to top_k must be no larger than minor dimension; {} vs {}"
|
|
|
|
|
raise ValueError(msg.format(k, shape))
|
|
|
|
|
shape[-1] = k
|
2021-03-28 10:32:02 -07:00
|
|
|
|
return (operand.update(shape=shape, dtype=operand.dtype,
|
|
|
|
|
weak_type=operand.weak_type),
|
2021-01-27 15:13:30 -08:00
|
|
|
|
operand.update(shape=shape, dtype=np.dtype(np.int32)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _top_k_jvp(primals, tangents, *, k):
|
|
|
|
|
operand, = primals
|
|
|
|
|
tangent, = tangents
|
|
|
|
|
primals_out = top_k(operand, k)
|
|
|
|
|
if type(tangent) is ad_util.Zero:
|
2024-09-18 13:43:14 -07:00
|
|
|
|
tangent_out = ad_util.Zero.from_primal_value(primals_out[0])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
else:
|
|
|
|
|
_, k_idxs = primals_out
|
|
|
|
|
idx_shape = k_idxs.shape
|
|
|
|
|
rank = len(idx_shape)
|
|
|
|
|
gather_index_shape = idx_shape + (1,)
|
2024-10-04 13:55:36 -07:00
|
|
|
|
gather_indices = reshape(k_idxs, gather_index_shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
slice_sizes = (1,) * rank
|
2021-11-23 16:34:33 -08:00
|
|
|
|
dnums = slicing.GatherDimensionNumbers(
|
2024-10-04 13:55:36 -07:00
|
|
|
|
offset_dims=(),
|
|
|
|
|
collapsed_slice_dims=(rank - 1,),
|
|
|
|
|
operand_batching_dims=tuple(range(rank - 1)),
|
|
|
|
|
start_indices_batching_dims=tuple(range(rank - 1)),
|
|
|
|
|
start_index_map=(rank - 1,),
|
|
|
|
|
)
|
2021-11-23 16:34:33 -08:00
|
|
|
|
tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes)
|
2024-09-18 13:43:14 -07:00
|
|
|
|
return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1]))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _top_k_batch_rule(batched_args, batch_dims, *, k):
|
|
|
|
|
operand, = batched_args
|
|
|
|
|
bdim, = batch_dims
|
|
|
|
|
if bdim == operand.ndim-1:
|
|
|
|
|
perm = np.arange(operand.ndim)
|
|
|
|
|
perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
|
|
|
|
|
top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
|
|
|
|
|
return (transpose(top_k_v, perm),
|
|
|
|
|
transpose(top_k_i, perm)), (bdim, bdim)
|
|
|
|
|
else:
|
|
|
|
|
return top_k(operand, k=k), (bdim, bdim)
|
|
|
|
|
|
|
|
|
|
top_k_p = Primitive('top_k')
|
|
|
|
|
top_k_p.multiple_results = True
|
2023-03-27 13:29:59 -07:00
|
|
|
|
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
2022-07-08 00:21:16 +00:00
|
|
|
|
def _top_k_lower(ctx, operand, k):
|
2023-07-28 06:19:04 -07:00
|
|
|
|
if core.is_constant_dim(k):
|
|
|
|
|
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
|
|
|
|
|
k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
|
|
|
|
|
out_values_aval, out_indices_aval, = ctx.avals_out
|
|
|
|
|
return mlir.custom_call(
|
|
|
|
|
"stablehlo.dynamic_top_k",
|
2023-08-29 08:49:30 -07:00
|
|
|
|
result_types=[mlir.aval_to_ir_type(out_values_aval),
|
2023-07-28 06:19:04 -07:00
|
|
|
|
mlir.aval_to_ir_type(out_indices_aval)],
|
2023-08-29 08:49:30 -07:00
|
|
|
|
operands=[operand, k_value]).results
|
2023-07-28 06:19:04 -07:00
|
|
|
|
|
2022-07-08 00:21:16 +00:00
|
|
|
|
mlir.register_lowering(top_k_p, _top_k_lower)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
|
|
|
|
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
|
|
|
|
|
|
|
|
|
|
def _stop_gradient_jvp_rule(primals, tangents):
|
|
|
|
|
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
|
|
|
|
x, = primals
|
2024-09-18 13:43:14 -07:00
|
|
|
|
return stop_gradient(x), ad_util.Zero.from_primal_value(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _stop_gradient_batch_rule(batched_args, batch_dims):
|
|
|
|
|
x, = batched_args
|
|
|
|
|
dim, = batch_dims
|
|
|
|
|
return stop_gradient(x), dim
|
|
|
|
|
|
|
|
|
|
ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
|
|
|
|
|
batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
|
2022-10-10 18:51:04 -07:00
|
|
|
|
pe.def_trivial_padding(ad_util.stop_gradient_p)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
2020-12-17 19:49:30 -08:00
|
|
|
|
def create_token(_=None):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Creates an XLA token value with no preconditions for sequencing effects.
|
|
|
|
|
|
|
|
|
|
Experimental.
|
|
|
|
|
|
2020-12-17 19:49:30 -08:00
|
|
|
|
The argument is ignored. It exists for backward compatibility.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2021-03-29 13:58:04 -07:00
|
|
|
|
return create_token_p.bind()
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
create_token_p = Primitive("create_token")
|
2023-03-27 13:29:59 -07:00
|
|
|
|
create_token_p.def_impl(partial(dispatch.apply_primitive, create_token_p))
|
2020-12-17 19:49:30 -08:00
|
|
|
|
create_token_p.def_abstract_eval(lambda *_: abstract_token)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _create_token_lowering(ctx, *operands):
|
|
|
|
|
aval_out, = ctx.avals_out
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return [hlo.create_token()]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(create_token_p, _create_token_lowering)
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def after_all(*operands):
|
|
|
|
|
"""Merges one or more XLA token values. Experimental.
|
|
|
|
|
|
|
|
|
|
Wraps the XLA AfterAll operator."""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
operands = core.standard_insert_pbroadcast(*operands)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return after_all_p.bind(*operands)
|
|
|
|
|
|
|
|
|
|
def _after_all_abstract_eval(*operands):
|
|
|
|
|
if any(x is not abstract_token for x in operands):
|
|
|
|
|
raise TypeError("Arguments to after_all must be tokens")
|
|
|
|
|
return abstract_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
after_all_p = Primitive("after_all")
|
2023-03-27 13:29:59 -07:00
|
|
|
|
after_all_p.def_impl(partial(dispatch.apply_primitive, after_all_p))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _after_all_lowering(ctx, *operands):
|
|
|
|
|
aval_out, = ctx.avals_out
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return [hlo.after_all(operands)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(after_all_p, _after_all_lowering)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2023-02-01 17:50:00 -08:00
|
|
|
|
class InOutFeedEffect(effects.Effect):
|
|
|
|
|
pass
|
|
|
|
|
infeed_effect = InOutFeedEffect()
|
|
|
|
|
outfeed_effect = InOutFeedEffect()
|
2022-09-08 08:49:12 -07:00
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def infeed(token, shape=None, partitions=None):
|
|
|
|
|
"""Consumes an infeed value of `shape` from the host. Experimental.
|
|
|
|
|
|
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
2020-12-24 09:04:02 +11:00
|
|
|
|
`partitions` may be specified inside a `sharded_jit` function.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2023-07-19 06:47:46 -07:00
|
|
|
|
flat_shapes, treedef = tree_util.tree_flatten(shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
for shape in flat_shapes:
|
|
|
|
|
if not isinstance(shape, ShapedArray):
|
|
|
|
|
raise TypeError("shape argument to infeed must be a pytree of "
|
|
|
|
|
"ShapedArray values, got {}".format(shape))
|
|
|
|
|
if partitions is not None:
|
|
|
|
|
# Always replicate token.
|
|
|
|
|
# We specifically use type() to raise an error for PartitionSpecs.
|
|
|
|
|
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
|
|
|
|
raise ValueError(f"'partitions' argument to infeed should be a tuple, "
|
|
|
|
|
f"got {partitions}")
|
|
|
|
|
partitions = partitions + (None,)
|
|
|
|
|
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes),
|
|
|
|
|
partitions=partitions)
|
|
|
|
|
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
|
|
|
|
|
|
|
|
|
|
def _infeed_abstract_eval(token, *, shapes, partitions):
|
|
|
|
|
if token is not abstract_token:
|
|
|
|
|
raise TypeError("First argument to infeed must be a token")
|
2023-02-01 17:50:00 -08:00
|
|
|
|
return (*shapes, abstract_token), {infeed_effect}
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infeed_p = Primitive("infeed")
|
|
|
|
|
infeed_p.multiple_results = True
|
2023-03-27 13:29:59 -07:00
|
|
|
|
infeed_p.def_impl(partial(dispatch.apply_primitive, infeed_p))
|
2022-09-08 08:49:12 -07:00
|
|
|
|
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
|
2023-02-01 17:50:00 -08:00
|
|
|
|
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2022-03-04 10:25:22 -05:00
|
|
|
|
def _infeed_lowering(ctx, token, *, shapes, partitions):
|
2024-07-03 16:38:18 -04:00
|
|
|
|
output_types = safe_map(mlir.aval_to_ir_type, ctx.avals_out[:-1])
|
|
|
|
|
flat_output_types = mlir.flatten_ir_types(output_types)
|
2022-03-04 10:25:22 -05:00
|
|
|
|
# TODO(phawkins): verify `shapes` have a major-to-minor layout.
|
|
|
|
|
layouts = ir.ArrayAttr.get([
|
|
|
|
|
ir.ArrayAttr.get(
|
|
|
|
|
[mlir.i64_attr(i)
|
|
|
|
|
for i in range(len(aval.shape) - 1, -1, -1)])
|
|
|
|
|
for aval in shapes
|
|
|
|
|
])
|
2022-12-15 20:59:34 -08:00
|
|
|
|
infeed = hlo.InfeedOp(
|
|
|
|
|
flat_output_types + [hlo.TokenType.get()],
|
2022-05-24 04:32:15 -07:00
|
|
|
|
token,
|
|
|
|
|
infeed_config=ir.StringAttr.get(''),
|
|
|
|
|
layout=layouts)
|
2022-03-04 10:25:22 -05:00
|
|
|
|
if partitions is not None:
|
|
|
|
|
mlir.set_sharding(infeed, xla.sharding_to_proto(partitions))
|
|
|
|
|
token = infeed.results[-1]
|
|
|
|
|
outs = infeed.results[:-1]
|
2024-07-03 16:38:18 -04:00
|
|
|
|
return mlir.unflatten_ir_values_like_types(outs, output_types) + [
|
2022-03-04 10:25:22 -05:00
|
|
|
|
token,
|
2024-07-01 08:42:48 -04:00
|
|
|
|
]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(infeed_p, _infeed_lowering)
|
|
|
|
|
|
|
|
|
|
|
2021-07-01 11:59:13 -07:00
|
|
|
|
def outfeed(token, xs, partitions = None):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Outfeeds value `xs` to the host. Experimental.
|
|
|
|
|
|
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
2021-07-01 11:59:13 -07:00
|
|
|
|
`partitions` may be specified inside a `sharded_jit` or `pjit` function.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""
|
2021-07-01 11:59:13 -07:00
|
|
|
|
if partitions is not None:
|
|
|
|
|
# We specifically use type() to raise an error for PartitionSpecs.
|
|
|
|
|
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
|
|
|
|
raise ValueError(f"'partitions' argument to outfeed should be a tuple, "
|
|
|
|
|
f"got {partitions}")
|
2023-07-19 06:47:46 -07:00
|
|
|
|
flat_xs, _ = tree_util.tree_flatten(xs)
|
2021-07-01 11:59:13 -07:00
|
|
|
|
return outfeed_p.bind(token, *flat_xs, partitions=partitions)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-07-01 11:59:13 -07:00
|
|
|
|
def _outfeed_abstract_eval(token, *xs, partitions):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if token is not abstract_token:
|
|
|
|
|
raise TypeError("First argument to outfeed must be a token")
|
2023-02-01 17:50:00 -08:00
|
|
|
|
return abstract_token, {outfeed_effect}
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
outfeed_p = Primitive("outfeed")
|
2023-03-27 13:29:59 -07:00
|
|
|
|
outfeed_p.def_impl(partial(dispatch.apply_primitive, outfeed_p))
|
2022-09-08 08:49:12 -07:00
|
|
|
|
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
|
2023-02-01 17:50:00 -08:00
|
|
|
|
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
2022-03-04 10:25:22 -05:00
|
|
|
|
def _outfeed_lowering(ctx, token, *xs, partitions):
|
2023-02-06 10:36:58 -08:00
|
|
|
|
outfeed = hlo.OutfeedOp(
|
2024-07-01 08:42:48 -04:00
|
|
|
|
mlir.flatten_ir_values(xs),
|
2023-02-06 10:36:58 -08:00
|
|
|
|
token,
|
|
|
|
|
outfeed_config=ir.StringAttr.get(''))
|
2022-03-04 10:25:22 -05:00
|
|
|
|
if partitions is not None:
|
|
|
|
|
mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions))
|
|
|
|
|
return outfeed.results
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(outfeed_p, _outfeed_lowering)
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def rng_uniform(a, b, shape):
|
|
|
|
|
"""Stateful PRNG generator. Experimental and its use is discouraged.
|
|
|
|
|
|
2024-02-01 10:20:04 -08:00
|
|
|
|
Returns uniformly distributed random numbers in the range [a, b). If
|
|
|
|
|
b <= a, then the result is undefined, and different implementations may
|
|
|
|
|
return different results.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
You should use jax.random for most purposes; this function exists only for
|
|
|
|
|
niche use cases with special performance requirements.
|
|
|
|
|
|
|
|
|
|
This API may be removed at any time.
|
|
|
|
|
"""
|
2025-03-25 17:02:45 -07:00
|
|
|
|
a, b = core.standard_insert_pbroadcast(a, b)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
return rng_uniform_p.bind(a, b, shape=tuple(shape))
|
|
|
|
|
|
|
|
|
|
def _rng_uniform_abstract_eval(a, b, *, shape):
|
|
|
|
|
if a.dtype != b.dtype:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Arguments to rng_uniform must have identical dtypes, got {} "
|
|
|
|
|
"and {}.".format(a.dtype, b.dtype))
|
|
|
|
|
if a.shape != () or b.shape != ():
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Arguments to rng_uniform must be scalars; got shapes {} and {}."
|
|
|
|
|
.format(a.shape, b.shape))
|
2021-03-28 10:32:02 -07:00
|
|
|
|
return a.update(shape=shape, dtype=a.dtype,
|
|
|
|
|
weak_type=(a.weak_type and b.weak_type))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
rng_uniform_p = Primitive("rng_uniform")
|
2023-03-27 13:29:59 -07:00
|
|
|
|
rng_uniform_p.def_impl(partial(dispatch.apply_primitive, rng_uniform_p))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _rng_uniform_lowering(ctx, a, b, *, shape):
|
|
|
|
|
aval_out, = ctx.avals_out
|
2024-07-01 08:42:48 -04:00
|
|
|
|
shape = mlir.ir_constant(np.array(aval_out.shape, np.int64))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return [hlo.rng(a, b, shape, hlo.RngDistributionAttr.get('UNIFORM'))]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
|
|
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
|
|
|
|
|
2020-10-23 07:34:32 -07:00
|
|
|
|
|
2025-03-25 14:47:39 -07:00
|
|
|
|
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm, out_sharding):
|
2021-10-01 18:15:00 -07:00
|
|
|
|
del dtype, algorithm
|
2021-03-16 12:13:41 -04:00
|
|
|
|
return (key.shape, tuple(shape))
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2025-03-25 14:47:39 -07:00
|
|
|
|
def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm,
|
|
|
|
|
out_sharding):
|
|
|
|
|
return (key.sharding, out_sharding)
|
|
|
|
|
|
|
|
|
|
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding):
|
2021-10-01 18:15:00 -07:00
|
|
|
|
del shape, algorithm
|
2021-03-16 12:13:41 -04:00
|
|
|
|
return (key.dtype, dtype)
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2025-03-25 14:47:39 -07:00
|
|
|
|
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm,
|
|
|
|
|
out_sharding):
|
2021-10-01 18:15:00 -07:00
|
|
|
|
del shape, dtype, algorithm
|
2021-03-16 12:13:41 -04:00
|
|
|
|
return (key.weak_type, False)
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2024-10-11 11:02:39 -07:00
|
|
|
|
|
|
|
|
|
class RandomAlgorithm(enum.IntEnum):
|
|
|
|
|
"""Describes which PRNG algorithm to use for rng_bit_generator."""
|
|
|
|
|
|
|
|
|
|
RNG_DEFAULT = 0
|
|
|
|
|
"The platform's default algorithm."
|
|
|
|
|
|
|
|
|
|
RNG_THREE_FRY = 1
|
|
|
|
|
"The Threefry-2x32 PRNG algorithm."
|
|
|
|
|
|
|
|
|
|
RNG_PHILOX = 2
|
|
|
|
|
"The Philox-4x32 PRNG algorithm."
|
|
|
|
|
|
|
|
|
|
|
2024-05-17 09:46:36 +01:00
|
|
|
|
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[method-assign]
|
2022-05-04 11:18:28 -07:00
|
|
|
|
|
|
|
|
|
def _rng_algorithm(algorithm: RandomAlgorithm):
|
|
|
|
|
if algorithm == RandomAlgorithm.RNG_THREE_FRY:
|
2022-12-15 20:59:34 -08:00
|
|
|
|
return hlo.RngAlgorithmAttr.get("THREE_FRY")
|
2022-05-04 11:18:28 -07:00
|
|
|
|
elif algorithm == RandomAlgorithm.RNG_PHILOX:
|
2022-12-15 20:59:34 -08:00
|
|
|
|
return hlo.RngAlgorithmAttr.get("PHILOX")
|
2022-05-04 11:18:28 -07:00
|
|
|
|
elif algorithm == RandomAlgorithm.RNG_DEFAULT:
|
2022-12-15 20:59:34 -08:00
|
|
|
|
return hlo.RngAlgorithmAttr.get("DEFAULT")
|
2022-05-04 11:18:28 -07:00
|
|
|
|
else:
|
|
|
|
|
assert False
|
|
|
|
|
|
2022-04-13 18:09:08 -07:00
|
|
|
|
def _rng_bit_generator_lowering(
|
2025-03-25 14:47:39 -07:00
|
|
|
|
ctx, key, *, shape, dtype, algorithm, out_sharding):
|
2022-04-13 18:09:08 -07:00
|
|
|
|
key_type = ir.RankedTensorType(key.type)
|
|
|
|
|
key_shape, key_etype = key_type.shape, key_type.element_type
|
|
|
|
|
# While the RngBitGenerator HLO accepts a u64[2] key on all backends, we
|
|
|
|
|
# typically represent the key argument to this primitive as a u32[4] so as to
|
|
|
|
|
# sidestep issues with the jax_enable_x64=False configuration. As a result, we
|
|
|
|
|
# need to convert u32[4] -> u64[2] here in the translation rule. However, we
|
|
|
|
|
# also polymorphically allow a u64[2] for backward compatibility.
|
|
|
|
|
#
|
2024-10-11 11:02:39 -07:00
|
|
|
|
# Separately, RngBitGenerator doesn't support generating u8 or
|
2022-04-13 18:09:08 -07:00
|
|
|
|
# u16, so we request u32 and truncate in that case.
|
|
|
|
|
u32_type = ir.IntegerType.get_unsigned(32)
|
|
|
|
|
u64_type = ir.IntegerType.get_unsigned(64)
|
|
|
|
|
assert ((key_shape == [4] and key_etype == u32_type) or
|
|
|
|
|
(key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype)
|
|
|
|
|
dtype = np.dtype(dtype)
|
|
|
|
|
etype = mlir.dtype_to_ir_type(dtype)
|
2023-03-28 12:43:32 -07:00
|
|
|
|
if dtype in (np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
|
|
|
|
|
np.dtype('uint64')):
|
2022-04-13 18:09:08 -07:00
|
|
|
|
rbg_etype = etype
|
2023-06-19 00:38:59 -07:00
|
|
|
|
rbg_dtype = dtype
|
2022-04-13 18:09:08 -07:00
|
|
|
|
else:
|
|
|
|
|
rbg_etype = u32_type
|
2023-06-19 00:38:59 -07:00
|
|
|
|
rbg_dtype = np.uint32
|
2022-04-13 18:09:08 -07:00
|
|
|
|
if key_etype == u32_type:
|
2023-11-17 11:46:24 -08:00
|
|
|
|
key = hlo.bitcast_convert(
|
2022-04-13 18:09:08 -07:00
|
|
|
|
ir.RankedTensorType.get([2], u64_type),
|
2023-11-17 11:46:24 -08:00
|
|
|
|
hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key))
|
2022-07-08 00:21:16 +00:00
|
|
|
|
algorithm_attr = _rng_algorithm(algorithm)
|
2025-03-25 14:47:39 -07:00
|
|
|
|
out_key_aval, out_vals_aval = ctx.avals_out
|
2023-06-19 00:38:59 -07:00
|
|
|
|
if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out):
|
|
|
|
|
output_shape = mlir.shape_tensor(
|
|
|
|
|
mlir.eval_dynamic_shape(ctx, out_vals_aval.shape))
|
|
|
|
|
out_key, out_vals = mlir.custom_call(
|
|
|
|
|
"stablehlo.dynamic_rng_bit_generator",
|
2023-08-29 08:49:30 -07:00
|
|
|
|
result_types=[key.type,
|
|
|
|
|
mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))],
|
|
|
|
|
operands=[key, output_shape],
|
2023-06-19 00:38:59 -07:00
|
|
|
|
extra_attributes=dict(rng_algorithm=algorithm_attr)).results
|
|
|
|
|
else:
|
|
|
|
|
out_key, out_vals = hlo.RngBitGeneratorOp(
|
|
|
|
|
key.type,
|
|
|
|
|
ir.RankedTensorType.get(shape, rbg_etype),
|
|
|
|
|
algorithm_attr, key).results
|
2022-04-13 18:09:08 -07:00
|
|
|
|
if key_etype == u32_type:
|
2023-11-17 11:46:24 -08:00
|
|
|
|
out_key = hlo.reshape(
|
2022-04-13 18:09:08 -07:00
|
|
|
|
ir.RankedTensorType.get([4], u32_type),
|
2023-11-17 11:46:24 -08:00
|
|
|
|
hlo.bitcast_convert(
|
|
|
|
|
ir.RankedTensorType.get([2, 2], u32_type), out_key))
|
2022-04-13 18:09:08 -07:00
|
|
|
|
if rbg_etype != etype:
|
2023-11-17 11:46:24 -08:00
|
|
|
|
out_vals = hlo.convert(
|
2022-04-13 18:09:08 -07:00
|
|
|
|
ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype),
|
2023-11-17 11:46:24 -08:00
|
|
|
|
out_vals)
|
2025-03-25 14:47:39 -07:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out_key, out_key_aval),
|
|
|
|
|
mlir.lower_with_sharding_in_types(ctx, out_vals, out_vals_aval)]
|
2022-04-13 18:09:08 -07:00
|
|
|
|
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
|
rng_bit_generator_p = Primitive("rng_bit_generator")
|
|
|
|
|
rng_bit_generator_p.multiple_results = True
|
|
|
|
|
rng_bit_generator_p.def_impl(
|
2023-03-27 13:29:59 -07:00
|
|
|
|
partial(dispatch.apply_primitive, rng_bit_generator_p))
|
2021-03-16 12:13:41 -04:00
|
|
|
|
rng_bit_generator_p.def_abstract_eval(
|
|
|
|
|
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
|
|
|
|
|
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
2025-03-25 17:02:45 -07:00
|
|
|
|
_rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule,
|
|
|
|
|
None))
|
2022-04-13 18:09:08 -07:00
|
|
|
|
mlir.register_lowering(rng_bit_generator_p,
|
|
|
|
|
_rng_bit_generator_lowering)
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
|
|
|
|
|
2022-10-03 15:55:56 -07:00
|
|
|
|
def _array_copy(arr: ArrayLike) -> Array:
|
2022-01-21 08:27:10 -08:00
|
|
|
|
return copy_p.bind(arr)
|
|
|
|
|
|
2022-12-07 15:41:13 -08:00
|
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
|
def _which_dim_sharded(s: PmapSharding) -> int | None:
|
2022-12-07 15:41:13 -08:00
|
|
|
|
sharded_dim = None
|
|
|
|
|
for i, s in enumerate(s.sharding_spec.sharding):
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
if isinstance(s, (pxla.Unstacked, pxla.Chunked)):
|
2022-12-07 15:41:13 -08:00
|
|
|
|
sharded_dim = i
|
|
|
|
|
break
|
|
|
|
|
return sharded_dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _identity_fn(x): return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
|
|
|
|
|
axis_name, static_broadcasted_tuple, donate_tuple = api._shared_code_pmap(
|
|
|
|
|
_identity_fn, None, (), (), sharded_dim, sharded_dim)
|
|
|
|
|
p = api._prepare_pmap(
|
|
|
|
|
_identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple,
|
2023-03-29 09:22:34 -07:00
|
|
|
|
donate_tuple, None, None, None, args, kwargs)
|
2022-12-07 15:41:13 -08:00
|
|
|
|
out_flat = pxla.xla_pmap_impl(
|
|
|
|
|
p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name,
|
2023-01-05 07:54:02 -08:00
|
|
|
|
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
|
|
|
|
devices=p.devices, in_axes=p.in_axes_flat,
|
|
|
|
|
out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__,
|
|
|
|
|
donated_invars=p.donated_invars,
|
|
|
|
|
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
|
|
|
|
)
|
2022-12-07 15:41:13 -08:00
|
|
|
|
return tree_util.tree_unflatten(p.out_tree(), out_flat)
|
|
|
|
|
|
|
|
|
|
|
2024-09-20 07:51:48 -07:00
|
|
|
|
# TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a
|
2022-12-07 15:41:13 -08:00
|
|
|
|
# method on jax.Array so that we can bypass the XLA compilation here.
|
|
|
|
|
def _copy_impl(prim, *args, **kwargs):
|
|
|
|
|
a, = args
|
2024-09-01 07:49:49 -07:00
|
|
|
|
if isinstance(a, Array) and isinstance(a.sharding, PmapSharding):
|
2022-12-07 15:41:13 -08:00
|
|
|
|
sharded_dim = _which_dim_sharded(a.sharding)
|
2023-09-22 08:24:08 -07:00
|
|
|
|
if sharded_dim is None:
|
|
|
|
|
return dispatch.apply_primitive(prim, *args, **kwargs)
|
2022-12-07 15:41:13 -08:00
|
|
|
|
return _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs)
|
2023-03-27 13:29:59 -07:00
|
|
|
|
return dispatch.apply_primitive(prim, *args, **kwargs)
|
2022-12-07 15:41:13 -08:00
|
|
|
|
|
2022-01-21 08:27:10 -08:00
|
|
|
|
# The copy_p primitive exists for expressing making copies of runtime arrays.
|
|
|
|
|
# For that reason we don't simplify it out of jaxprs (e.g. for jit invariance).
|
|
|
|
|
# It's used in jnp.array(x, copy=True), which is the user-facing API.
|
|
|
|
|
copy_p = core.Primitive('copy')
|
2022-12-07 15:41:13 -08:00
|
|
|
|
copy_p.def_impl(partial(_copy_impl, copy_p))
|
2022-01-21 08:27:10 -08:00
|
|
|
|
copy_p.def_abstract_eval(lambda x: x)
|
|
|
|
|
mlir.register_lowering(copy_p, lambda ctx, x: [x])
|
|
|
|
|
ad.deflinear(copy_p, lambda t: [copy_p.bind(t)])
|
2022-08-19 14:10:12 -07:00
|
|
|
|
pe.def_trivial_padding(copy_p)
|
2022-01-21 08:27:10 -08:00
|
|
|
|
batching.defvectorized(copy_p)
|
2024-05-31 14:03:04 -07:00
|
|
|
|
def _propagate_mem_kind_copy(in_mem_kind):
|
|
|
|
|
return in_mem_kind
|
|
|
|
|
pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2021-10-05 13:46:57 -07:00
|
|
|
|
def rng_bit_generator(key, shape, dtype=np.uint32,
|
2025-03-25 14:47:39 -07:00
|
|
|
|
algorithm=RandomAlgorithm.RNG_DEFAULT, out_sharding=None):
|
2021-03-16 12:13:41 -04:00
|
|
|
|
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
|
Returns uniformly distributed random bits with the specified shape and dtype
|
2021-04-08 10:42:25 -07:00
|
|
|
|
(what is required to be an integer type) using the platform specific
|
2021-03-16 12:13:41 -04:00
|
|
|
|
default algorithm or the one specified.
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2021-07-24 15:25:13 +07:00
|
|
|
|
It provides direct access to the RngBitGenerator primitive exposed by XLA
|
2021-03-16 12:13:41 -04:00
|
|
|
|
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
|
|
|
|
|
level API access.
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
|
Most users should use `jax.random` instead for a stable and more user
|
|
|
|
|
friendly API.
|
|
|
|
|
"""
|
2023-02-14 23:00:40 -08:00
|
|
|
|
shape = core.canonicalize_shape(shape)
|
2022-01-20 22:20:17 -08:00
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2025-03-25 14:47:39 -07:00
|
|
|
|
out_sharding = canonicalize_sharding(out_sharding, 'rng_bit_generator')
|
2022-01-20 22:20:17 -08:00
|
|
|
|
if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'),
|
|
|
|
|
np.dtype('uint32'), np.dtype('uint64')}:
|
|
|
|
|
raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}')
|
2021-03-16 12:13:41 -04:00
|
|
|
|
return tuple(
|
|
|
|
|
rng_bit_generator_p.bind(
|
2025-03-25 14:47:39 -07:00
|
|
|
|
key, shape=shape, dtype=dtype, algorithm=algorithm,
|
|
|
|
|
out_sharding=out_sharding))
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
|
def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding):
|
2023-06-13 15:56:36 -04:00
|
|
|
|
if not dyn_shape:
|
|
|
|
|
# TODO(mattjj) Generalize shape_like checking to permit dynamic shapes
|
|
|
|
|
_check_shapelike("iota", "shape", shape)
|
2020-10-23 07:34:32 -07:00
|
|
|
|
if not any(dtypes.issubdtype(dtype, t) for t in _num):
|
|
|
|
|
msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
2023-04-24 16:54:25 -07:00
|
|
|
|
typename = dtype_to_string(dtype)
|
2020-10-23 07:34:32 -07:00
|
|
|
|
accepted_typenames = (t.__name__ for t in _num)
|
|
|
|
|
raise TypeError(msg.format(typename, ', '.join(accepted_typenames)))
|
|
|
|
|
if not 0 <= dimension < len(shape):
|
|
|
|
|
raise ValueError("iota dimension must be between 0 and len(shape), got "
|
2022-12-01 09:12:01 -08:00
|
|
|
|
f"{dimension=} for {shape=}")
|
2023-06-13 15:56:36 -04:00
|
|
|
|
if (not dyn_shape and
|
|
|
|
|
not any(isinstance(d, core.DArray) and
|
|
|
|
|
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
2025-02-03 17:59:44 -08:00
|
|
|
|
if sharding is None:
|
|
|
|
|
sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape)))
|
2024-10-17 21:16:18 -07:00
|
|
|
|
return ShapedArray(shape, dtype, sharding=sharding)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
2023-06-13 15:56:36 -04:00
|
|
|
|
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
|
2020-10-23 07:34:32 -07:00
|
|
|
|
|
2024-10-25 12:06:59 -07:00
|
|
|
|
|
2020-10-23 07:34:32 -07:00
|
|
|
|
iota_p = Primitive('iota')
|
2023-03-27 13:29:59 -07:00
|
|
|
|
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
|
2020-10-23 07:34:32 -07:00
|
|
|
|
iota_p.def_abstract_eval(_iota_abstract_eval)
|
2024-10-25 12:06:59 -07:00
|
|
|
|
batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule
|
2020-10-23 07:34:32 -07:00
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
|
def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding):
|
|
|
|
|
params = dict(dtype=dtype, shape=shape, dimension=dimension,
|
|
|
|
|
sharding=sharding)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if not dyn_shape:
|
|
|
|
|
return trace.default_process_primitive(iota_p, (), params)
|
|
|
|
|
aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
|
|
|
|
|
return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params)
|
2021-11-16 11:17:42 +02:00
|
|
|
|
pe.custom_staging_rules[iota_p] = _iota_staging_rule
|
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
|
def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension, sharding):
|
2021-11-16 11:17:42 +02:00
|
|
|
|
if not dyn_shape:
|
|
|
|
|
out_aval, effects = iota_p.abstract_eval(
|
2024-10-17 21:16:18 -07:00
|
|
|
|
dtype=dtype, shape=shape, dimension=dimension, sharding=sharding)
|
2021-11-16 11:17:42 +02:00
|
|
|
|
return [out_aval], effects
|
|
|
|
|
else:
|
|
|
|
|
out_shape = _merge_dyn_shape(shape, dyn_shape)
|
2022-12-15 20:34:43 -08:00
|
|
|
|
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
2021-11-16 11:17:42 +02:00
|
|
|
|
out_aval = core.DShapedArray(tuple(out_shape), dtype, False)
|
|
|
|
|
return [out_aval], core.no_effects
|
|
|
|
|
core.custom_typechecks[iota_p] = _iota_typecheck_rule
|
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
|
def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
|
2021-11-16 11:17:42 +02:00
|
|
|
|
del dtype
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-16 11:17:42 +02:00
|
|
|
|
if dyn_shape:
|
[jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-11-28 13:16:07 +01:00
|
|
|
|
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
2024-10-17 21:16:18 -07:00
|
|
|
|
out = mlir.iota(ctx, aval_out, dimension=dimension)
|
2025-02-22 10:45:18 -08:00
|
|
|
|
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
mlir.register_lowering(iota_p, _iota_lower)
|
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
|
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension,
|
|
|
|
|
sharding):
|
2022-10-17 11:15:14 -07:00
|
|
|
|
(segment_lengths,), (ax,) = in_vals, in_dims
|
2023-05-05 15:25:42 -04:00
|
|
|
|
assert ax == 0
|
|
|
|
|
bound = segment_lengths.dtype.bound
|
2023-07-21 14:20:39 -04:00
|
|
|
|
ragged_axis, = (i for i, dim in enumerate(shape) if dim is None)
|
2023-05-05 15:25:42 -04:00
|
|
|
|
shape = (len(segment_lengths),) + _merge_dyn_shape(shape, (bound,))
|
2024-10-17 21:16:18 -07:00
|
|
|
|
if sharding is not None:
|
|
|
|
|
raise NotImplementedError('Please file an issue if you want this support')
|
2023-05-05 15:25:42 -04:00
|
|
|
|
iota = broadcasted_iota(dtype, shape, dimension+1)
|
2023-07-12 14:38:28 -04:00
|
|
|
|
return iota, batching.RaggedAxis(ax, ((ragged_axis+1, segment_lengths),))
|
2022-10-17 11:15:14 -07:00
|
|
|
|
batching.primitive_batchers[iota_p] = _iota_batching_rule
|
|
|
|
|
|
2024-10-17 21:16:18 -07:00
|
|
|
|
def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension,
|
|
|
|
|
sharding):
|
2022-06-29 13:55:30 -07:00
|
|
|
|
out_aval, = out_avals
|
|
|
|
|
new_shape = []
|
|
|
|
|
new_dyn_shape = []
|
|
|
|
|
for d in out_aval.shape:
|
|
|
|
|
if type(d) is pe.BoundedAxisSize:
|
|
|
|
|
new_shape.append(d.bound)
|
|
|
|
|
elif type(d) is int:
|
|
|
|
|
new_shape.append(d)
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(d, core.Tracer)
|
|
|
|
|
new_shape.append(None)
|
|
|
|
|
new_dyn_shape.append(d)
|
2024-10-17 21:16:18 -07:00
|
|
|
|
if sharding is not None:
|
|
|
|
|
raise NotImplementedError('Please file an issue if you want this support')
|
2022-06-29 13:55:30 -07:00
|
|
|
|
return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape),
|
2024-10-17 21:16:18 -07:00
|
|
|
|
dtype=dtype, dimension=dimension, sharding=sharding)]
|
2022-06-29 13:55:30 -07:00
|
|
|
|
pe.padding_rules[iota_p] = _iota_padding_rule
|
|
|
|
|
|
2020-10-23 07:34:32 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
### util
|
|
|
|
|
|
|
|
|
|
_ndim = np.ndim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dilate_shape(shape, dilation):
|
|
|
|
|
"""Utility function for computing the shape resulting from a dilation."""
|
|
|
|
|
if not np.all(np.greater(dilation, 0)):
|
|
|
|
|
msg = "All dilations must be positive, got {}."
|
|
|
|
|
raise TypeError(msg.format(dilation))
|
|
|
|
|
dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation)
|
2023-07-11 14:03:52 +01:00
|
|
|
|
return tuple(map(core.dilate_dim, shape, dilation))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
def _ceil_divide(x1, x2):
|
|
|
|
|
return -np.floor_divide(np.negative(x1), x2)
|
|
|
|
|
|
2023-03-23 15:49:44 -07:00
|
|
|
|
|
|
|
|
|
class PaddingType(enum.Enum):
|
|
|
|
|
VALID = 1
|
|
|
|
|
SAME = 2
|
|
|
|
|
SAME_LOWER = 3
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def padtype_to_pads(in_shape, window_shape, window_strides, padding):
|
|
|
|
|
"""Convert padding string to list of pairs of pad values."""
|
|
|
|
|
|
|
|
|
|
if isinstance(padding, str):
|
2023-03-23 15:49:44 -07:00
|
|
|
|
mapping = {
|
|
|
|
|
'VALID': PaddingType.VALID,
|
|
|
|
|
'SAME': PaddingType.SAME,
|
|
|
|
|
'SAME_LOWER': PaddingType.SAME_LOWER,
|
|
|
|
|
}
|
2020-10-17 14:33:26 -04:00
|
|
|
|
try:
|
|
|
|
|
padding = mapping[padding.upper()]
|
|
|
|
|
except KeyError as err:
|
|
|
|
|
msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}."
|
|
|
|
|
raise RuntimeError(msg.format(padding)) from err
|
|
|
|
|
|
2023-03-23 15:49:44 -07:00
|
|
|
|
if padding == PaddingType.SAME or padding == PaddingType.SAME_LOWER:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
out_shape = _ceil_divide(in_shape, window_strides)
|
2024-01-08 08:49:31 +02:00
|
|
|
|
pad_sizes = (core.max_dim(d, 0)
|
|
|
|
|
for d in (out_shape - 1) * window_strides +
|
|
|
|
|
window_shape - in_shape)
|
2023-03-23 15:49:44 -07:00
|
|
|
|
if padding == PaddingType.SAME:
|
2024-08-18 09:09:00 -07:00
|
|
|
|
pads = [
|
2023-03-23 15:49:44 -07:00
|
|
|
|
(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes
|
|
|
|
|
]
|
|
|
|
|
else:
|
2024-08-18 09:09:00 -07:00
|
|
|
|
pads = [
|
2023-03-23 15:49:44 -07:00
|
|
|
|
(pad_size - pad_size // 2, pad_size // 2) for pad_size in pad_sizes
|
|
|
|
|
]
|
2024-08-18 09:09:00 -07:00
|
|
|
|
# Avoids verbose numpy scalars in jaxprs.
|
|
|
|
|
return [p.item() if isinstance(p, np.generic) else p for p in pads]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
elif padding == PaddingType.VALID:
|
|
|
|
|
return [(0, 0)] * len(in_shape)
|
|
|
|
|
else:
|
|
|
|
|
msg = "Unknown padding type: {}."
|
|
|
|
|
raise TypeError(msg.format(padding))
|
|
|
|
|
|
|
|
|
|
|
2021-12-07 11:45:07 -08:00
|
|
|
|
# Map of lax function to equivalent jax.numpy function for use in error string below.
|
|
|
|
|
_JNP_FUNCTION_EQUIVALENTS = {
|
|
|
|
|
'abs': 'fabs',
|
|
|
|
|
'acos': 'arccos',
|
|
|
|
|
'acosh': 'arccosh',
|
|
|
|
|
'add': 'add',
|
|
|
|
|
'asin': 'arcsin',
|
|
|
|
|
'asinh': 'arcsinh',
|
|
|
|
|
'atan': 'arctan',
|
|
|
|
|
'atan2': 'arctan2',
|
|
|
|
|
'atanh': 'arctanh',
|
|
|
|
|
'bitwise_and': 'bitwise_and',
|
|
|
|
|
'bitwise_not': 'bitwise_not',
|
|
|
|
|
'bitwise_or': 'bitwise_or',
|
|
|
|
|
'bitwise_xor': 'bitwise_xor',
|
|
|
|
|
'cbrt': 'cbrt',
|
|
|
|
|
'ceil': 'ceil',
|
|
|
|
|
'concatenate': 'concatenate',
|
|
|
|
|
'cos': 'cos',
|
|
|
|
|
'cosh': 'cosh',
|
|
|
|
|
'div': 'divide',
|
|
|
|
|
'eq': 'equal',
|
|
|
|
|
'exp': 'exp',
|
|
|
|
|
'expm1': 'expm1',
|
|
|
|
|
'floor': 'floor',
|
|
|
|
|
'greater': 'greater',
|
|
|
|
|
'greater_equal': 'greater_equal',
|
|
|
|
|
'less': 'less',
|
|
|
|
|
'less_equal': 'less_equal',
|
|
|
|
|
'log': 'log',
|
|
|
|
|
'logical_and': 'logical_and',
|
|
|
|
|
'logical_not': 'logical_not',
|
|
|
|
|
'logical_or': 'logical_or',
|
|
|
|
|
'logical_xor': 'logical_xor',
|
|
|
|
|
'log1p': 'log1p',
|
|
|
|
|
'max': 'maximum',
|
|
|
|
|
'min': 'minimum',
|
|
|
|
|
'mul': 'multiply',
|
|
|
|
|
'ne': 'not_equal',
|
|
|
|
|
'neg': 'negative',
|
|
|
|
|
'nextafter': 'nextafter',
|
|
|
|
|
'pow': 'float_power',
|
2023-06-04 16:08:27 -07:00
|
|
|
|
'round': 'round',
|
2021-12-07 11:45:07 -08:00
|
|
|
|
'select': 'where',
|
|
|
|
|
'shift_left': 'left_shift',
|
|
|
|
|
'shift_right_logical': 'right_shift',
|
|
|
|
|
'shift_right_arithmetic': 'right_shift',
|
|
|
|
|
'sign': 'sign',
|
|
|
|
|
'sin': 'sin',
|
|
|
|
|
'sinh': 'sinh',
|
|
|
|
|
'sqrt': 'sqrt',
|
|
|
|
|
'sub': 'subtract',
|
|
|
|
|
'tan': 'tan',
|
|
|
|
|
'tanh': 'tanh'
|
|
|
|
|
}
|
|
|
|
|
|
2023-04-11 13:11:41 -07:00
|
|
|
|
def check_same_dtypes(name: str, *avals: core.UnshapedArray) -> None:
|
2020-10-17 14:33:26 -04:00
|
|
|
|
"""Check that dtypes agree, possibly ignoring float precision."""
|
|
|
|
|
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
|
|
|
|
# allows mixed floating point precision, but the HLO verifier often rejects it
|
2023-07-24 14:29:37 -07:00
|
|
|
|
if any(dtypes.issubdtype(aval.dtype, dtypes.extended) for aval in avals):
|
2022-08-05 22:18:53 -07:00
|
|
|
|
return # TODO(mattjj,frostig): do some checking, friend
|
2023-04-11 13:11:41 -07:00
|
|
|
|
if len(avals) < 2:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
dtype = dtypes.canonicalize_dtype(avals[0].dtype)
|
|
|
|
|
if any(dtypes.canonicalize_dtype(aval.dtype) != dtype for aval in avals[1:]):
|
|
|
|
|
msg = "lax.{} requires arguments to have the same dtypes, got {}."
|
2021-12-07 11:45:07 -08:00
|
|
|
|
if name in _JNP_FUNCTION_EQUIVALENTS:
|
|
|
|
|
equiv = _JNP_FUNCTION_EQUIVALENTS[name]
|
|
|
|
|
msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
|
2023-04-11 13:11:41 -07:00
|
|
|
|
raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
|
|
|
|
|
"""Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints)."""
|
|
|
|
|
if not isinstance(obj, (tuple, list, np.ndarray)):
|
|
|
|
|
msg = "{} {} must be of type tuple/list/ndarray, got {}."
|
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, type(obj)))
|
|
|
|
|
# bool(obj) for an ndarray raises an error, so we check len
|
|
|
|
|
if not len(obj): # pylint: disable=g-explicit-length-test
|
|
|
|
|
return
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and
|
2022-10-10 18:51:04 -07:00
|
|
|
|
any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)):
|
2022-01-20 22:58:09 -08:00
|
|
|
|
return # TODO(mattjj): handle more checks in the dynamic shape case
|
2020-10-17 14:33:26 -04:00
|
|
|
|
obj_arr = np.array(obj)
|
|
|
|
|
if obj_arr.ndim != 1:
|
2022-06-09 15:03:53 -07:00
|
|
|
|
msg = "{} {} must be 1-dimensional, got {}."
|
2020-10-17 14:33:26 -04:00
|
|
|
|
raise TypeError(msg.format(obj_arr.ndim))
|
|
|
|
|
try:
|
|
|
|
|
canonicalize_shape(obj_arr)
|
|
|
|
|
except TypeError as err:
|
|
|
|
|
msg = "{} {} must have every element be an integer type, got {}."
|
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err
|
|
|
|
|
lower_bound, bound_error = (
|
|
|
|
|
(1, "strictly positive") if non_zero_shape else (0, "nonnegative"))
|
2023-07-11 14:03:52 +01:00
|
|
|
|
if not all(d >= lower_bound for d in obj_arr):
|
2020-10-17 14:33:26 -04:00
|
|
|
|
msg = "{} {} must have every element be {}, got {}."
|
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _const(example, val):
|
2021-03-17 13:07:53 -07:00
|
|
|
|
dtype = _dtype(example)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if dtypes.is_python_scalar(example):
|
2021-03-17 13:07:53 -07:00
|
|
|
|
val = dtypes.scalar_type_of(example)(val)
|
|
|
|
|
return val if dtype == _dtype(val) else np.array(val, dtype)
|
|
|
|
|
return np.array(val, dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
_zeros: Callable = partial(full_like, fill_value=0)
|
2025-01-14 08:03:08 -08:00
|
|
|
|
|
|
|
|
|
def _zero(x):
|
2025-02-19 06:52:52 -08:00
|
|
|
|
x_aval = core.get_aval(x)
|
|
|
|
|
return full_like(x, shape=(), fill_value=0,
|
2025-03-11 18:21:19 -07:00
|
|
|
|
sharding=x_aval.sharding.with_spec(P()))
|
2025-01-14 08:03:08 -08:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_ones: Callable = partial(full_like, fill_value=1)
|
2024-10-25 10:34:33 -07:00
|
|
|
|
|
|
|
|
|
def _one(x):
|
2025-02-19 06:52:52 -08:00
|
|
|
|
x_aval = core.get_aval(x)
|
|
|
|
|
return full_like(x, shape=(), fill_value=1,
|
|
|
|
|
sharding=x_aval.sharding.with_spec(P()))
|
2024-10-25 10:34:33 -07:00
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
_twos: Callable = partial(full_like, fill_value=2)
|
|
|
|
|
_two: Callable = partial(full_like, shape=(), fill_value=2)
|
|
|
|
|
|
2021-11-22 09:29:43 -08:00
|
|
|
|
dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
|
|
|
|
_dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
2022-09-12 09:08:13 -07:00
|
|
|
|
def _isnan(x: ArrayLike) -> Array:
|
2022-01-13 13:03:41 -08:00
|
|
|
|
return ne(x, x)
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _iscomplex(x) -> bool:
|
|
|
|
|
return dtypes.issubdtype(_dtype(x), np.complexfloating)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ranges_like(*xs):
|
|
|
|
|
start = 0
|
|
|
|
|
for x in xs:
|
|
|
|
|
x_len = len(x)
|
|
|
|
|
yield range(start, start + x_len)
|
|
|
|
|
start += x_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remaining(original, *removed_lists):
|
|
|
|
|
removed = set(itertools.chain(*removed_lists))
|
|
|
|
|
return [i for i in original if i not in removed]
|
|
|
|
|
|
|
|
|
|
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
def canonicalize_precision(precision: PrecisionLike) -> CanonicalPrecision:
|
2024-06-17 13:55:46 +05:30
|
|
|
|
"""Turns an API precision specification into a pair of enumeration values.
|
2021-05-12 02:29:51 -07:00
|
|
|
|
|
|
|
|
|
The API can take the precision as a string, or int, and either as a single
|
|
|
|
|
value to apply to both operands, or as a sequence of two values.
|
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
|
if precision is None:
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.default_matmul_precision.value is None:
|
2021-03-23 20:58:52 -07:00
|
|
|
|
return None
|
|
|
|
|
try:
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
return canonicalize_precision(config.default_matmul_precision.value)
|
|
|
|
|
except ValueError:
|
2021-03-23 20:58:52 -07:00
|
|
|
|
raise ValueError(
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
"jax_default_matmul_precision flag must be set to None, a value in "
|
|
|
|
|
f"{list(_precision_strings)}, or the name of a lax.DotAlgorithmPreset, "
|
|
|
|
|
f"but got {config.default_matmul_precision.value}"
|
2021-03-23 20:58:52 -07:00
|
|
|
|
) from None
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
elif isinstance(precision, str):
|
|
|
|
|
if precision in _precision_strings:
|
|
|
|
|
return Precision(precision), Precision(precision)
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
return DotAlgorithmPreset[precision]
|
|
|
|
|
except KeyError:
|
|
|
|
|
pass
|
2024-03-04 14:00:36 -08:00
|
|
|
|
elif isinstance(precision, Precision):
|
2024-07-26 10:59:56 +01:00
|
|
|
|
return precision, precision
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
elif isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
|
|
|
|
return precision
|
2021-03-23 20:58:52 -07:00
|
|
|
|
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
2024-03-04 14:00:36 -08:00
|
|
|
|
all(isinstance(p, Precision) for p in precision)):
|
|
|
|
|
return type_cast(tuple[Precision, Precision], precision)
|
2021-03-23 20:58:52 -07:00
|
|
|
|
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
|
|
|
|
all(isinstance(s, str) for s in precision)):
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
s1, s2 = type_cast(tuple[str, str], precision)
|
2024-03-04 14:00:36 -08:00
|
|
|
|
p1 = type_cast(tuple[Precision, Precision], canonicalize_precision(s1))[0]
|
|
|
|
|
p2 = type_cast(tuple[Precision, Precision], canonicalize_precision(s2))[0]
|
2022-03-04 04:20:57 -08:00
|
|
|
|
return (p1, p2)
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
|
raise ValueError(
|
|
|
|
|
"Precision argument must be one of:\n"
|
|
|
|
|
"- None,\n"
|
|
|
|
|
f"- a string in {list(_precision_strings)},\n"
|
|
|
|
|
"- a lax.Precision value,\n"
|
|
|
|
|
"- a tuple of two lax.Precision values or strings,\n"
|
|
|
|
|
"- a lax.DotAlgorithmPreset or the name of one of these presets, or\n"
|
|
|
|
|
"- a lax.DotAlgorithm value;\n"
|
|
|
|
|
f"but got {precision}.")
|
2024-09-25 06:16:22 -07:00
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
def _balanced_eq(x, z, y):
|
|
|
|
|
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
|
|
|
|
|
select(_eq_meet(y, z), _twos(z), _ones(z)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _eq_meet(a, b):
|
|
|
|
|
a_dtype, b_dtype = _dtype(a), _dtype(b)
|
|
|
|
|
if a_dtype != b_dtype:
|
|
|
|
|
higher_dtype = dtypes.promote_types(a_dtype, b_dtype)
|
|
|
|
|
if higher_dtype == a_dtype:
|
|
|
|
|
a = convert_element_type(a, b_dtype)
|
|
|
|
|
else:
|
|
|
|
|
b = convert_element_type(b, a_dtype)
|
|
|
|
|
return eq(a, b)
|
|
|
|
|
|
|
|
|
|
|
2022-08-30 14:47:15 -07:00
|
|
|
|
def empty(dtype):
|
|
|
|
|
return empty_p.bind(dtype=dtype)
|
2022-08-05 22:18:53 -07:00
|
|
|
|
empty_p = core.Primitive('empty')
|
2022-08-30 14:47:15 -07:00
|
|
|
|
empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
|
|
|
|
|
def _empty_lower(ctx, *, dtype):
|
2023-07-24 14:29:37 -07:00
|
|
|
|
dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
|
2023-05-11 21:40:01 -07:00
|
|
|
|
phys_aval = core.physical_aval(core.ShapedArray((), dtype))
|
2024-07-01 08:42:48 -04:00
|
|
|
|
return mlir.ir_constant(np.zeros(phys_aval.shape, phys_aval.dtype)),
|
2022-08-05 22:18:53 -07:00
|
|
|
|
mlir.register_lowering(empty_p, _empty_lower)
|
2022-10-10 18:51:04 -07:00
|
|
|
|
|
|
|
|
|
|
2023-08-31 17:30:34 -07:00
|
|
|
|
tie_p = core.Primitive('tie')
|
|
|
|
|
tie_p.def_impl(lambda x, y: y)
|
|
|
|
|
tie_p.def_abstract_eval(lambda x, y: y)
|
|
|
|
|
mlir.register_lowering(tie_p, lambda ctx, x, y: [y])
|
|
|
|
|
ad.primitive_jvps[tie_p] = \
|
|
|
|
|
lambda primals, tangents: (tie_p.bind(*primals), tangents[-1])
|
|
|
|
|
ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct]
|
|
|
|
|
pe.def_trivial_padding(tie_p)
|
|
|
|
|
batching.defvectorized(tie_p)
|
|
|
|
|
|
|
|
|
|
|
2022-10-10 18:51:04 -07:00
|
|
|
|
class BIntRules:
|
simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).
This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-20 22:58:01 +00:00
|
|
|
|
allow_conversion: bool = True
|
|
|
|
|
|
2022-10-10 18:51:04 -07:00
|
|
|
|
@staticmethod
|
2023-05-10 19:13:29 -07:00
|
|
|
|
def physical_element_aval(dtype) -> core.ShapedArray:
|
2023-07-11 17:40:04 -04:00
|
|
|
|
return core.ShapedArray((), np.dtype('int32'))
|
2022-10-10 18:51:04 -07:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def result_handler(sticky_device, aval):
|
|
|
|
|
def handler(_, buf):
|
|
|
|
|
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
|
|
|
|
return core.DArray(aval, buf)
|
|
|
|
|
return handler
|
|
|
|
|
|
2023-03-22 20:54:45 -07:00
|
|
|
|
@staticmethod
|
2024-02-28 15:21:50 -08:00
|
|
|
|
def global_sharded_result_handler(aval, out_sharding, committed):
|
2023-05-10 19:13:29 -07:00
|
|
|
|
phys_aval = core.physical_aval(aval)
|
2023-03-22 20:54:45 -07:00
|
|
|
|
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
|
|
|
|
|
|
|
|
|
if not dispatch.is_single_device_sharding(out_sharding):
|
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
|
else:
|
|
|
|
|
phys_sharding = out_sharding
|
2024-02-28 15:21:50 -08:00
|
|
|
|
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
2023-03-22 20:54:45 -07:00
|
|
|
|
|
|
|
|
|
def handler(bufs):
|
|
|
|
|
return core.DArray(aval, phys_handler(bufs))
|
|
|
|
|
return handler
|
|
|
|
|
|
2024-02-28 14:36:20 -08:00
|
|
|
|
|
2022-10-10 18:51:04 -07:00
|
|
|
|
core.bint._rules = BIntRules
|
2024-09-05 19:49:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def optimization_barrier(operand, /):
|
|
|
|
|
"""Prevents the compiler from moving operations across the barrier.
|
|
|
|
|
|
|
|
|
|
Optimization barriers have a number of possible uses:
|
|
|
|
|
|
|
|
|
|
* An optimization barrier ensures that all inputs are evaluated before any
|
|
|
|
|
operators that depend on the barrier's outputs. This can be used to enforce
|
|
|
|
|
a particular order of operations.
|
|
|
|
|
* An optimization barrier prevents common subexpression elimination. This is
|
|
|
|
|
used by JAX to implement rematerialization.
|
|
|
|
|
* Optimization barriers prevent compiler fusions. That is, operations before
|
|
|
|
|
the barrier may not be fused into the same kernel as operations after the
|
|
|
|
|
barrier by the compiler.
|
|
|
|
|
|
|
|
|
|
JAX does not define derivative or batching rules for an optimization barrier.
|
|
|
|
|
|
|
|
|
|
Optimization barriers have no effect outside a compiled function.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
operand: a pytree of JAX values.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A pytree of JAX values, with the same structure and contents as ``operand``.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
Prevents common-subexpression elimination between the two calls to `sin`:
|
|
|
|
|
|
|
|
|
|
>>> def f(x):
|
|
|
|
|
... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
|
|
|
|
|
>>> jax.jit(f)(0.)
|
|
|
|
|
Array(0., dtype=float32, weak_type=True)
|
|
|
|
|
"""
|
|
|
|
|
flat_args, treedef = tree_util.tree_flatten(operand)
|
2025-03-25 17:02:45 -07:00
|
|
|
|
# TODO(yashkatariya): Enable this
|
|
|
|
|
# flat_args = core.standard_insert_pbroadcast(flat_args)
|
|
|
|
|
out = optimization_barrier_p.bind(*flat_args)
|
|
|
|
|
return tree_util.tree_unflatten(treedef, out)
|
2024-09-05 19:49:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _optimization_barrier_abstract_eval(*args):
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
def _optimization_barrier_lowering_rule(ctx, *args):
|
|
|
|
|
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
|
|
|
|
|
flat_args = mlir.flatten_ir_values(args)
|
|
|
|
|
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
|
|
|
|
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimization_barrier_p = core.Primitive('optimization_barrier')
|
|
|
|
|
optimization_barrier_p.multiple_results = True
|
|
|
|
|
optimization_barrier_p.def_impl(
|
|
|
|
|
partial(dispatch.apply_primitive, optimization_barrier_p))
|
|
|
|
|
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
|
|
|
|
mlir.register_lowering(optimization_barrier_p,
|
|
|
|
|
_optimization_barrier_lowering_rule)
|
2024-12-09 19:20:04 -08:00
|
|
|
|
|
|
|
|
|
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
|
|
|
|
|
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
|
|
|
|
|
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
|
2025-03-14 22:40:41 +00:00
|
|
|
|
|
|
|
|
|
def _opt_barrier_jvp(primals, tangents):
|
|
|
|
|
tangents = [ad.instantiate_zeros(t) for t in tangents]
|
|
|
|
|
return optimization_barrier(primals), optimization_barrier(tangents)
|
|
|
|
|
ad.primitive_jvps[optimization_barrier_p] = _opt_barrier_jvp
|
|
|
|
|
|
|
|
|
|
def _opt_barrier_transpose(cts, *primals):
|
|
|
|
|
cts = [ad.instantiate_zeros(ct) for ct in cts]
|
|
|
|
|
return optimization_barrier(cts)
|
|
|
|
|
ad.primitive_transposes[optimization_barrier_p] = _opt_barrier_transpose
|