mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
5414 lines
216 KiB
Python
5414 lines
216 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import builtins
|
|
from collections.abc import Callable, Sequence
|
|
import enum
|
|
import functools
|
|
from functools import partial
|
|
import itertools
|
|
import math
|
|
import operator
|
|
from typing import Any, TypeVar, Union, cast as type_cast, overload
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from jax import tree_util
|
|
from jax.sharding import Sharding
|
|
from jax.tree_util import tree_map
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import api
|
|
from jax._src import api_util
|
|
from jax._src import array
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src import linear_util as lu
|
|
from jax._src import pjit
|
|
from jax._src import pretty_printer as pp
|
|
from jax._src import source_info_util
|
|
from jax._src import state
|
|
from jax._src import util
|
|
from jax._src.abstract_arrays import array_types
|
|
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
|
raise_to_shaped, abstract_token, canonicalize_shape)
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.interpreters import xla
|
|
from jax._src.interpreters.batching import RaggedAxis
|
|
from jax._src.lax import slicing
|
|
from jax._src.lax.utils import (
|
|
_input_dtype, dtype_to_string, standard_abstract_eval,
|
|
standard_multi_result_abstract_eval, standard_primitive)
|
|
from jax._src import xla_bridge
|
|
from jax._src.lib import xla_client
|
|
from jax._src.lib import version as jaxlib_version
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import chlo
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
from jax._src.sharding_impls import PmapSharding, NamedSharding, PartitionSpec
|
|
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
|
|
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
|
|
split_list, NumpyComplexWarning)
|
|
|
|
xb = xla_bridge
|
|
xc = xla_client
|
|
xops = xla_client.ops
|
|
xe = xla_client._xla
|
|
|
|
_max = builtins.max
|
|
_min = builtins.min
|
|
_reduce = functools.reduce
|
|
|
|
T = TypeVar("T")
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
def _matrix_transpose(x: Array) -> Array:
|
|
assert x.ndim >= 2
|
|
return transpose(x, [*range(x.ndim - 2), x.ndim - 1, x.ndim - 2])
|
|
|
|
def _clip_int_to_valid_range(val: DimSize, dtype, where: str) -> int:
|
|
info = np.iinfo(dtype)
|
|
val = core.concrete_dim_or_error(val, where)
|
|
return core.max_dim(info.min, core.min_dim(val, info.max))
|
|
|
|
def _validate_shapes(shapes: Sequence[Shape]):
|
|
def _check_static_shape(shape: Shape):
|
|
checked = canonicalize_shape(shape)
|
|
if not all(idx >= 0 for idx in checked):
|
|
msg = f"Only non-negative indices are allowed when broadcasting" \
|
|
f" static shapes, but got shape {shape!r}."
|
|
raise TypeError(msg)
|
|
|
|
assert shapes
|
|
if config.dynamic_shapes.value:
|
|
# pass dynamic shapes through unchecked
|
|
return
|
|
else:
|
|
map(_check_static_shape, shapes)
|
|
|
|
def _try_broadcast_shapes(
|
|
shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None:
|
|
if len(shapes) == 1: return shapes[0]
|
|
ranks = {len(shape) for shape in shapes}
|
|
if len(ranks) > 1: return None # must have consistent rank
|
|
rank = ranks.pop()
|
|
if not rank: return () # scalar case
|
|
result_shape = []
|
|
for ds in unsafe_zip(*shapes):
|
|
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
|
|
# if all axes are identical objects, the resulting size is the object
|
|
result_shape.append(ds[0])
|
|
else:
|
|
# if all dims are equal (or 1), the result is the non-1 size (or 1)
|
|
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
|
|
if not non_1s:
|
|
result_shape.append(1)
|
|
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
|
|
result_shape.append(non_1s[0])
|
|
else:
|
|
return None
|
|
return tuple(result_shape)
|
|
|
|
def asarray(x: ArrayLike) -> Array:
|
|
"""Lightweight conversion of ArrayLike input to Array output."""
|
|
if isinstance(x, Array):
|
|
return x
|
|
if isinstance(x, (np.ndarray, np.generic, bool, int, float, builtins.complex)):
|
|
return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x))
|
|
else:
|
|
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
|
|
|
|
@overload
|
|
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
|
|
|
|
@overload
|
|
def broadcast_shapes(*shapes: tuple[int | core.Tracer, ...]
|
|
) -> tuple[int | core.Tracer, ...]: ...
|
|
|
|
def broadcast_shapes(*shapes):
|
|
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
|
|
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
|
|
try:
|
|
return _broadcast_shapes_cached(*shapes)
|
|
except:
|
|
return _broadcast_shapes_uncached(*shapes)
|
|
|
|
@cache()
|
|
def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
|
|
return _broadcast_shapes_uncached(*shapes)
|
|
|
|
def _broadcast_shapes_uncached(*shapes):
|
|
_validate_shapes(shapes)
|
|
fst, *rst = shapes
|
|
if not rst: return fst
|
|
|
|
# First check if we need only rank promotion (and not singleton-broadcasting).
|
|
try: return _reduce(_broadcast_ranks, rst, fst)
|
|
except ValueError: pass
|
|
|
|
# Next try singleton-broadcasting, padding out ranks using singletons.
|
|
ndim = _max(len(shape) for shape in shapes)
|
|
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
|
|
result_shape = _try_broadcast_shapes(shape_list)
|
|
if result_shape is None:
|
|
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
|
|
return result_shape
|
|
|
|
def _broadcast_ranks(s1, s2):
|
|
if len(s1) > len(s2):
|
|
s1, s2 = s2, s1
|
|
assert len(s1) <= len(s2)
|
|
s1_ = s2[len(s2) - len(s1):]
|
|
if core.definitely_equal_shape(s1_, s1): return s2
|
|
else: raise ValueError
|
|
|
|
def _identity(x): return x
|
|
|
|
def _extract_tracers_dyn_shape(
|
|
shape: Sequence[int | core.Tracer]
|
|
) -> tuple[list[core.Tracer], list[int | None]]:
|
|
# Given a sequence representing a shape, pull out Tracers, replacing with None
|
|
if config.dynamic_shapes.value:
|
|
# We must gate this behavior under a flag because otherwise the errors
|
|
# raised are different (and have worse source provenance information).
|
|
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
|
|
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
|
|
return dyn_shape, static_shape
|
|
else:
|
|
return [], list(shape) # type: ignore
|
|
|
|
def _merge_dyn_shape(
|
|
static_shape: Sequence[int | None],
|
|
dyn_shape: Sequence[Any],
|
|
) -> tuple[int | mlir.Value | core.Tracer, ...]:
|
|
# Replace Nones in static_shape with elements of dyn_shape, in order
|
|
dyn_shape_it = iter(dyn_shape)
|
|
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
|
|
assert next(dyn_shape_it, None) is None
|
|
return shape
|
|
|
|
def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params):
|
|
source_info = source_info_util.current()
|
|
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info)
|
|
eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args],
|
|
[trace.makevar(out_tracer)],
|
|
prim, params, core.no_effects, source_info)
|
|
trace.frame.add_eqn(eqn)
|
|
return out_tracer
|
|
|
|
|
|
### traceables
|
|
|
|
def neg(x: ArrayLike) -> Array:
|
|
r"""Elementwise negation: :math:`-x`."""
|
|
return neg_p.bind(x)
|
|
|
|
def sign(x: ArrayLike) -> Array:
|
|
r"""Elementwise sign.
|
|
|
|
For floating-point inputs, returns
|
|
:math:`\mathrm{sign}(x) = \begin{cases}
|
|
-1 & x < 0\\
|
|
-0 & x = -0\\
|
|
\mathit{NaN} & x = \mathit{NaN}\\
|
|
+0 & x = +0\\
|
|
1 & x > 0
|
|
\end{cases}`
|
|
|
|
For signed integer inputs, returns
|
|
:math:`\mathrm{sign}(x) = \begin{cases}
|
|
-1 & x < 0\\
|
|
0 & x = 0\\
|
|
1 & x > 0
|
|
\end{cases}`
|
|
|
|
For complex inputs, returns the complex phase, i.e.
|
|
:math:`\mathrm{sign}(x) = \frac{x}{|x|}`.
|
|
"""
|
|
return sign_p.bind(x)
|
|
|
|
def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
|
|
r"""Returns the next representable value after `x1` in the direction of `x2`.
|
|
|
|
Note that in some environments flush-denormal-to-zero semantics is used.
|
|
This means that, around zero, this function returns strictly non-zero
|
|
values which appear as zero in any operations. Consider this example::
|
|
|
|
>>> jnp.nextafter(0, 1) # denormal numbers are representable
|
|
Array(1.e-45, dtype=float32, weak_type=True)
|
|
>>> jnp.nextafter(0, 1) * 1 # but are flushed to zero
|
|
Array(0., dtype=float32, weak_type=True)
|
|
|
|
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
|
|
"""
|
|
return nextafter_p.bind(x1, x2)
|
|
|
|
def floor(x: ArrayLike) -> Array:
|
|
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
|
|
return floor_p.bind(x)
|
|
|
|
def ceil(x: ArrayLike) -> Array:
|
|
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
|
|
return ceil_p.bind(x)
|
|
|
|
class RoundingMethod(enum.IntEnum):
|
|
AWAY_FROM_ZERO = 0
|
|
TO_NEAREST_EVEN = 1
|
|
|
|
def round(x: ArrayLike,
|
|
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
|
|
) -> Array:
|
|
r"""Elementwise round.
|
|
|
|
Rounds values to the nearest integer.
|
|
|
|
Args:
|
|
x: an array or scalar value to round.
|
|
rounding_method: the method to use when rounding halfway values
|
|
(e.g., `0.5`). See ``lax.RoundingMethod`` for the list of possible
|
|
values.
|
|
|
|
Returns:
|
|
An array containing the elementwise rounding of x.
|
|
"""
|
|
rounding_method = RoundingMethod(rounding_method)
|
|
return round_p.bind(x, rounding_method=rounding_method)
|
|
|
|
def is_finite(x: ArrayLike) -> Array:
|
|
r"""Elementwise :math:`\mathrm{isfinite}`.
|
|
|
|
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
|
|
:math:`\mathit{NaN}`.
|
|
"""
|
|
return is_finite_p.bind(x)
|
|
|
|
def exp(x: ArrayLike) -> Array:
|
|
r"""Elementwise exponential: :math:`e^x`."""
|
|
return exp_p.bind(x)
|
|
|
|
def exp2(x: ArrayLike) -> Array:
|
|
r"""Elementwise base-2 exponential: :math:`2^x`."""
|
|
return exp2_p.bind(x)
|
|
|
|
def expm1(x: ArrayLike) -> Array:
|
|
r"""Elementwise :math:`e^{x} - 1`."""
|
|
return expm1_p.bind(x)
|
|
|
|
def log(x: ArrayLike) -> Array:
|
|
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
|
|
return log_p.bind(x)
|
|
|
|
def log1p(x: ArrayLike) -> Array:
|
|
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
|
|
return log1p_p.bind(x)
|
|
|
|
def tanh(x: ArrayLike) -> Array:
|
|
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
|
|
return tanh_p.bind(x)
|
|
|
|
def logistic(x: ArrayLike) -> Array:
|
|
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
|
|
return logistic_p.bind(x)
|
|
|
|
def sin(x: ArrayLike) -> Array:
|
|
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
|
|
return sin_p.bind(x)
|
|
|
|
def cos(x: ArrayLike) -> Array:
|
|
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
|
|
return cos_p.bind(x)
|
|
|
|
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise arc tangent of two variables:
|
|
:math:`\mathrm{atan}({x \over y})`."""
|
|
return atan2_p.bind(x, y)
|
|
|
|
def real(x: ArrayLike) -> Array:
|
|
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
|
|
|
|
Returns the real part of a complex number.
|
|
"""
|
|
return real_p.bind(x)
|
|
|
|
def imag(x: ArrayLike) -> Array:
|
|
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
|
|
|
|
Returns the imaginary part of a complex number.
|
|
"""
|
|
return imag_p.bind(x)
|
|
|
|
def complex(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise make complex number: :math:`x + jy`.
|
|
|
|
Builds a complex number from real and imaginary parts.
|
|
"""
|
|
return complex_p.bind(x, y)
|
|
|
|
def conj(x: ArrayLike) -> Array:
|
|
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
|
|
# TODO(mattjj): remove input_dtype, not needed anymore
|
|
return conj_p.bind(x, input_dtype=_dtype(x))
|
|
|
|
def abs(x: ArrayLike) -> Array:
|
|
r"""Elementwise absolute value: :math:`|x|`."""
|
|
return abs_p.bind(x)
|
|
|
|
def pow(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise power: :math:`x^y`."""
|
|
return pow_p.bind(x, y)
|
|
|
|
def integer_pow(x: ArrayLike, y: int) -> Array:
|
|
r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer."""
|
|
return integer_pow_p.bind(x, y=y)
|
|
|
|
def sqrt(x: ArrayLike) -> Array:
|
|
r"""Elementwise square root: :math:`\sqrt{x}`."""
|
|
return sqrt_p.bind(x)
|
|
|
|
def rsqrt(x: ArrayLike) -> Array:
|
|
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
|
|
return rsqrt_p.bind(x)
|
|
|
|
def cbrt(x: ArrayLike) -> Array:
|
|
r"""Elementwise cube root: :math:`\sqrt[3]{x}`."""
|
|
return cbrt_p.bind(x)
|
|
|
|
def bitwise_not(x: ArrayLike) -> Array:
|
|
r"""Elementwise NOT: :math:`\neg x`."""
|
|
return not_p.bind(x)
|
|
|
|
def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise AND: :math:`x \wedge y`."""
|
|
return and_p.bind(x, y)
|
|
|
|
def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise OR: :math:`x \vee y`."""
|
|
return or_p.bind(x, y)
|
|
|
|
def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
|
|
return xor_p.bind(x, y)
|
|
|
|
def population_count(x: ArrayLike) -> Array:
|
|
r"""Elementwise popcount, count the number of set bits in each element."""
|
|
return population_count_p.bind(x)
|
|
|
|
def clz(x: ArrayLike) -> Array:
|
|
r"""Elementwise count-leading-zeros."""
|
|
return clz_p.bind(x)
|
|
|
|
def add(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise addition: :math:`x + y`."""
|
|
return add_p.bind(x, y)
|
|
|
|
def sub(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise subtraction: :math:`x - y`."""
|
|
return sub_p.bind(x, y)
|
|
|
|
def mul(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise multiplication: :math:`x \times y`."""
|
|
return mul_p.bind(x, y)
|
|
|
|
def div(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise division: :math:`x \over y`.
|
|
|
|
Integer division overflow
|
|
(division by zero or signed division of INT_SMIN with -1)
|
|
produces an implementation defined value.
|
|
"""
|
|
return div_p.bind(x, y)
|
|
|
|
def rem(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise remainder: :math:`x \bmod y`.
|
|
|
|
The sign of the result is taken from the dividend,
|
|
and the absolute value of the result is always
|
|
less than the divisor's absolute value.
|
|
|
|
Integer division overflow
|
|
(remainder by zero or remainder of INT_SMIN with -1)
|
|
produces an implementation defined value.
|
|
"""
|
|
return rem_p.bind(x, y)
|
|
|
|
def max(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
|
|
|
|
For complex numbers, uses a lexicographic comparison on the
|
|
`(real, imaginary)` pairs."""
|
|
return max_p.bind(x, y)
|
|
|
|
def min(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
|
|
|
For complex numbers, uses a lexicographic comparison on the
|
|
`(real, imaginary)` pairs."""
|
|
return min_p.bind(x, y)
|
|
|
|
def shift_left(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise left shift: :math:`x \ll y`."""
|
|
return shift_left_p.bind(x, y)
|
|
|
|
def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise arithmetic right shift: :math:`x \gg y`."""
|
|
return shift_right_arithmetic_p.bind(x, y)
|
|
|
|
def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise logical right shift: :math:`x \gg y`."""
|
|
return shift_right_logical_p.bind(x, y)
|
|
|
|
def eq(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise equals: :math:`x = y`."""
|
|
return eq_p.bind(x, y)
|
|
|
|
def ne(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise not-equals: :math:`x \neq y`."""
|
|
return ne_p.bind(x, y)
|
|
|
|
def ge(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
|
|
return ge_p.bind(x, y)
|
|
|
|
def gt(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise greater-than: :math:`x > y`."""
|
|
return gt_p.bind(x, y)
|
|
|
|
def le(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
|
|
return le_p.bind(x, y)
|
|
|
|
def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
|
r"""Elementwise less-than: :math:`x < y`."""
|
|
return lt_p.bind(x, y)
|
|
|
|
def convert_element_type(operand: ArrayLike,
|
|
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
|
|
"""Elementwise cast.
|
|
|
|
Wraps XLA's `ConvertElementType
|
|
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
|
|
operator, which performs an elementwise conversion from one type to another.
|
|
Similar to a C++ `static_cast`.
|
|
|
|
Args:
|
|
operand: an array or scalar value to be cast.
|
|
new_dtype: a NumPy dtype representing the target type.
|
|
|
|
Returns:
|
|
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
|
|
"""
|
|
return _convert_element_type(operand, new_dtype, weak_type=False)
|
|
|
|
def _convert_element_type(
|
|
operand: ArrayLike,
|
|
new_dtype: DTypeLike | dtypes.ExtendedDType | None = None,
|
|
weak_type: bool = False,
|
|
sharding: Sharding | None = None):
|
|
if hasattr(operand, '__jax_array__'):
|
|
operand = operand.__jax_array__()
|
|
|
|
if (dtypes.issubdtype(new_dtype, dtypes.extended) or
|
|
dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)):
|
|
return convert_element_type_p.bind(
|
|
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
|
sharding=sharding)
|
|
|
|
new_dtype = type_cast(DTypeLike | None, new_dtype)
|
|
|
|
# Don't canonicalize old_dtype because x64 context might cause
|
|
# un-canonicalized operands to be passed in.
|
|
old_dtype = dtypes.dtype(operand, canonicalize=False)
|
|
old_weak_type = dtypes.is_weakly_typed(operand)
|
|
if new_dtype is None:
|
|
new_dtype = old_dtype
|
|
else:
|
|
new_dtype = np.dtype(new_dtype)
|
|
new_dtype = dtypes.dtype(new_dtype, canonicalize=True)
|
|
|
|
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
msg = "Casting complex values to real discards the imaginary part"
|
|
warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
|
|
|
|
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
|
|
# not be an error since the target dtype fits the value. Handle this case by
|
|
# converting to a NumPy array before calling bind. Without this step, we'd
|
|
# first canonicalize the input to a value of dtype int32 or int64, leading to
|
|
# an overflow error.
|
|
if type(operand) is int:
|
|
operand = np.asarray(operand).astype(new_dtype)
|
|
old_weak_type = False
|
|
|
|
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
|
|
isinstance(operand, Array) and
|
|
not (isinstance(operand, core.Tracer) and
|
|
isinstance(core.get_aval(operand), core.ConcreteArray)) and
|
|
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
|
|
return operand
|
|
else:
|
|
return convert_element_type_p.bind(
|
|
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
|
sharding=sharding)
|
|
|
|
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
|
"""Elementwise bitcast.
|
|
|
|
Wraps XLA's `BitcastConvertType
|
|
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
|
|
operator, which performs a bit cast from one type to another.
|
|
|
|
The output shape depends on the size of the input and output dtypes with
|
|
the following logic::
|
|
|
|
if new_dtype.itemsize == operand.dtype.itemsize:
|
|
output_shape = operand.shape
|
|
if new_dtype.itemsize < operand.dtype.itemsize:
|
|
output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize)
|
|
if new_dtype.itemsize > operand.dtype.itemsize:
|
|
assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize
|
|
output_shape = operand.shape[:-1]
|
|
|
|
Args:
|
|
operand: an array or scalar value to be cast
|
|
new_dtype: the new type. Should be a NumPy type.
|
|
|
|
Returns:
|
|
An array of shape `output_shape` (see above) and type `new_dtype`,
|
|
constructed from the same bits as operand.
|
|
"""
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
|
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
|
|
|
def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array:
|
|
r"""Elementwise clamp.
|
|
|
|
Returns :math:`\mathrm{clamp}(x) = \begin{cases}
|
|
\mathit{min} & \text{if } x < \mathit{min},\\
|
|
\mathit{max} & \text{if } x > \mathit{max},\\
|
|
x & \text{otherwise}
|
|
\end{cases}`.
|
|
"""
|
|
return clamp_p.bind(min, x, max)
|
|
|
|
def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
|
|
"""Concatenates a sequence of arrays along `dimension`.
|
|
|
|
Wraps XLA's `Concatenate
|
|
<https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
|
|
operator.
|
|
|
|
Args:
|
|
operands: a sequence of arrays to concatenate. The arrays must have equal
|
|
shapes, except in the `dimension` axis.
|
|
dimension: the dimension along which to concatenate the arrays.
|
|
|
|
Returns:
|
|
An array containing the concatenation.
|
|
"""
|
|
if len(operands) == 0:
|
|
raise ValueError("concatenate requires a non-empty sequences of arrays")
|
|
if len(operands) == 1:
|
|
op, = operands
|
|
if isinstance(op, Array):
|
|
return op
|
|
return concatenate_p.bind(*operands, dimension=dimension)
|
|
|
|
|
|
_precision_strings: dict[Any, Precision] = {}
|
|
|
|
class Precision(enum.Enum):
|
|
"""Precision enum for lax matrix multiply related functions.
|
|
|
|
The device-dependent `precision` argument to JAX functions generally
|
|
controls the tradeoff between speed and accuracy for array computations on
|
|
accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends.
|
|
This only has an effect on float32 computations, and does not affect the
|
|
input/output datatypes. Members are:
|
|
|
|
DEFAULT:
|
|
Fastest mode, but least accurate. On TPU: performs float32 computations in
|
|
bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100
|
|
GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases:
|
|
``'default'``, ``'fastest'``.
|
|
HIGH:
|
|
Slower but more accurate. On TPU: performs float32 computations in 3
|
|
bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise
|
|
float32. Aliases: ``'high'``..
|
|
HIGHEST:
|
|
Slowest but most accurate. On TPU: performs float32 computations in 6
|
|
bfloat16. Aliases: ``'highest'``. On GPU: uses float32.
|
|
"""
|
|
|
|
DEFAULT = 0
|
|
HIGH = 1
|
|
HIGHEST = 2
|
|
|
|
@classmethod
|
|
def _missing_(cls, value: object) -> Precision | None:
|
|
return _precision_strings.get(value)
|
|
|
|
def __repr__(self) -> str:
|
|
return f'{self.__class__.__name__}.{self.name}'
|
|
|
|
def __str__(self) -> str:
|
|
return self.name
|
|
|
|
|
|
_precision_strings['highest'] = Precision.HIGHEST
|
|
_precision_strings['float32'] = Precision.HIGHEST
|
|
_precision_strings['high'] = Precision.HIGH
|
|
_precision_strings['bfloat16_3x'] = Precision.HIGH
|
|
_precision_strings['tensorfloat32'] = Precision.HIGH
|
|
_precision_strings['default'] = Precision.DEFAULT
|
|
_precision_strings['bfloat16'] = Precision.DEFAULT
|
|
_precision_strings['fastest'] = Precision.DEFAULT
|
|
_precision_strings[None] = Precision.DEFAULT
|
|
|
|
|
|
PrecisionLike = Union[
|
|
str,
|
|
Precision,
|
|
tuple[str, str],
|
|
tuple[Precision, Precision],
|
|
None,
|
|
]
|
|
|
|
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
|
preferred_element_type: DTypeLike | None = None) -> Array:
|
|
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
|
|
|
Wraps XLA's `Dot
|
|
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
|
|
operator.
|
|
|
|
For more general contraction, see the :func:`jax.lax.dot_general` operator.
|
|
|
|
Args:
|
|
lhs: an array of dimension 1 or 2.
|
|
rhs: an array of dimension 1 or 2.
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
|
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
|
|
preferred_element_type: Optional. Either ``None``, which means the default
|
|
accumulation type for the input types, or a datatype, indicating to
|
|
accumulate results to and return a result with that datatype.
|
|
|
|
Returns:
|
|
An array containing the product.
|
|
"""
|
|
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]):
|
|
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type)
|
|
else:
|
|
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
|
|
lhs.shape, rhs.shape))
|
|
|
|
|
|
DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
|
tuple[Sequence[int], Sequence[int]]]
|
|
|
|
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
|
precision: PrecisionLike = None,
|
|
preferred_element_type: DTypeLike | None = None) -> Array:
|
|
"""General dot product/contraction operator.
|
|
|
|
Wraps XLA's `DotGeneral
|
|
<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
|
|
operator.
|
|
|
|
The semantics of ``dot_general`` are complicated, but most users should not have to
|
|
use it directly. Instead, you can use higher-level functions like :func:`jax.numpy.dot`,
|
|
:func:`jax.numpy.matmul`, :func:`jax.numpy.tensordot`, :func:`jax.numpy.einsum`,
|
|
and others which will construct appropriate calls to ``dot_general`` under the hood.
|
|
If you really want to understand ``dot_general`` itself, we recommend reading XLA's
|
|
`DotGeneral <https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
|
|
operator documentation.
|
|
|
|
Args:
|
|
lhs: an array
|
|
rhs: an array
|
|
dimension_numbers: a tuple of tuples of sequences of ints of the form
|
|
``((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))``
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
|
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
|
|
preferred_element_type: Optional. Either ``None``, which means the default
|
|
accumulation type for the input types, or a datatype, indicating to
|
|
accumulate results to and return a result with that datatype.
|
|
|
|
Returns:
|
|
An array whose first dimensions are the (shared) batch dimensions, followed by
|
|
the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
|
|
non-contracting/non-batch dimensions.
|
|
"""
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
|
api_util._ensure_index_tuple(rhs_contract))
|
|
bdims = (api_util._ensure_index_tuple(lhs_batch),
|
|
api_util._ensure_index_tuple(rhs_batch))
|
|
preferred_element_type = (
|
|
None if preferred_element_type is None else
|
|
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
|
|
return dot_general_p.bind(lhs, rhs,
|
|
dimension_numbers=(cdims, bdims),
|
|
precision=canonicalize_precision(precision),
|
|
preferred_element_type=preferred_element_type)
|
|
|
|
|
|
def ragged_dot(
|
|
lhs: Array,
|
|
rhs: Array,
|
|
group_sizes: Array,
|
|
precision: PrecisionLike = None,
|
|
preferred_element_type: DTypeLike | None = None,
|
|
group_offset: Array | None = None,
|
|
) -> Array:
|
|
"""Ragged matrix multiplication.
|
|
|
|
Args:
|
|
lhs: (m, k) shaped array.
|
|
rhs: (g, k, n) shaped array.
|
|
group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
|
|
precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
|
preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
|
group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].
|
|
|
|
Results:
|
|
(m, n) shaped array with preferred_element_type element type.
|
|
"""
|
|
return ragged_dot_p.bind(lhs, rhs, group_sizes,
|
|
precision=canonicalize_precision(precision),
|
|
preferred_element_type=preferred_element_type, group_offset=group_offset)
|
|
|
|
|
|
def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array:
|
|
"""Broadcasts an array, adding new leading dimensions
|
|
|
|
Args:
|
|
operand: an array
|
|
sizes: a sequence of integers, giving the sizes of new leading dimensions
|
|
to add to the front of the array.
|
|
|
|
Returns:
|
|
An array containing the result.
|
|
|
|
See Also:
|
|
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
|
|
"""
|
|
if len(sizes) == 0:
|
|
return asarray(operand)
|
|
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
|
|
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
|
|
|
|
def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
|
broadcast_dimensions: Sequence[int]) -> Array:
|
|
"""Wraps XLA's `BroadcastInDim
|
|
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
|
|
operator.
|
|
|
|
Args:
|
|
operand: an array
|
|
shape: the shape of the target array
|
|
broadcast_dimensions: to which dimension in the target shape each dimension
|
|
of the operand shape corresponds to. That is, dimension i of the operand
|
|
becomes dimension broadcast_dimensions[i] of the result.
|
|
|
|
Returns:
|
|
An array containing the result.
|
|
|
|
See Also:
|
|
jax.lax.broadcast : simpler interface to add new leading dimensions.
|
|
"""
|
|
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array):
|
|
return operand
|
|
if config.dynamic_shapes.value:
|
|
# We must gate this behavior under a flag because otherwise the errors
|
|
# raised are different (and have worse source provenance information).
|
|
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
|
|
else:
|
|
dyn_shape, static_shape = [], shape # type: ignore
|
|
return broadcast_in_dim_p.bind(
|
|
operand, *dyn_shape, shape=tuple(static_shape),
|
|
broadcast_dimensions=tuple(broadcast_dimensions))
|
|
|
|
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
|
|
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
|
|
ndim = np.ndim(x)
|
|
if ndim == rank:
|
|
return asarray(x)
|
|
return broadcast(x, (1,) * (rank - ndim))
|
|
|
|
def reshape(operand: ArrayLike, new_sizes: Shape,
|
|
dimensions: Sequence[int] | None = None) -> Array:
|
|
"""Wraps XLA's `Reshape
|
|
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
|
operator.
|
|
|
|
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
|
|
``lax.expand_dims``. These preserve information about axis identity that may
|
|
be useful for advanced transformation rules.
|
|
|
|
Args:
|
|
operand: array to be reshaped.
|
|
new_sizes: sequence of integers specifying the resulting shape. The size
|
|
of the final array must match the size of the input.
|
|
dimensions: optional sequence of integers specifying the permutation order of
|
|
the input shape. If specified, the length must match ``operand.shape``.
|
|
|
|
Returns:
|
|
out: reshaped array.
|
|
|
|
Examples:
|
|
Simple reshaping from one to two dimensions:
|
|
|
|
>>> x = jnp.arange(6)
|
|
>>> y = reshape(x, (2, 3))
|
|
>>> y
|
|
Array([[0, 1, 2],
|
|
[3, 4, 5]], dtype=int32)
|
|
|
|
Reshaping back to one dimension:
|
|
|
|
>>> reshape(y, (6,))
|
|
Array([0, 1, 2, 3, 4, 5], dtype=int32)
|
|
|
|
Reshaping to one dimension with permutation of dimensions:
|
|
|
|
>>> reshape(y, (6,), (1, 0))
|
|
Array([0, 3, 1, 4, 2, 5], dtype=int32)
|
|
"""
|
|
new_sizes = canonicalize_shape(new_sizes) # TODO
|
|
new_sizes = tuple(new_sizes)
|
|
same_shape = core.definitely_equal_shape(np.shape(operand), new_sizes)
|
|
if dimensions is None:
|
|
same_dims = True
|
|
dims = None
|
|
else:
|
|
dims = api_util._ensure_index_tuple(dimensions)
|
|
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
|
|
if np.shape(operand) and same_shape and same_dims and isinstance(operand, Array):
|
|
return operand
|
|
else:
|
|
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
|
|
|
|
return reshape_p.bind(
|
|
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
|
|
dimensions=None if dims is None or same_dims else dims)
|
|
|
|
def pad(operand: ArrayLike, padding_value: ArrayLike,
|
|
padding_config: Sequence[tuple[int, int, int]]) -> Array:
|
|
"""Applies low, high, and/or interior padding to an array.
|
|
|
|
Wraps XLA's `Pad
|
|
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
|
|
operator.
|
|
|
|
Args:
|
|
operand: an array to be padded.
|
|
padding_value: the value to be inserted as padding. Must have the same dtype
|
|
as ``operand``.
|
|
padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
|
|
giving the amount of low, high, and interior (dilation) padding to insert
|
|
in each dimension.
|
|
|
|
Returns:
|
|
The ``operand`` array with padding value ``padding_value`` inserted in each
|
|
dimension according to the ``padding_config``.
|
|
"""
|
|
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
|
|
|
def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array:
|
|
"""Wraps XLA's `Rev
|
|
<https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
|
|
operator.
|
|
"""
|
|
return rev_p.bind(operand, dimensions=tuple(dimensions))
|
|
|
|
def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array:
|
|
"""Selects between two branches based on a boolean predicate.
|
|
|
|
Wraps XLA's `Select
|
|
<https://www.tensorflow.org/xla/operation_semantics#select>`_
|
|
operator.
|
|
|
|
In general :func:`~jax.lax.select` leads to evaluation of both branches, although
|
|
the compiler may elide computations if possible. For a similar function that
|
|
usually evaluates only a single branch, see :func:`~jax.lax.cond`.
|
|
|
|
Args:
|
|
pred: boolean array
|
|
on_true: array containing entries to return where ``pred`` is True. Must have
|
|
the same shape as ``pred``, and the same shape and dtype as ``on_false``.
|
|
on_false: array containing entries to return where ``pred`` is False. Must have
|
|
the same shape as ``pred``, and the same shape and dtype as ``on_true``.
|
|
|
|
Returns:
|
|
result: array with same shape and dtype as ``on_true`` and ``on_false``.
|
|
"""
|
|
# Caution! The select_n_p primitive has the *opposite* order of arguments to
|
|
# select(). This is because it implements `select_n`.
|
|
return select_n_p.bind(pred, on_false, on_true)
|
|
|
|
def select_n(which: ArrayLike, *cases: ArrayLike) -> Array:
|
|
"""Selects array values from multiple cases.
|
|
|
|
Generalizes XLA's `Select
|
|
<https://www.tensorflow.org/xla/operation_semantics#select>`_
|
|
operator. Unlike XLA's version, the operator is variadic and can select
|
|
from many cases using an integer `pred`.
|
|
|
|
Args:
|
|
which: determines which case should be returned. Must be an array containing
|
|
either a boolean or integer values. May either be a scalar or have
|
|
shape matching ``cases``. For each array element, the value of ``which``
|
|
determines which of ``cases`` is taken. ``which`` must be in the range
|
|
``[0 .. len(cases))``; for values outside that range the behavior is
|
|
implementation-defined.
|
|
*cases: a non-empty list of array cases. All must have equal dtypes and
|
|
equal shapes.
|
|
Returns:
|
|
An array with shape and dtype equal to the cases, whose values are chosen
|
|
according to ``which``.
|
|
"""
|
|
if len(cases) == 0:
|
|
raise ValueError("select_n() must have at least one case")
|
|
return select_n_p.bind(which, *cases)
|
|
|
|
|
|
def transpose(operand: ArrayLike,
|
|
permutation: Sequence[int] | np.ndarray) -> Array:
|
|
"""Wraps XLA's `Transpose
|
|
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
|
|
operator.
|
|
"""
|
|
permutation = tuple(operator.index(d) for d in permutation)
|
|
if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array):
|
|
return operand
|
|
else:
|
|
return transpose_p.bind(operand, permutation=permutation)
|
|
|
|
def argmin(operand: ArrayLike, axis: int,
|
|
index_dtype: DTypeLike) -> Array:
|
|
"""Computes the index of the minimum element along ``axis``."""
|
|
return argmin_p.bind(operand, axes=(axis,),
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
def argmax(operand: ArrayLike, axis: int,
|
|
index_dtype: DTypeLike) -> Array:
|
|
"""Computes the index of the maximum element along ``axis``."""
|
|
return argmax_p.bind(operand, axes=(axis,),
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
def reduce(operands: Any,
|
|
init_values: Any,
|
|
computation: Callable[[Any, Any], Any],
|
|
dimensions: Sequence[int]) -> Any:
|
|
"""Wraps XLA's `Reduce
|
|
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
|
|
operator.
|
|
|
|
``init_values`` and ``computation`` together must form a `monoid
|
|
<https://en.wikipedia.org/wiki/Monoid>`_
|
|
for correctness. That is ``init_values`` must be an identity of
|
|
``computation``, and ``computation`` must be associative. XLA may exploit both
|
|
of these properties during code generation; if either is violated the result
|
|
is undefined.
|
|
"""
|
|
flat_operands, operand_tree = tree_util.tree_flatten(operands)
|
|
flat_init_values, init_value_tree = tree_util.tree_flatten(init_values)
|
|
if operand_tree != init_value_tree:
|
|
raise ValueError('Operands must have the same tree structure as init_values:'
|
|
f' {operand_tree} vs. {init_value_tree}')
|
|
if len(flat_operands) != len(flat_init_values):
|
|
raise ValueError('Must have same total number of operands as init_values: '
|
|
f' {len(flat_operands)} vs. {len(flat_init_values)}')
|
|
monoid_reducer = _get_monoid_reducer(computation, flat_init_values)
|
|
if monoid_reducer:
|
|
# monoid reducers bypass the weak_type_rule, so we set it explicitly.
|
|
weak_type = dtypes.is_weakly_typed(*flat_operands) and dtypes.is_weakly_typed(*flat_init_values)
|
|
return _convert_element_type(monoid_reducer(*flat_operands, dimensions),
|
|
weak_type=weak_type)
|
|
else:
|
|
flat_init_avals = safe_map(_abstractify, flat_init_values)
|
|
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
|
|
computation, tuple(flat_init_avals), init_value_tree)
|
|
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
|
|
jaxpr=closed_jaxpr, dimensions=tuple(dimensions))
|
|
return tree_util.tree_unflatten(out_tree, out)
|
|
|
|
@cache()
|
|
def _reduction_jaxpr(computation, aval):
|
|
@lu.wrap_init
|
|
def comp(x, y):
|
|
result = computation(x, y)
|
|
if not (isinstance(result, core.Tracer) or core.valid_jaxtype(result)):
|
|
raise ValueError(
|
|
f"Invalid return type from reduction function: {type(result)}\n"
|
|
f"Reduction functions should only return an array.\n"
|
|
f"Full return value: {result}")
|
|
return (result,)
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp, (aval, aval))
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
|
raise NotImplementedError(
|
|
"Reduction computations can't close over Tracers. Please open an issue "
|
|
"at https://github.com/jax-ml/jax.")
|
|
return jaxpr, tuple(consts)
|
|
|
|
@cache()
|
|
def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
|
|
avals = tree_util.tree_unflatten(aval_tree, flat_avals)
|
|
flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
|
|
comp = lu.wrap_init(computation)
|
|
flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
|
raise NotImplementedError(
|
|
"Reduction computations can't close over Tracers. Please open an issue "
|
|
"at https://github.com/jax-ml/jax.")
|
|
return core.ClosedJaxpr(jaxpr, consts), out_tree()
|
|
|
|
def _get_monoid_reducer(monoid_op: Callable,
|
|
xs: Sequence[Array]) -> Callable | None:
|
|
if len(xs) != 1:
|
|
return None
|
|
x, = xs
|
|
aval = core.get_aval(x)
|
|
dtype = _dtype(x)
|
|
if (type(aval) is ConcreteArray) and aval.shape == ():
|
|
# allow bitwise reductions for boolean and integer types
|
|
_is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer)
|
|
if monoid_op is add:
|
|
return _reduce_sum if np.equal(aval.val, 0) else None
|
|
elif monoid_op is mul:
|
|
return _reduce_prod if np.equal(aval.val, 1) else None
|
|
elif monoid_op is bitwise_or and _is_intlike:
|
|
return _reduce_or if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None
|
|
elif monoid_op is bitwise_and and _is_intlike:
|
|
return _reduce_and if np.equal(aval.val, _get_bitwise_and_identity(dtype)) else None
|
|
elif monoid_op is bitwise_xor and _is_intlike:
|
|
return _reduce_xor if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None
|
|
elif monoid_op is max:
|
|
return _reduce_max if np.equal(aval.val, _get_max_identity(dtype)) else None
|
|
elif monoid_op is min:
|
|
return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None
|
|
return None
|
|
|
|
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
|
|
return np.array(-1).astype(dtype)
|
|
|
|
def _get_bitwise_or_identity(dtype: DTypeLike) -> np.ndarray:
|
|
return np.array(0, dtype)
|
|
|
|
def _get_sum_identity(dtype: DTypeLike) -> np.ndarray:
|
|
return np.array(0, dtype)
|
|
|
|
def _get_prod_identity(dtype: DTypeLike) -> np.ndarray:
|
|
return np.array(1, dtype)
|
|
|
|
def _get_max_identity(dtype: DTypeLike) -> np.ndarray:
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
return np.array(-np.inf if dtypes.supports_inf(dtype) else dtypes.finfo(dtype).min,
|
|
dtype=dtype)
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
return np.array(dtypes.iinfo(dtype).min, dtype)
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
return np.array(False, np.bool_)
|
|
else:
|
|
raise ValueError(f"Unsupported dtype for max: {dtype}")
|
|
|
|
def _get_min_identity(dtype: DTypeLike) -> np.ndarray:
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
return np.array(np.inf if dtypes.supports_inf(dtype) else dtypes.finfo(dtype).max,
|
|
dtype=dtype)
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
return np.array(dtypes.iinfo(dtype).max, dtype)
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
return np.array(True, np.bool_)
|
|
else:
|
|
raise ValueError(f"Unsupported dtype for min: {dtype}")
|
|
|
|
def _reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
return reduce_sum_p.bind(operand, axes=tuple(axes))
|
|
|
|
def _reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
return reduce_prod_p.bind(operand, axes=tuple(axes))
|
|
|
|
def _reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
return reduce_max_p.bind(operand, axes=tuple(axes))
|
|
|
|
def _reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
return reduce_min_p.bind(operand, axes=tuple(axes))
|
|
|
|
def _reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
return reduce_or_p.bind(operand, axes=tuple(axes))
|
|
|
|
def _reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
return reduce_and_p.bind(operand, axes=tuple(axes))
|
|
|
|
def _reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
|
return reduce_xor_p.bind(operand, axes=tuple(axes))
|
|
|
|
@overload
|
|
def sort(operand: Array, dimension: int = -1,
|
|
is_stable: bool = True, num_keys: int = 1) -> Array: ...
|
|
|
|
@overload
|
|
def sort(operand: Sequence[Array], dimension: int = -1,
|
|
is_stable: bool = True, num_keys: int = 1) -> tuple[Array, ...]: ...
|
|
|
|
def sort(operand: Array | Sequence[Array], dimension: int = -1,
|
|
is_stable: bool = True, num_keys: int = 1) -> Array | tuple[Array, ...]:
|
|
"""Wraps XLA's `Sort
|
|
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
|
|
|
|
For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values
|
|
are sorted to the end of the array. For complex inputs, the sort order is
|
|
lexicographic over the real and imaginary parts, with the real part primary.
|
|
|
|
Args:
|
|
operand : Array or sequence of arrays
|
|
dimension : integer dimension along which to sort. Default: -1.
|
|
is_stable : boolean specifying whether to use a stable sort. Default: True.
|
|
num_keys : number of operands to treat as sort keys. Default: 1.
|
|
For num_keys > 1, the sort order will be determined lexicographically using
|
|
the first `num_keys` arrays, with the first key being primary.
|
|
The remaining operands will be returned with the same permutation.
|
|
|
|
Returns:
|
|
operand : sorted version of the input or inputs.
|
|
"""
|
|
if isinstance(operand, Sequence):
|
|
if len(operand) == 0:
|
|
raise TypeError("Sort requires at least one operand")
|
|
if not (1 <= num_keys <= len(operand)):
|
|
raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}")
|
|
dimension = canonicalize_axis(dimension, len(operand[0].shape))
|
|
return tuple(sort_p.bind(*operand, dimension=dimension,
|
|
is_stable=is_stable,
|
|
num_keys=num_keys))
|
|
else:
|
|
if num_keys != 1:
|
|
raise ValueError(f"{num_keys=} must equal 1 for a single operand.")
|
|
dimension = canonicalize_axis(dimension, len(operand.shape))
|
|
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
|
|
|
|
def sort_key_val(keys: Array, values: ArrayLike, dimension: int = -1,
|
|
is_stable: bool = True) -> tuple[Array, Array]:
|
|
"""Sorts ``keys`` along ``dimension`` and applies the same permutation to ``values``."""
|
|
dimension = canonicalize_axis(dimension, len(keys.shape))
|
|
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
|
|
return k, v
|
|
|
|
def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
|
|
"""Returns top ``k`` values and their indices along the last axis of ``operand``.
|
|
|
|
Args:
|
|
operand: N-dimensional array of non-complex type.
|
|
k: integer specifying the number of top entries.
|
|
|
|
Returns:
|
|
A tuple ``(values, indices)`` where
|
|
|
|
- ``values`` is an array containing the top k values along the last axis.
|
|
- ``indices`` is an array containing the indices corresponding to values.
|
|
|
|
See also:
|
|
- :func:`jax.lax.approx_max_k`
|
|
- :func:`jax.lax.approx_min_k`
|
|
|
|
Examples:
|
|
Find the largest three values, and their indices, within an array:
|
|
|
|
>>> x = jnp.array([9., 3., 6., 4., 10.])
|
|
>>> values, indices = jax.lax.top_k(x, 3)
|
|
>>> values
|
|
Array([10., 9., 6.], dtype=float32)
|
|
>>> indices
|
|
Array([4, 0, 2], dtype=int32)
|
|
"""
|
|
if core.is_constant_dim(k):
|
|
k = int(k)
|
|
if k < 0:
|
|
raise ValueError(f"k argument to top_k must be nonnegative, got {k}")
|
|
return top_k_p.bind(operand, k=k)
|
|
|
|
def tie_in(x: Any, y: T) -> T:
|
|
"""Deprecated. Ignores ``x`` and returns ``y``."""
|
|
return y
|
|
|
|
def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
|
sharding: Sharding | None = None) -> Array:
|
|
"""Returns an array of `shape` filled with `fill_value`.
|
|
|
|
Args:
|
|
shape: sequence of integers, describing the shape of the output array.
|
|
fill_value: the value to fill the new array with.
|
|
dtype: the type of the output array, or `None`. If not `None`, `fill_value`
|
|
will be cast to `dtype`.
|
|
sharding: an optional sharding specification for the resulting array,
|
|
note, sharding will currently be ignored in jitted mode, this might change
|
|
in the future.
|
|
"""
|
|
shape = canonicalize_shape(shape)
|
|
if np.shape(fill_value):
|
|
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
|
raise TypeError(msg.format(np.shape(fill_value)))
|
|
if dtypes.issubdtype(dtype, dtypes.extended):
|
|
return dtype._rules.full(shape, fill_value, dtype) # type: ignore[union-attr]
|
|
weak_type = dtype is None and dtypes.is_weakly_typed(fill_value)
|
|
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
|
|
fill_value = _convert_element_type(fill_value, dtype, weak_type)
|
|
# In tracing mode we can't set sharing explictly and PmapShardng is not
|
|
# supported.
|
|
# NB: Consider using with_sharding_constraint in jitted computation
|
|
# if needed?
|
|
if (sharding is not None and not isinstance(sharding, PmapSharding) and
|
|
isinstance(fill_value, array.ArrayImpl)):
|
|
broadcast_shape = sharding.shard_shape(shape)
|
|
shard = broadcast(fill_value, broadcast_shape)
|
|
return array.make_array_from_callback(shape, sharding, lambda _: shard)
|
|
|
|
return broadcast(fill_value, shape)
|
|
|
|
def zeros_like_shaped_array(aval: ShapedArray) -> Array:
|
|
assert isinstance(aval, ShapedArray)
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
|
scalar_zero = aval.dtype._rules.zero(aval.dtype)
|
|
elif aval.dtype == dtypes.float0:
|
|
scalar_zero = np.zeros((), dtype=aval.dtype)
|
|
else:
|
|
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
|
|
return broadcast(scalar_zero, aval.shape)
|
|
|
|
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
|
|
|
def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray:
|
|
val = ad_util.zeros_like_aval(aval.inner_aval)
|
|
return core.mutable_array(val)
|
|
|
|
ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore
|
|
|
|
def iota(dtype: DTypeLike, size: int) -> Array:
|
|
"""Wraps XLA's `Iota
|
|
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
|
|
operator.
|
|
"""
|
|
return broadcasted_iota(dtype, (size,), 0)
|
|
|
|
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array:
|
|
"""Convenience wrapper around ``iota``."""
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
shape = canonicalize_shape(shape)
|
|
dynamic_shape = [d for d in shape if isinstance(d, core.Tracer)]
|
|
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
|
|
dimension = core.concrete_or_error(
|
|
int, dimension, "dimension argument of lax.broadcasted_iota")
|
|
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
|
|
dimension=dimension)
|
|
|
|
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
|
|
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
|
offset = _clip_int_to_valid_range(offset, np.int32,
|
|
"argument `offset` of jax.numpy.eye")
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
|
broadcasted_iota(np.int32, shape, 1))
|
|
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False,
|
|
sharding=None)
|
|
|
|
def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
|
"""This utility function exists for creating Kronecker delta arrays."""
|
|
axes = map(int, axes)
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
base_shape = tuple(np.take(shape, axes))
|
|
iotas = [broadcasted_iota(np.uint32, base_shape, i)
|
|
for i in range(len(base_shape))]
|
|
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
|
result = convert_element_type_p.bind(
|
|
_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False,
|
|
sharding=None)
|
|
return broadcast_in_dim(result, shape, axes)
|
|
|
|
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
|
|
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
|
|
offset = _clip_int_to_valid_range(offset, np.int32,
|
|
"argument `offset` of jax.numpy.tri")
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0),
|
|
asarray(core.dimension_as_value(offset)).astype(np.int32)),
|
|
broadcasted_iota(np.int32, shape, 1))
|
|
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False,
|
|
sharding=None)
|
|
|
|
def stop_gradient(x: T) -> T:
|
|
"""Stops gradient computation.
|
|
|
|
Operationally ``stop_gradient`` is the identity function, that is, it returns
|
|
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
|
|
gradients during forward or reverse-mode automatic differentiation. If there
|
|
are multiple nested gradient computations, ``stop_gradient`` stops gradients
|
|
for all of them.
|
|
|
|
For example:
|
|
|
|
>>> jax.grad(lambda x: x**2)(3.)
|
|
Array(6., dtype=float32, weak_type=True)
|
|
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
|
|
Array(0., dtype=float32, weak_type=True)
|
|
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
|
|
Array(2., dtype=float32, weak_type=True)
|
|
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
|
Array(0., dtype=float32, weak_type=True)
|
|
"""
|
|
def stop(x):
|
|
# only bind primitive on inexact dtypes, to avoid some staging
|
|
if dtypes.issubdtype(core.get_aval(x).dtype, dtypes.extended):
|
|
return x
|
|
elif (dtypes.issubdtype(_dtype(x), np.floating) or
|
|
dtypes.issubdtype(_dtype(x), np.complexfloating)):
|
|
return ad_util.stop_gradient_p.bind(x)
|
|
else:
|
|
return x
|
|
return tree_map(stop, x)
|
|
|
|
def reduce_precision(operand: float | ArrayLike,
|
|
exponent_bits: int,
|
|
mantissa_bits: int) -> Array:
|
|
"""Wraps XLA's `ReducePrecision
|
|
<https://www.tensorflow.org/xla/operation_semantics#reduceprecision>`_
|
|
operator.
|
|
"""
|
|
exponent_bits = core.concrete_or_error(
|
|
operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision")
|
|
mantissa_bits = core.concrete_or_error(
|
|
operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision")
|
|
return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits)
|
|
|
|
def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array:
|
|
"""Squeeze any number of size 1 dimensions from an array."""
|
|
ndim = np.ndim(array)
|
|
dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions))
|
|
if not dimensions and isinstance(array, Array):
|
|
return array
|
|
return squeeze_p.bind(array, dimensions=dimensions)
|
|
|
|
def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:
|
|
"""Insert any number of size 1 dimensions into an array."""
|
|
if len(set(dimensions)) != len(dimensions):
|
|
raise ValueError(f'repeated axis in lax.expand_dims: {dimensions}')
|
|
ndim_out = np.ndim(array) + len(dimensions)
|
|
dims = [canonicalize_axis(i, ndim_out) for i in dimensions]
|
|
if len(set(dims)) != len(dims): # check again after canonicalizing
|
|
raise ValueError(f'repeated axis in lax.expand_dims: {dims}')
|
|
dims_set = frozenset(dims)
|
|
result_shape = list(np.shape(array))
|
|
for i in sorted(dims_set):
|
|
result_shape.insert(i, 1)
|
|
broadcast_dims = [i for i in range(ndim_out) if i not in dims_set]
|
|
return broadcast_in_dim(array, result_shape, broadcast_dims)
|
|
|
|
|
|
### convenience wrappers around traceables
|
|
|
|
def full_like(x: ArrayLike | DuckTypedArray,
|
|
fill_value: ArrayLike, dtype: DTypeLike | None = None,
|
|
shape: Shape | None = None, sharding: Sharding | None = None) -> Array:
|
|
"""Create a full array like np.full based on the example array `x`.
|
|
|
|
Args:
|
|
x: example array-like, used for shape and dtype information.
|
|
fill_value: a scalar value to fill the entries of the output array.
|
|
dtype: optional, a dtype parameter for the output ndarray.
|
|
shape: optional, a shape parameter for the output ndarray.
|
|
sharding: an optional sharding specification for the resulting array.
|
|
If not specified, the output will have the same sharding as the input,
|
|
with a few exceptions/limitations in particular:
|
|
1. Sharding is not available during tracing, thus this will rely on jit.
|
|
2. If x is weakly typed or uncommitted, will use default sharding.
|
|
3. Shape is not None and is different from x.shape, default will be used.
|
|
|
|
Returns:
|
|
An ndarray with the same shape as `x` with its entries set equal to
|
|
`fill_value`, similar to the output of np.full.
|
|
"""
|
|
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape) # type: ignore[arg-type]
|
|
weak_type = dtype is None and dtypes.is_weakly_typed(x)
|
|
dtype = dtype or _dtype(x)
|
|
if dtypes.issubdtype(dtype, dtypes.extended):
|
|
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
|
|
|
|
# If `x` has a sharding but no `_committed` attribute
|
|
# (in case of ShapeDtypeStruct), default it to True.
|
|
use_x_sharding = (
|
|
sharding is None
|
|
# Tracer have special logic in handling sharding and even
|
|
# though hasattr(x, 'sharding') returns False, it is very slow.
|
|
# This bypasses the check.
|
|
and not isinstance(x, core.Tracer)
|
|
and hasattr(x, 'sharding')
|
|
and getattr(x, '_committed', True)
|
|
and not weak_type
|
|
and fill_shape == np.shape(x) # type: ignore[arg-type]
|
|
)
|
|
if use_x_sharding:
|
|
# TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported.
|
|
sharding = x.sharding # type: ignore
|
|
val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type),
|
|
sharding=sharding)
|
|
return val
|
|
|
|
|
|
def collapse(operand: Array, start_dimension: int,
|
|
stop_dimension: int | None = None) -> Array:
|
|
"""Collapses dimensions of an array into a single dimension.
|
|
|
|
For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
|
|
``collapse(operand, 0, 2).shape == [6, 4]``. The elements of the collapsed
|
|
dimension are laid out major-to-minor, i.e., with the lowest-numbered
|
|
dimension as the slowest varying dimension.
|
|
|
|
Args:
|
|
operand: an input array.
|
|
start_dimension: the start of the dimensions to collapse (inclusive).
|
|
stop_dimension: the end of the dimensions to collapse (exclusive). Pass None
|
|
to collapse all the dimensions after start.
|
|
|
|
Returns:
|
|
An array where dimensions ``[start_dimension, stop_dimension)`` have been
|
|
collapsed (raveled) into a single dimension.
|
|
"""
|
|
lo, hi, _ = slice(start_dimension, stop_dimension).indices(len(operand.shape))
|
|
if hi < lo:
|
|
raise ValueError(f"Invalid dimension range passed to collapse: {operand.shape}"
|
|
f"[{start_dimension}:{stop_dimension}]")
|
|
size = math.prod(operand.shape[lo:hi])
|
|
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
|
return reshape(operand, new_shape)
|
|
|
|
|
|
def batch_matmul(lhs: Array, rhs: Array,
|
|
precision: PrecisionLike = None) -> Array:
|
|
"""Batch matrix multiplication."""
|
|
if _min(lhs.ndim, rhs.ndim) < 2:
|
|
raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}'
|
|
.format(lhs.ndim, rhs.ndim))
|
|
if lhs.ndim != rhs.ndim:
|
|
raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}'
|
|
.format(lhs.ndim, rhs.ndim))
|
|
lhs_contract = (lhs.ndim - 1,)
|
|
rhs_contract = (rhs.ndim - 2,)
|
|
batch = tuple(range(lhs.ndim - 2))
|
|
return dot_general(lhs, rhs, ((lhs_contract, rhs_contract), (batch, batch)),
|
|
precision=precision)
|
|
|
|
|
|
# These functions also exist in the XLA client library, but we treat them
|
|
# as non-primitive to maintain a smaller set of autodiff primitives.
|
|
|
|
def square(x: ArrayLike) -> Array:
|
|
r"""Elementwise square: :math:`x^2`."""
|
|
return integer_pow(x, 2)
|
|
|
|
def reciprocal(x: ArrayLike) -> Array:
|
|
r"""Elementwise reciprocal: :math:`1 \over x`."""
|
|
return integer_pow(x, -1)
|
|
|
|
def _upcast_fp16_for_computation(f):
|
|
@functools.wraps(f)
|
|
def f_wrapped(x):
|
|
dtype = _dtype(x)
|
|
if dtype == np.float16 or dtype == dtypes.bfloat16:
|
|
return convert_element_type(
|
|
f(convert_element_type(x, np.float32)), dtype)
|
|
return f(x)
|
|
|
|
return f_wrapped
|
|
|
|
def tan(x: ArrayLike) -> Array:
|
|
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
|
|
return tan_p.bind(x)
|
|
|
|
def asin(x: ArrayLike) -> Array:
|
|
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
|
|
return asin_p.bind(x)
|
|
|
|
def acos(x: ArrayLike) -> Array:
|
|
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
|
|
return acos_p.bind(x)
|
|
|
|
def atan(x: ArrayLike) -> Array:
|
|
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
|
|
return atan_p.bind(x)
|
|
|
|
def sinh(x: ArrayLike) -> Array:
|
|
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`."""
|
|
return sinh_p.bind(x)
|
|
|
|
def cosh(x: ArrayLike) -> Array:
|
|
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`."""
|
|
return cosh_p.bind(x)
|
|
|
|
def asinh(x: ArrayLike) -> Array:
|
|
r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
|
|
return asinh_p.bind(x)
|
|
|
|
def acosh(x: ArrayLike) -> Array:
|
|
r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
|
|
return acosh_p.bind(x)
|
|
|
|
def atanh(x: ArrayLike) -> Array:
|
|
r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
|
|
return atanh_p.bind(x)
|
|
|
|
|
|
# Add some methods to ShapedArray that rely on lax primitives
|
|
|
|
ShapedArray.broadcast = core.aval_method(broadcast)
|
|
ShapedArray.transpose = core.aval_method(transpose) # clobbered by lax_numpy
|
|
ShapedArray.reshape = core.aval_method(reshape) # clobbered by lax_numpy
|
|
|
|
def _iter(tracer):
|
|
if tracer.ndim == 0:
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
else:
|
|
n = int(tracer.shape[0])
|
|
if any(isinstance(d, core.Tracer) for d in tracer.shape):
|
|
return (slicing.dynamic_index_in_dim(tracer, i, keepdims=False)
|
|
for i in range(n))
|
|
else:
|
|
return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n))
|
|
ShapedArray._iter = staticmethod(_iter)
|
|
core.DShapedArray._iter = staticmethod(_iter)
|
|
|
|
def zeros_like_array(x: ArrayLike) -> Array:
|
|
return full_like(x, 0)
|
|
|
|
|
|
def _add_arrays(x, y):
|
|
if (isinstance(a := core.get_aval(x), ShapedArray) and
|
|
dtypes.issubdtype(a.dtype, dtypes.extended)):
|
|
return dtype._rules.add(dtype, x, y) # pytype: disable=attribute-error
|
|
return add(x, y)
|
|
|
|
for t in itertools.chain(
|
|
dtypes.python_scalar_dtypes.keys(), array_types, [array.ArrayImpl]):
|
|
ad_util.raw_jaxval_adders[t] = _add_arrays
|
|
|
|
|
|
### primitives
|
|
|
|
|
|
_fixed_dtype = \
|
|
lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
|
|
_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype
|
|
|
|
_strip_weak_type = lambda *args, **_: False
|
|
|
|
|
|
def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
|
|
if aval.dtype == dtypes.float0:
|
|
raise TypeError(
|
|
f"Called {name} with a float0 array. "
|
|
"float0s do not support any operations by design, because they "
|
|
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
|
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
|
"to cast a float0 array to a regular zeros array. \n"
|
|
"If you didn't expect to get a float0 you might have accidentally "
|
|
"taken a gradient with respect to an integer argument.")
|
|
if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes):
|
|
msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
|
typename = dtype_to_string(aval.dtype)
|
|
accepted_typenames = (t.__name__ for t in accepted_dtypes)
|
|
raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames)))
|
|
return result_dtype(aval.dtype)
|
|
|
|
|
|
def unop(result_dtype, accepted_dtypes, name):
|
|
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
|
|
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name)
|
|
batching.defvectorized(prim)
|
|
pe.def_trivial_padding(prim)
|
|
return prim
|
|
standard_unop = partial(unop, _identity)
|
|
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
|
|
|
|
|
|
def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
|
|
require_same=True, allow_extended_dtype=False, **kwargs):
|
|
del kwargs
|
|
assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes)
|
|
for i, aval in enumerate(avals):
|
|
if allow_extended_dtype and isinstance(aval.dtype, dtypes.ExtendedDType):
|
|
continue
|
|
types = accepted_dtypes[i]
|
|
if not any(dtypes.issubdtype(aval.dtype, t) for t in types):
|
|
if aval.dtype == dtypes.float0:
|
|
raise TypeError(
|
|
f"Called {name} with a float0 at position {i}. "
|
|
"float0s do not support any operations by design, because they "
|
|
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
|
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
|
"to cast a float0 array to a regular zeros array. \n"
|
|
"If you didn't expect to get a float0 you might have accidentally "
|
|
"taken a gradient with respect to an integer argument.")
|
|
else:
|
|
msg = ('{} does not accept dtype {} at position {}. '
|
|
'Accepted dtypes at position {} are subtypes of {}.')
|
|
typename = dtype_to_string(aval.dtype)
|
|
typenames = ', '.join(t.__name__ for t in types)
|
|
raise TypeError(msg.format(name, typename, i, i, typenames))
|
|
if require_same: check_same_dtypes(name, *avals)
|
|
return result_dtype(*avals)
|
|
|
|
|
|
def broadcasting_shape_rule(name, *avals):
|
|
shapes = [aval.shape for aval in avals if aval.shape]
|
|
if not shapes:
|
|
return ()
|
|
if len({len(shape) for shape in shapes}) != 1:
|
|
msg = '{}: arrays must have same number of dimensions, got {}.'
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
|
# TODO(mattjj): de-duplicate with _try_broadcast_shapes
|
|
result_shape = []
|
|
for ds in zip(*shapes):
|
|
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
|
|
# if all axes are identical objects, the resulting size is the object
|
|
result_shape.append(ds[0])
|
|
else:
|
|
# if all dims are equal (or 1), the result is the non-1 size
|
|
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
|
|
if not non_1s:
|
|
result_shape.append(1)
|
|
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
|
|
result_shape.append(non_1s[0])
|
|
else:
|
|
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
|
|
f'{", ".join(map(str, map(tuple, shapes)))}.')
|
|
|
|
return tuple(result_shape)
|
|
|
|
|
|
def broadcasting_sharding_rule(name, *avals):
|
|
shapes = [aval.shape for aval in avals if aval.shape]
|
|
if not shapes:
|
|
return ()
|
|
if len({len(shape) for shape in shapes}) != 1:
|
|
msg = '{}: arrays must have same number of dimensions, got {}.'
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
|
|
|
specs = [a.sharding.spec for a in avals if a.shape]
|
|
|
|
mesh = None
|
|
for a in avals:
|
|
if a.shape:
|
|
mesh = a.sharding.mesh
|
|
if mesh is not None and mesh != a.sharding.mesh:
|
|
raise ValueError(
|
|
f'Mesh for all inputs should be equal. Got one mesh: {mesh} and'
|
|
f' another mesh: {a.sharding.mesh}')
|
|
assert mesh is not None
|
|
|
|
result_specs = []
|
|
for ss, ds in zip(zip(*specs), zip(*shapes)):
|
|
if all(s == ss[0] for s in ss[1:]):
|
|
# if all dimension shardings are same, the resulting dimension sharding is
|
|
# the same.
|
|
result_specs.append(ss[0])
|
|
else:
|
|
non_trivial_s = [s for s, d in zip(ss, ds)
|
|
if not (core.definitely_equal(d, 1) and s is None)]
|
|
if not non_trivial_s:
|
|
result_specs.append(None)
|
|
elif all(non_trivial_s[0] == s for s in non_trivial_s[1:]):
|
|
result_specs.append(non_trivial_s[0])
|
|
else:
|
|
raise TypeError(f'{name} got incompatible shardings for broadcasting: '
|
|
f'{", ".join(map(str, map(tuple, specs)))}.')
|
|
return NamedSharding(mesh, PartitionSpec(*result_specs))
|
|
|
|
|
|
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
|
|
require_same_dtypes=False):
|
|
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
|
|
allow_extended_dtype=allow_extended_dtype,
|
|
require_same=require_same_dtypes)
|
|
shape_rule = partial(broadcasting_shape_rule, name)
|
|
sharding_rule = partial(broadcasting_sharding_rule, name)
|
|
prim = standard_primitive(shape_rule, dtype_rule, name,
|
|
sharding_rule=sharding_rule)
|
|
batching.defbroadcasting(prim)
|
|
pe.def_trivial_padding(prim)
|
|
return prim
|
|
standard_naryop = partial(naryop, _input_dtype)
|
|
|
|
|
|
# Like autograd.numpy.numpy_vjps.unbroadcast, this utility handles transposition
|
|
# involving linear primitives with implicit broadcasting.
|
|
def _unbroadcast(aval, x):
|
|
if not isinstance(aval, (core.DShapedArray, ShapedArray)):
|
|
raise TypeError("transpose with implicit broadcasting of unshaped values")
|
|
x_shape = np.shape(x)
|
|
if core.definitely_equal_shape(aval.shape, x_shape):
|
|
return x
|
|
assert not aval.shape or len(x_shape) == len(aval.shape)
|
|
if not aval.shape:
|
|
return _reduce_sum(x, list(range(len(x_shape))))
|
|
else:
|
|
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
|
|
if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims)
|
|
return reshape(_reduce_sum(x, dims), aval.shape)
|
|
|
|
def _maybe_broadcast(target_shape, x):
|
|
x_shape = np.shape(x)
|
|
if core.definitely_equal_shape(x_shape, target_shape):
|
|
return x
|
|
elif not x_shape:
|
|
return broadcast_in_dim(x, target_shape, ())
|
|
else:
|
|
dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape))
|
|
if core.definitely_equal(a, b)]
|
|
squeeze_shape = [x_shape[i] for i in dims]
|
|
return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)
|
|
|
|
def broadcast_hlo(
|
|
aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
|
|
args: Sequence[ir.Value]) -> Sequence[ir.Value]:
|
|
"""Broadcasts HLO values with broadcast-compatible shapes to the same shape.
|
|
"""
|
|
out = []
|
|
for aval, arg in zip(avals, args):
|
|
if aval.shape != aval_out.shape:
|
|
assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out)
|
|
dims = mlir.dense_int_array(
|
|
list(range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape))))
|
|
if any(isinstance(d, ir.Value) for d in aval_out.shape):
|
|
arg = hlo.dynamic_broadcast_in_dim(
|
|
mlir.aval_to_ir_type(aval_out), arg,
|
|
mlir.shape_tensor(aval_out.shape), dims)
|
|
else:
|
|
arg = hlo.broadcast_in_dim(
|
|
mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg,
|
|
dims)
|
|
out.append(arg)
|
|
return out
|
|
|
|
def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
|
|
out = []
|
|
for op, in_aval in zip(ops, in_avals):
|
|
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
|
|
out.append(op)
|
|
else:
|
|
# TODO(yashkatariya, dougalm): If `in_aval.sharding` contains
|
|
# CompilerShardingAxis, then specify `unspecified_dims` via
|
|
# `wrap_with_sharding_op`.
|
|
if config.use_shardy_partitioner.value:
|
|
sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim)
|
|
else:
|
|
sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
|
|
out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp))
|
|
return out
|
|
|
|
|
|
def _nary_lower_hlo(op: Callable, ctx,
|
|
*args: ir.Value,
|
|
explicit_type=False, **params) -> Sequence[ir.Value]:
|
|
"""Lowers an elementwise operator to its MLIR equivalent.
|
|
|
|
Args:
|
|
explicit_type: does the MLIR op require its output type to be provided?
|
|
"""
|
|
del params
|
|
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
|
args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) # type: ignore
|
|
if config.sharding_in_types.value:
|
|
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
|
|
|
|
if explicit_type:
|
|
out = op(mlir.aval_to_ir_type(aval_out), *args)
|
|
else:
|
|
out = op(*args)
|
|
if config.sharding_in_types.value:
|
|
if config.use_shardy_partitioner.value:
|
|
out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim)
|
|
else:
|
|
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
|
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)]
|
|
else:
|
|
return [out]
|
|
|
|
|
|
_float = {np.floating}
|
|
_complex = {np.complexfloating}
|
|
_complex_elem_types = {np.float32, np.float64}
|
|
_int = {np.integer}
|
|
_bool = {np.bool_}
|
|
_signedint = {np.signedinteger}
|
|
|
|
_num = _int | _float | _complex
|
|
_any = _int | _float | _complex | _bool
|
|
_bool_or_int = _int | _bool
|
|
_ordered = _int | _float | _bool
|
|
|
|
neg_p = standard_unop(_num, 'neg')
|
|
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
|
|
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.negate))
|
|
|
|
sign_p = standard_unop(_num, 'sign')
|
|
ad.defjvp_zero(sign_p)
|
|
|
|
def _sign_lower_hlo(ctx, x):
|
|
x_aval, = ctx.avals_in
|
|
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
|
|
return [hlo.select(
|
|
mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
|
|
'UNSIGNED'),
|
|
mlir.full_like_aval(ctx, 0, x_aval),
|
|
mlir.full_like_aval(ctx, 1, x_aval))]
|
|
return [hlo.sign(x)]
|
|
|
|
mlir.register_lowering(sign_p, _sign_lower_hlo)
|
|
|
|
nextafter_p = standard_naryop([_float, _float], 'nextafter')
|
|
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.next_after))
|
|
|
|
floor_p = standard_unop(_float, 'floor')
|
|
ad.defjvp_zero(floor_p)
|
|
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.floor))
|
|
|
|
ceil_p = standard_unop(_float, 'ceil')
|
|
ad.defjvp_zero(ceil_p)
|
|
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.ceil))
|
|
|
|
round_p = standard_unop(_float, 'round')
|
|
ad.defjvp_zero(round_p)
|
|
|
|
def _round_lower(ctx, x, *, rounding_method):
|
|
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
|
|
return [hlo.round_nearest_afz(x)]
|
|
else:
|
|
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
|
return [hlo.round_nearest_even(x)]
|
|
mlir.register_lowering(round_p, _round_lower)
|
|
|
|
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
|
|
ad.defjvp_zero(is_finite_p)
|
|
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
|
|
|
|
exp_p = standard_unop(_float | _complex, 'exp')
|
|
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
|
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
|
|
|
|
exp2_p = standard_unop(_float | _complex, 'exp2')
|
|
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
|
|
def _exp2_lower(ctx, x):
|
|
x_aval, = ctx.avals_in
|
|
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
|
|
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
|
|
return [hlo.exponential(hlo.multiply(log2, x))]
|
|
mlir.register_lowering(exp2_p, _exp2_lower)
|
|
|
|
log_p = standard_unop(_float | _complex, 'log')
|
|
ad.defjvp(log_p, lambda g, x: div(g, x))
|
|
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log))
|
|
|
|
expm1_p = standard_unop(_float | _complex, 'expm1')
|
|
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
|
mlir.register_lowering(expm1_p,
|
|
partial(_nary_lower_hlo, hlo.exponential_minus_one))
|
|
|
|
log1p_p = standard_unop(_float | _complex, 'log1p')
|
|
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
|
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one))
|
|
|
|
tanh_p = standard_unop(_float | _complex, 'tanh')
|
|
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
|
|
sub(_one(x), ans)))
|
|
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh))
|
|
|
|
logistic_p = standard_unop(_float | _complex, 'logistic')
|
|
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
|
|
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
|
|
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic))
|
|
|
|
def logistic_impl(x):
|
|
one = _const(x, 1)
|
|
return div(one, add(one, exp(neg(x))))
|
|
|
|
mlir.register_lowering(logistic_p,
|
|
mlir.lower_fun(logistic_impl, multiple_results=False))
|
|
|
|
def _sin_complex(x):
|
|
# use expm1 instead of exp to avoid cancellation when abs(x) is small
|
|
# relies on the quality of real-valued expm1, sin, cos
|
|
# sin(x) = complex(sin(real(x)) * cosh(imag(x)), cos(real(x)) * sinh(imag(x)))
|
|
# 2 * sinh(x) = exp(x) - 1 - (exp(-x) - 1) = expm1(x) - expm1(-x)
|
|
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
|
|
a, b = real(x), imag(x)
|
|
a_is_zero = eq(a, _const(a, 0))
|
|
sn, cs = sin(a), cos(a)
|
|
e1m, e2m = expm1(b), expm1(-b)
|
|
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
|
|
re, im = sn * csh, cs * snh
|
|
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
|
|
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
|
|
|
|
def _sin_lowering(ctx, x):
|
|
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
|
sine = mlir.lower_fun(_sin_complex, multiple_results=False)
|
|
return sine(ctx, x)
|
|
return _nary_lower_hlo(hlo.sine, ctx, x)
|
|
|
|
sin_p = standard_unop(_float | _complex, 'sin')
|
|
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
|
mlir.register_lowering(sin_p, _sin_lowering)
|
|
|
|
def _cos_complex(x):
|
|
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
|
|
# see also _sin_complex
|
|
a, b = real(x), imag(x)
|
|
a_is_zero = eq(a, _const(a, 0))
|
|
sn, cs = sin(a), cos(a)
|
|
e1m, e2m = expm1(b), expm1(-b)
|
|
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
|
|
re, im = cs * csh, -sn * snh
|
|
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
|
|
|
|
def _cos_lowering(ctx, x):
|
|
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
|
cosine = mlir.lower_fun(_cos_complex, multiple_results=False)
|
|
return cosine(ctx, x)
|
|
return _nary_lower_hlo(hlo.cosine, ctx, x)
|
|
|
|
cos_p = standard_unop(_float | _complex, 'cos')
|
|
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
|
mlir.register_lowering(cos_p, _cos_lowering)
|
|
|
|
@_upcast_fp16_for_computation
|
|
def _tan_impl(x):
|
|
return div(sin(x), cos(x))
|
|
|
|
tan_p = standard_unop(_float | _complex, 'tan')
|
|
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
|
# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this
|
|
# lowering is mostly supported, but it fails on export or with the PJRT plugin
|
|
# because those modes target an older StableHLO version, and the
|
|
# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't
|
|
# included in the 0.4.33 release.
|
|
if jaxlib_version <= (0, 4, 33):
|
|
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
|
|
else:
|
|
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
|
|
|
|
def asin_impl(x):
|
|
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
|
return mul(_const(x, -1j), asinh(mul(_const(x, 1j), x)))
|
|
else:
|
|
return mul(_const(x, 2),
|
|
atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x))))))
|
|
|
|
asin_p = standard_unop(_float | _complex, 'asin')
|
|
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
|
|
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
|
|
|
|
def acos_impl(x):
|
|
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
|
result = mul(_const(x, 1j), acosh(x))
|
|
# By convention, numpy chooses the branch with positive real part.
|
|
rpart = real(result)
|
|
return select(
|
|
gt(rpart, _const(rpart, 0)),
|
|
result,
|
|
neg(result)
|
|
)
|
|
else:
|
|
return select(
|
|
ne(x, _const(x, -1.0)),
|
|
mul(_const(x, 2),
|
|
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))),
|
|
full_like(x, np.pi))
|
|
|
|
acos_p = standard_unop(_float | _complex, 'acos')
|
|
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
|
|
mlir.register_lowering(acos_p,
|
|
mlir.lower_fun(acos_impl, multiple_results=False))
|
|
|
|
def atan_impl(x):
|
|
return atan2(x, _const(x, 1))
|
|
|
|
atan_p = standard_unop(_float | _complex, 'atan')
|
|
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
|
|
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
|
|
|
|
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
|
|
ad.defjvp(atan2_p,
|
|
lambda g, x, y: g * (y / (square(x) + square(y))),
|
|
lambda g, x, y: g * -x / (square(x) + square(y)))
|
|
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
|
|
|
|
sinh_p = standard_unop(_float | _complex, 'sinh')
|
|
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
|
|
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.sinh))
|
|
|
|
cosh_p = standard_unop(_float | _complex, 'cosh')
|
|
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
|
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
|
|
|
|
asinh_p = standard_unop(_float | _complex, 'asinh')
|
|
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
|
|
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
|
|
|
|
acosh_p = standard_unop(_float | _complex, 'acosh')
|
|
ad.defjvp(acosh_p,
|
|
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
|
|
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
|
|
|
|
atanh_p = standard_unop(_float | _complex, 'atanh')
|
|
ad.defjvp(atanh_p,
|
|
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
|
|
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
|
|
|
|
real_p = unop(_complex_basetype, _complex, 'real')
|
|
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
|
|
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.real))
|
|
|
|
imag_p = unop(_complex_basetype, _complex, 'imag')
|
|
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
|
|
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.imag))
|
|
|
|
|
|
def _complex_transpose_rule(t, x, y):
|
|
assert ad.is_undefined_primal(x) or ad.is_undefined_primal(y)
|
|
if ad.is_undefined_primal(x) and ad.is_undefined_primal(y):
|
|
if type(t) is ad_util.Zero:
|
|
return [ad_util.Zero(x.aval), ad_util.Zero(y.aval)]
|
|
else:
|
|
return [_unbroadcast(x.aval, real(t)), _unbroadcast(y.aval, imag(neg(t)))]
|
|
elif ad.is_undefined_primal(x):
|
|
if type(t) is ad_util.Zero:
|
|
return [ad_util.Zero(x.aval), None]
|
|
else:
|
|
return [_unbroadcast(x.aval, real(t)), None]
|
|
else:
|
|
if type(t) is ad_util.Zero:
|
|
return [None, ad_util.Zero(y.aval)]
|
|
else:
|
|
return [None, _unbroadcast(y.aval, imag(neg(t)))]
|
|
|
|
_complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
|
|
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
|
'complex')
|
|
ad.deflinear2(complex_p, _complex_transpose_rule)
|
|
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.complex))
|
|
|
|
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
|
|
|
|
def _conj_impl(x, **kw):
|
|
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
|
return complex(real(x), -imag(x))
|
|
else:
|
|
return complex(x, _zeros(x))
|
|
|
|
mlir.register_lowering(conj_p,
|
|
mlir.lower_fun(_conj_impl, multiple_results=False))
|
|
|
|
|
|
def _conj_transpose_rule(t, x, *, input_dtype):
|
|
assert ad.is_undefined_primal(x)
|
|
if type(t) is ad_util.Zero:
|
|
return [ad_util.Zero(x.aval)]
|
|
elif dtypes.issubdtype(input_dtype, np.complexfloating):
|
|
return [conj(t)]
|
|
else:
|
|
return [real(t)]
|
|
|
|
ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
|
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
|
|
|
abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs')
|
|
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs))
|
|
|
|
def _abs_jvp_rule(g, ans, x):
|
|
if _iscomplex(x):
|
|
return _maybe_real(mul(g, div(_maybe_conj(x),
|
|
_replace_zero(convert_element_type(ans, _dtype(x))))))
|
|
else:
|
|
return select(ge(x, _zero(x)), g, neg(g))
|
|
ad.defjvp2(abs_p, _abs_jvp_rule)
|
|
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
|
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
|
|
|
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
|
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
|
|
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt))
|
|
|
|
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
|
ad.defjvp2(rsqrt_p,
|
|
lambda g, ans, x:
|
|
mul(g, mul(_const(x, -0.5), div(ans, x))))
|
|
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt))
|
|
|
|
cbrt_p = standard_unop(_float, 'cbrt')
|
|
ad.defjvp2(cbrt_p,
|
|
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
|
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt))
|
|
|
|
def _pow_dtype_rule(x, y):
|
|
if (dtypes.issubdtype(x.dtype, np.inexact) and
|
|
dtypes.issubdtype(y.dtype, np.integer)):
|
|
return x.dtype
|
|
if x.dtype == y.dtype:
|
|
return x.dtype
|
|
raise TypeError("the first argument to pow must have an inexact dtype (float "
|
|
"or complex), and the second argument must have an inexact or"
|
|
" integer dtype, and two inexact dtypes must match, but got "
|
|
f"{x.dtype} and {y.dtype} respectively.")
|
|
pow_p = naryop(_pow_dtype_rule, [_float | _complex, _int | _float | _complex],
|
|
'pow', require_same_dtypes=False)
|
|
|
|
def _pow_jvp_lhs(g, ans, x, y):
|
|
y_dtype = dtypes.dtype(y)
|
|
result_dtype = dtypes.result_type(x, y)
|
|
if result_dtype == bool:
|
|
result_dtype = 'int32'
|
|
x = convert_element_type(x, result_dtype)
|
|
y = convert_element_type(y, result_dtype)
|
|
if dtypes.issubdtype(y_dtype, np.integer):
|
|
if x.shape != y.shape:
|
|
shape = broadcast_shapes(x.shape, y.shape)
|
|
x = _maybe_broadcast(shape, x)
|
|
y = _maybe_broadcast(shape, y)
|
|
jac = select(eq(y, _const(y, 0)), _zeros(y),
|
|
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
|
|
else:
|
|
jac = mul(y, pow(x, sub(y, _ones(y))))
|
|
return mul(g, jac)
|
|
|
|
def _pow_jvp_rhs(g, ans, x, y):
|
|
y_dtype = dtypes.dtype(y)
|
|
assert dtypes.issubdtype(y_dtype, np.inexact)
|
|
return convert_element_type(mul(g, mul(log(_replace_zero(x)), ans)), y_dtype)
|
|
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
|
|
|
def _pow_lower(ctx, x, y):
|
|
x_aval, y_aval = ctx.avals_in
|
|
out_aval, = ctx.avals_out
|
|
convert = mlir.lower_fun(
|
|
partial(convert_element_type, new_dtype=out_aval.dtype), False)
|
|
x_aval_ = x_aval.update(dtype=out_aval.dtype)
|
|
y_aval_ = y_aval.update(dtype=out_aval.dtype)
|
|
[x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
|
|
[y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
|
|
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
|
|
return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
|
|
mlir.register_lowering(pow_p, _pow_lower)
|
|
|
|
def _integer_pow_dtype_rule(x, *, y):
|
|
dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x)
|
|
if y < 0 and dtypes.issubdtype(dtype, np.integer):
|
|
raise TypeError("Integers cannot be raised to negative powers, got "
|
|
f"integer_pow({x}, {y})")
|
|
return dtype
|
|
|
|
def _integer_pow_jvp(g, x, *, y):
|
|
return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))
|
|
|
|
integer_pow_p = standard_primitive(
|
|
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow')
|
|
batching.defvectorized(integer_pow_p)
|
|
ad.defjvp(integer_pow_p, _integer_pow_jvp)
|
|
pe.def_trivial_padding(integer_pow_p)
|
|
|
|
def _integer_pow(x, *, y):
|
|
# This should be kept in sync with the jax2tf translation rule.
|
|
if y == 0:
|
|
return full_like(x, 1)
|
|
is_reciprocal = y < 0
|
|
if is_reciprocal:
|
|
y = -y
|
|
acc = None
|
|
while y > 0:
|
|
if y & 1:
|
|
acc = x if acc is None else mul(acc, x)
|
|
y >>= 1
|
|
if y > 0:
|
|
# We don't call square because it calls integer_pow.
|
|
x = mul(x, x)
|
|
return div(full_like(acc, 1), acc) if is_reciprocal else acc
|
|
|
|
|
|
def _integer_pow_lowering(ctx, x, *, y):
|
|
# These cases are subsumed by the general case, but it's faster to emit these
|
|
# common cases directly.
|
|
if y == 2:
|
|
return (hlo.multiply(x, x),)
|
|
elif y == 3:
|
|
return (hlo.multiply(hlo.multiply(x, x), x),)
|
|
else:
|
|
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
|
|
# TODO(b/217551391): emitting an out-of-line call leads to a large
|
|
# expansion when the MLIR is lowered to HLO, because the HLO lowering
|
|
# clones the callee. Consider unconditionally caching when the MLIR->HLO
|
|
# lowering doesn't expand the program.
|
|
lowering = mlir.cache_lowering(lowering)
|
|
return lowering(ctx, x, y=y)
|
|
|
|
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
|
|
|
|
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
|
|
|
not_p = standard_unop(_bool_or_int, 'not')
|
|
ad.defjvp_zero(not_p)
|
|
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.not_))
|
|
|
|
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
|
ad.defjvp_zero(and_p)
|
|
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.and_))
|
|
|
|
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
|
|
ad.defjvp_zero(or_p)
|
|
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.or_))
|
|
|
|
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
|
ad.defjvp_zero(xor_p)
|
|
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.xor))
|
|
|
|
population_count_p = standard_unop(_int, 'population_count')
|
|
mlir.register_lowering(population_count_p, partial(_nary_lower_hlo, hlo.popcnt))
|
|
|
|
clz_p = standard_unop(_int, 'clz')
|
|
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.count_leading_zeros))
|
|
|
|
def _add_jvp(primals, tangents):
|
|
x, y = primals
|
|
xdot, ydot = tangents
|
|
primal_out = add(x, y)
|
|
if type(xdot) is type(ydot) is ad_util.Zero:
|
|
return primal_out, ad_util.Zero.from_primal_value(primal_out)
|
|
if type(xdot) is ad_util.Zero:
|
|
return primal_out, _maybe_broadcast(primal_out.shape, ydot)
|
|
elif type(ydot) is ad_util.Zero:
|
|
return primal_out, _maybe_broadcast(primal_out.shape, xdot)
|
|
else:
|
|
return primal_out, add(xdot, ydot)
|
|
|
|
def _add_transpose(t, x, y):
|
|
# Morally the following assertion is true, but because we instantiate zeros in
|
|
# some places (e.g. in custom_jvp) it may not always hold. For example, see
|
|
# api_test.py's CustomJVPTest.test_jaxpr_zeros.
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
|
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
|
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
|
if type(t) is ad_util.Zero:
|
|
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
|
else:
|
|
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)]
|
|
|
|
# TODO(slebedev): Why does mypy fail to infer the type here?
|
|
add_p: Primitive = standard_naryop([_num, _num], 'add')
|
|
ad.primitive_jvps[add_p] = _add_jvp
|
|
ad.primitive_transposes[add_p] = _add_transpose
|
|
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add))
|
|
|
|
def _sub_jvp(primals, tangents):
|
|
x, y = primals
|
|
xdot, ydot = tangents
|
|
primal_out = sub(x, y)
|
|
if type(xdot) is type(ydot) is ad_util.Zero:
|
|
return primal_out, ad_util.Zero.from_primal_value(primal_out)
|
|
if type(xdot) is ad_util.Zero:
|
|
return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot))
|
|
elif type(ydot) is ad_util.Zero:
|
|
return primal_out, _maybe_broadcast(primal_out.shape, xdot)
|
|
else:
|
|
return primal_out, sub(xdot, ydot)
|
|
|
|
def _sub_transpose(t, x, y):
|
|
# Morally the following assertion is true, but see the comment in add_p's
|
|
# transpose rule.
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
|
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
|
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
|
if type(t) is ad_util.Zero:
|
|
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
|
else:
|
|
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, neg(t))]
|
|
|
|
sub_p = standard_naryop([_num, _num], 'sub')
|
|
ad.primitive_jvps[sub_p] = _sub_jvp
|
|
ad.primitive_transposes[sub_p] = _sub_transpose
|
|
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract))
|
|
|
|
|
|
def _mul_transpose(ct, x, y):
|
|
assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y)
|
|
if ad.is_undefined_primal(x):
|
|
if type(ct) is ad_util.Zero:
|
|
return [ad_util.Zero(x.aval), None]
|
|
else:
|
|
return [_unbroadcast(x.aval, mul(ct, y)), None]
|
|
else:
|
|
if type(ct) is ad_util.Zero:
|
|
return [None, ad_util.Zero(y.aval)]
|
|
else:
|
|
return [None, _unbroadcast(y.aval, mul(x, ct))]
|
|
|
|
mul_p = standard_naryop([_num, _num], 'mul')
|
|
ad.defjvp(mul_p,
|
|
lambda xdot, x, y: mul(xdot, y),
|
|
lambda ydot, x, y: mul(x, ydot))
|
|
ad.primitive_transposes[mul_p] = _mul_transpose
|
|
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply))
|
|
|
|
def _div_transpose_rule(cotangent, x, y):
|
|
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
|
if type(cotangent) is ad_util.Zero:
|
|
return [ad_util.Zero(x.aval), None]
|
|
else:
|
|
return [_unbroadcast(x.aval, div(cotangent, y)), None]
|
|
div_p = standard_naryop([_num, _num], 'div')
|
|
ad.defjvp(div_p,
|
|
lambda g, x, y: div(g, y),
|
|
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
|
|
ad.primitive_transposes[div_p] = _div_transpose_rule
|
|
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide))
|
|
|
|
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
|
|
ad.defjvp(
|
|
rem_p,
|
|
lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
|
|
lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y))))))
|
|
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder))
|
|
|
|
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
|
|
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
|
|
x = _maybe_broadcast(result_shape, x)
|
|
y = _maybe_broadcast(result_shape, y)
|
|
rx = real(x)
|
|
ry = real(y)
|
|
pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)),
|
|
lax_cmp_pick_x(rx, ry))
|
|
return select(pick_x, x, y)
|
|
|
|
max_p: core.Primitive = standard_naryop([_any, _any], 'max')
|
|
ad.defjvp2(max_p,
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
|
mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo))
|
|
|
|
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
|
|
ad.defjvp2(min_p,
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
|
mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
|
|
|
|
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
|
ad.defjvp_zero(shift_left_p)
|
|
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.shift_left))
|
|
|
|
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
|
|
ad.defjvp_zero(shift_right_arithmetic_p)
|
|
mlir.register_lowering(shift_right_arithmetic_p,
|
|
partial(_nary_lower_hlo, hlo.shift_right_arithmetic))
|
|
|
|
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
|
|
ad.defjvp_zero(shift_right_logical_p)
|
|
mlir.register_lowering(shift_right_logical_p,
|
|
partial(_nary_lower_hlo, hlo.shift_right_logical))
|
|
|
|
def _opaque_comparison_hlo(direction, reduction_op, identity, ctx,
|
|
avals_in, aval_out, x, y):
|
|
aval_x, aval_y = avals_in
|
|
base_aval_x = core.physical_aval(aval_x)
|
|
base_aval_y = core.physical_aval(aval_y)
|
|
base_aval_out = core.ShapedArray(base_aval_x.shape, aval_out.dtype)
|
|
reduce_axes = tuple(range(aval_out.ndim, base_aval_out.ndim))
|
|
res, = mlir.delegate_lowering(
|
|
ctx, partial(_compare_lower_hlo, direction, False),
|
|
x, y, avals_in=[base_aval_x, base_aval_y], avals_out=[base_aval_out])
|
|
return mlir.delegate_lowering(
|
|
ctx, partial(_unary_reduce_lower, reduction_op, identity,
|
|
axes=reduce_axes),
|
|
res, avals_in=[base_aval_out], avals_out=[aval_out])
|
|
|
|
_opaque_eq_hlo = partial(
|
|
_opaque_comparison_hlo, 'EQ', hlo.AndOp, _get_bitwise_and_identity)
|
|
_opaque_ne_hlo = partial(
|
|
_opaque_comparison_hlo, 'NE', hlo.OrOp, _get_bitwise_or_identity)
|
|
|
|
def _compare_lower_hlo_opaque(direction: str, ctx, avals_in, aval_out, x, y):
|
|
broadcast_avals_in = tuple(
|
|
core.ShapedArray(aval_out.shape, aval.dtype) for aval in avals_in)
|
|
if direction == 'EQ':
|
|
return _opaque_eq_hlo(ctx, broadcast_avals_in, aval_out, x, y)
|
|
elif direction == 'NE':
|
|
return _opaque_ne_hlo(ctx, broadcast_avals_in, aval_out, x, y)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"HLO comparison {direction} for extended dtype {avals_in[0].dtype}")
|
|
|
|
|
|
def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
|
|
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
|
x_dtype = avals_in[0].dtype
|
|
x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape)
|
|
if dtypes.issubdtype(x_dtype, dtypes.extended):
|
|
assert not total_order
|
|
return _compare_lower_hlo_opaque(direction, ctx, avals_in, aval_out, x, y)
|
|
if dtypes.issubdtype(x_dtype, np.inexact):
|
|
compare_type = "TOTALORDER" if total_order else "FLOAT"
|
|
elif dtypes.issubdtype(x_dtype, np.signedinteger):
|
|
compare_type = "SIGNED"
|
|
else:
|
|
compare_type = "UNSIGNED"
|
|
return [mlir.compare_hlo(x, y, direction, compare_type)]
|
|
|
|
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
|
|
ad.defjvp_zero(eq_p)
|
|
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False))
|
|
|
|
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True)
|
|
ad.defjvp_zero(ne_p)
|
|
mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE", False))
|
|
|
|
ge_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'ge')
|
|
ad.defjvp_zero(ge_p)
|
|
mlir.register_lowering(ge_p, partial(_compare_lower_hlo, "GE", False))
|
|
|
|
gt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'gt')
|
|
ad.defjvp_zero(gt_p)
|
|
mlir.register_lowering(gt_p, partial(_compare_lower_hlo, "GT", False))
|
|
|
|
le_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le')
|
|
ad.defjvp_zero(le_p)
|
|
mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE", False))
|
|
|
|
lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
|
|
ad.defjvp_zero(lt_p)
|
|
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False))
|
|
|
|
eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to')
|
|
ad.defjvp_zero(eq_to_p)
|
|
mlir.register_lowering(eq_to_p, partial(_compare_lower_hlo, "EQ", True))
|
|
|
|
le_to_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le_to')
|
|
ad.defjvp_zero(le_to_p)
|
|
mlir.register_lowering(le_to_p, partial(_compare_lower_hlo, "LE", True))
|
|
|
|
lt_to_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt_to')
|
|
ad.defjvp_zero(lt_to_p)
|
|
mlir.register_lowering(lt_to_p, partial(_compare_lower_hlo, "LT", True))
|
|
|
|
|
|
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type,
|
|
sharding):
|
|
return operand.shape
|
|
|
|
def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
|
|
sharding):
|
|
return sharding
|
|
|
|
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
|
|
sharding):
|
|
if (operand.dtype != new_dtype and
|
|
((dtypes.issubdtype(operand.dtype, dtypes.extended) and
|
|
not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or
|
|
(dtypes.issubdtype(new_dtype, dtypes.extended) and
|
|
not new_dtype._rules.convert_to(operand.dtype, new_dtype)))):
|
|
raise ValueError(
|
|
f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} "
|
|
f"to {dtype_to_string(new_dtype)}")
|
|
return new_dtype
|
|
|
|
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type,
|
|
sharding):
|
|
return weak_type
|
|
|
|
def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type,
|
|
sharding):
|
|
assert ad.is_undefined_primal(operand)
|
|
old_dtype = operand.aval.dtype
|
|
old_weak_type = dtypes.is_weakly_typed(operand)
|
|
if type(ct) is ad_util.Zero:
|
|
return [ad_util.Zero(operand.aval)]
|
|
elif core.primal_dtype_to_tangent_dtype(old_dtype) == dtypes.float0:
|
|
return [ad_util.Zero(operand.aval.update(dtype=dtypes.float0, weak_type=False))]
|
|
else:
|
|
return [convert_element_type_p.bind(
|
|
ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)]
|
|
|
|
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type,
|
|
sharding):
|
|
if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0:
|
|
tangent_aval = core.raise_to_shaped(core.get_aval(tangent))
|
|
return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False))
|
|
else:
|
|
return convert_element_type_p.bind(tangent, new_dtype=new_dtype,
|
|
weak_type=weak_type, sharding=sharding)
|
|
|
|
def _convert_elt_type_folding_rule(consts, eqn):
|
|
# We constant-fold convert_element_types applied to constants if those
|
|
# constants are Python builtin numeric types or numpy.ndarrays (so as not
|
|
# to perform any device operations when constant-folding) and if the output
|
|
# type can be faithfully represented by a Python builtin numeric type or
|
|
# numpy.ndarray. If those conditions are met, we output a numpy.ndarray
|
|
# constant if the output type is not weak, and if the output type is weak then
|
|
# we output a Python builtin numeric type.
|
|
# TODO(mattjj): allow constant-folding CPU-backed JAX arrays
|
|
c, = consts
|
|
o, = eqn.outvars
|
|
if (type(c) in {np.ndarray, *dtypes.python_scalar_dtypes} and
|
|
isinstance(o.aval, core.UnshapedArray) and not np.shape(c) and
|
|
not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended)):
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter('ignore', util.NumpyComplexWarning)
|
|
out = np.array(c).astype(eqn.params['new_dtype'])
|
|
if not o.aval.weak_type:
|
|
return [out], None
|
|
out = out.item()
|
|
if core.get_aval(out).dtype is o.aval.dtype:
|
|
return [out], None
|
|
return [None], eqn
|
|
|
|
def _convert_elt_type_fwd_rule(eqn):
|
|
v, = eqn.invars
|
|
if (not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended) and
|
|
not dtypes.issubdtype(v.aval.dtype, dtypes.extended) and
|
|
v.aval.dtype == eqn.params['new_dtype'] and
|
|
v.aval.weak_type == eqn.params['weak_type']):
|
|
return [v], None
|
|
else:
|
|
return [None], eqn
|
|
|
|
def _convert_elt_type_pp_rule(eqn, context, settings):
|
|
# don't print new_dtype because the output binder shows it, don't print
|
|
# weak_type when false
|
|
params = dict(eqn.params)
|
|
if params['sharding'] is None:
|
|
del params['sharding'] # don't show trivial case
|
|
return core._pp_eqn(eqn.replace(params=params), context, settings)
|
|
|
|
convert_element_type_p = Primitive('convert_element_type')
|
|
def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding):
|
|
operand = core.Primitive.bind(convert_element_type_p, operand,
|
|
new_dtype=new_dtype, weak_type=weak_type,
|
|
sharding=sharding)
|
|
if sharding is not None:
|
|
operand = pjit.with_sharding_constraint(operand, sharding)
|
|
return operand
|
|
convert_element_type_p.def_custom_bind(_convert_element_type_bind)
|
|
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
|
|
convert_element_type_p.def_abstract_eval(
|
|
partial(standard_abstract_eval, convert_element_type_p,
|
|
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
|
_convert_element_type_weak_type_rule,
|
|
_convert_element_type_sharding_rule))
|
|
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
|
|
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
|
|
batching.defvectorized(convert_element_type_p)
|
|
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
|
|
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
|
|
pe.def_trivial_padding(convert_element_type_p)
|
|
core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
|
|
|
|
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
|
|
|
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
|
|
sharding):
|
|
aval_in, = ctx.avals_in
|
|
aval_out, = ctx.avals_out
|
|
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
operand = hlo.real(operand)
|
|
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
|
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
|
|
|
|
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
|
|
|
|
|
|
def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
|
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
|
|
|
if old_dtype.itemsize == new_dtype.itemsize:
|
|
return operand.shape
|
|
elif old_dtype.itemsize > new_dtype.itemsize:
|
|
return (*operand.shape, old_dtype.itemsize // new_dtype.itemsize)
|
|
else:
|
|
dim_size = operand.shape[-1] if operand.shape else 1
|
|
if dim_size * old_dtype.itemsize != new_dtype.itemsize:
|
|
raise ValueError(
|
|
f"Attempting to convert array of shape {operand.shape} "
|
|
f"from {old_dtype} of size {old_dtype.itemsize} "
|
|
f"to {new_dtype} of size {new_dtype.itemsize}, "
|
|
f"but {dim_size} * {old_dtype.itemsize} != {new_dtype.itemsize}")
|
|
return operand.shape[:-1]
|
|
|
|
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
|
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
|
if (dtypes.issubdtype(old_dtype, np.bool_) or
|
|
dtypes.issubdtype(old_dtype, np.complexfloating) or
|
|
dtypes.issubdtype(new_dtype, np.bool_) or
|
|
dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
if old_dtype != new_dtype:
|
|
raise TypeError("lax.bitcast_convert_type does not support bool or complex values "
|
|
"unless the operand and destination types match. "
|
|
f"Got operand dtype={old_dtype}, {new_dtype=}. "
|
|
"Consider using the arr.view() method instead.")
|
|
return new_dtype
|
|
|
|
bitcast_convert_type_p = standard_primitive(
|
|
_bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule,
|
|
'bitcast_convert_type', weak_type_rule=_strip_weak_type)
|
|
ad.defjvp_zero(bitcast_convert_type_p)
|
|
batching.defvectorized(bitcast_convert_type_p)
|
|
|
|
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
|
|
aval_out, = ctx.avals_out
|
|
return [hlo.bitcast_convert(mlir.aval_to_ir_type(aval_out), operand)]
|
|
|
|
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
|
|
|
|
|
|
def _validate_preferred_element_type(input_dtype, preferred_element_type):
|
|
if (dtypes.issubdtype(input_dtype, np.integer) and
|
|
dtypes.issubdtype(preferred_element_type, np.floating)):
|
|
# Special-case integer->float multiply. This is allowed, and also allows
|
|
# different signedness between input and output.
|
|
pass
|
|
else:
|
|
allowed_types = (np.integer, np.floating, np.complexfloating)
|
|
if any(dtypes.issubdtype(input_dtype, t) and not
|
|
dtypes.issubdtype(preferred_element_type, t) for t in allowed_types):
|
|
raise TypeError("Input type is incompatible with "
|
|
"`preferred_element_type`. The compatible combinations "
|
|
"of (input_type, preferred_element_type) are "
|
|
"(integral, integral), (integral, floating), "
|
|
"(floating, floating), (complex, complex.")
|
|
if (dtypes.issubdtype(input_dtype, np.signedinteger) and
|
|
not dtypes.issubdtype(preferred_element_type, np.signedinteger)):
|
|
raise TypeError("`preferred_element_type` must have the same signedness "
|
|
"as the original type.")
|
|
input_bitwidth = np.dtype(input_dtype).itemsize
|
|
preferred_bitwidth = np.dtype(preferred_element_type).itemsize
|
|
if preferred_bitwidth < input_bitwidth:
|
|
raise TypeError("`preferred_element_type` must not be narrower than the "
|
|
"original type.")
|
|
|
|
|
|
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
|
preferred_element_type: DTypeLike | None):
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
|
for d in (lhs_contracting, lhs_batch)):
|
|
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
|
|
"less than the number of axes of the lhs value, got "
|
|
f"lhs_batch of {lhs_batch} and lhs_contracting of {lhs_contracting} "
|
|
f"for lhs of rank {lhs.ndim}")
|
|
raise TypeError(msg)
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, rhs.ndim))
|
|
for d in (rhs_contracting, rhs_batch)):
|
|
msg = ("dot_general requires rhs dimension numbers to be nonnegative and "
|
|
"less than the number of axes of the rhs value, got "
|
|
f"rhs_batch of {rhs_batch} and rhs_contracting of {rhs_contracting} "
|
|
f"for rhs of rank {rhs.ndim}")
|
|
raise TypeError(msg)
|
|
if len(lhs_batch) != len(rhs_batch):
|
|
msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
|
|
"dimensions, got lhs_batch {} and rhs_batch {}.")
|
|
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
|
lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch)
|
|
rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch)
|
|
if len(lhs_batch_set) != len(lhs_batch):
|
|
msg = ("dot_general requires lhs batch dimensions to be distinct, got "
|
|
f"lhs_batch {lhs_batch}.")
|
|
raise TypeError(msg)
|
|
if len(rhs_batch_set) != len(rhs_batch):
|
|
msg = ("dot_general requires rhs batch dimensions to be distinct, got "
|
|
f"rhs_batch {rhs_batch}.")
|
|
raise TypeError(msg)
|
|
if len(lhs_contracting_set) != len(lhs_contracting):
|
|
msg = ("dot_general requires lhs contracting dimensions to be distinct, "
|
|
f"got lhs_contracting {lhs_contracting}.")
|
|
raise TypeError(msg)
|
|
if len(rhs_contracting_set) != len(rhs_contracting):
|
|
msg = ("dot_general requires rhs contracting dimensions to be distinct, "
|
|
f"got rhs_contracting {rhs_contracting}.")
|
|
raise TypeError(msg)
|
|
if lhs_contracting_set & lhs_batch_set:
|
|
msg = ("dot_general requires lhs batch dimensions to be disjoint from "
|
|
"contracting dimensions, got lhs_batch {} and lhs_contracting {}.")
|
|
raise TypeError(msg.format(lhs_batch, lhs_contracting))
|
|
if rhs_contracting_set & rhs_batch_set:
|
|
msg = ("dot_general requires rhs batch dimensions to be disjoint from "
|
|
"contracting dimensions, got rhs_batch {} and rhs_contracting {}.")
|
|
raise TypeError(msg.format(rhs_batch, rhs_contracting))
|
|
lhs_batch_shape = tuple(lhs.shape[i] for i in lhs_batch)
|
|
rhs_batch_shape = tuple(rhs.shape[i] for i in rhs_batch)
|
|
if not core.definitely_equal_shape(lhs_batch_shape, rhs_batch_shape):
|
|
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
|
|
"to have the same shape, got {} and {}.")
|
|
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
|
|
lhs_contracting_shape = tuple(lhs.shape[i] for i in lhs_contracting)
|
|
rhs_contracting_shape = tuple(rhs.shape[i] for i in rhs_contracting)
|
|
if not core.definitely_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
|
|
msg = ("dot_general requires contracting dimensions to have the same "
|
|
"shape, got {} and {}.")
|
|
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
|
|
|
|
return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
|
|
|
|
def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
batch_shape = tuple(lhs_shape[i] for i in lhs_batch)
|
|
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
|
lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch)
|
|
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
|
|
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch)
|
|
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
|
|
|
def tuple_delete(tup, idx):
|
|
idx_ = set(idx)
|
|
return tuple(tup[i] for i in range(len(tup)) if i not in idx_)
|
|
|
|
|
|
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
|
preferred_element_type: DTypeLike | None):
|
|
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
|
|
# primitive_util.h's HigherPrecisionType, e.g.
|
|
# https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329
|
|
def type_properties(dt):
|
|
c = _real_dtype(dt) if dtypes.issubdtype(dt, np.complexfloating) else dt
|
|
return (dtypes.issubdtype(dt, np.complexfloating),
|
|
dtypes.finfo(c).maxexp if dtypes.issubdtype(c, np.floating) else -1,
|
|
dtypes.finfo(c).nmant if dtypes.issubdtype(c, np.floating) else -1,
|
|
_bit_width(c),
|
|
not dtypes.issubdtype(c, np.unsignedinteger))
|
|
lhs_prop, rhs_prop = type_properties(lhs.dtype), type_properties(rhs.dtype)
|
|
if lhs_prop > rhs_prop:
|
|
result_dtype = lhs.dtype
|
|
elif rhs_prop > lhs_prop:
|
|
result_dtype = rhs.dtype
|
|
else:
|
|
if lhs.dtype != rhs.dtype:
|
|
raise TypeError(
|
|
f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}")
|
|
result_dtype = lhs.dtype
|
|
|
|
return _maybe_upcast(result_dtype, preferred_element_type)
|
|
|
|
def _bit_width(d):
|
|
if dtypes.issubdtype(d, np.inexact): return dtypes.finfo(d).bits
|
|
elif dtypes.issubdtype(d, np.integer): return dtypes.iinfo(d).bits
|
|
elif d == np.dtype('bool'): return 1
|
|
else: assert False, d # should be unreachable, open an issue!
|
|
|
|
def _maybe_upcast(result_dtype, preferred_element_type):
|
|
# replicates the logic in shape_inference.cc's MaybeUpcast
|
|
if (preferred_element_type is None or
|
|
result_dtype == preferred_element_type):
|
|
return result_dtype
|
|
if (not dtypes.issubdtype(result_dtype, np.floating) and
|
|
_bit_width(preferred_element_type) < _bit_width(result_dtype)):
|
|
raise TypeError("`preferred_element_type` must not be narrower than the "
|
|
"original type, got preferred_element_type of "
|
|
f"{preferred_element_type} for result type of "
|
|
f"{result_dtype}.")
|
|
return preferred_element_type
|
|
|
|
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
|
preferred_element_type: DTypeLike | None,
|
|
swap_ans=False):
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
x_ndim = x.aval.ndim
|
|
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
|
y_kept = remaining(range(np.ndim(y)), y_contract, y_batch)
|
|
if swap_ans:
|
|
ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
|
|
else:
|
|
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
|
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
|
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
|
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
|
x_bar = transpose(dot_general(g, y, dims, precision=precision,
|
|
preferred_element_type=preferred_element_type),
|
|
tuple(out_axes))
|
|
if x_bar.dtype != x.aval.dtype:
|
|
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
|
|
return x_bar
|
|
|
|
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
|
preferred_element_type: DTypeLike | None):
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
|
y_bar = _dot_general_transpose_lhs(
|
|
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
|
preferred_element_type=preferred_element_type,
|
|
swap_ans=True)
|
|
if y_bar.dtype != y.aval.dtype:
|
|
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
|
|
return y_bar
|
|
|
|
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
|
precision,
|
|
preferred_element_type: DTypeLike | None):
|
|
lhs, rhs = batched_args
|
|
lbd, rbd = batch_dims
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
|
|
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
|
|
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
|
|
(np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim),
|
|
dimension_numbers)
|
|
# TODO Should probably check that any ragged dimensions have corresponding
|
|
# sizes, because otherwise the dot product is technically undefined.
|
|
#
|
|
# This masking is not strictly necessary for non-contraction dimensions;
|
|
# we could micro-optimize here by avoiding computing that mask.
|
|
if type(lbd) is RaggedAxis:
|
|
lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd)
|
|
lhs_shape = batching.bdim_as_shape(lbd, lhs.shape)
|
|
else:
|
|
lhs_shape = np.shape(lhs)
|
|
if type(rbd) is RaggedAxis:
|
|
rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd)
|
|
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
|
|
else:
|
|
rhs_shape = np.shape(rhs)
|
|
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type)
|
|
result_batch_dim = batching.shape_as_bdim(
|
|
result_stack_dim,
|
|
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
|
|
return batched_out, result_batch_dim
|
|
|
|
def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
|
# There are three kinds of dimensions in a dot_general:
|
|
# - contraction dimensions appear in lhs and rhs but not the result
|
|
# - batch dimensions appear in lhs, rhs, and result
|
|
# - tensor product dimensions appear in the result and one of lhs or rhs
|
|
# The dimensions of the result are ordered as
|
|
# - Batch dimensions
|
|
# - Q: In what order? The order of appearance in lhs, rhs, or
|
|
# dimension_numbers?
|
|
# - Tensor dimensions from the LHS
|
|
# - Tensor dimensions from the RHS
|
|
lhs_ndim, rhs_ndim = ndims
|
|
# lbd and rbd are "batch" dimensions in the sense of dimensions being
|
|
# vmapped, not to be confused with "batch" dimensions in the sense of
|
|
# explicitly present dimensions that this dot_general is zipping together.
|
|
lbd, rbd = batch_dims
|
|
assert lbd is not None or rbd is not None
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
def bump_dims(dims, b):
|
|
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
|
|
|
if type(lbd) is type(rbd) is int:
|
|
# The vmapped dimensions become an additional batch dimension in the
|
|
# batched dot_general, which we arbitrarily put first.
|
|
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
|
|
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
|
result_batch_dim = 0
|
|
elif (type(lbd) is int and rbd is None):
|
|
# The left vmapped dimension becomes an additional tensor dimension in the
|
|
# batched dot_general.
|
|
lhs_tensor = [d for d in range(lhs_ndim)
|
|
if d not in lhs_batch and d not in lhs_contract]
|
|
result_batch_dim = len(lhs_batch) + int(sum(np.less(lhs_tensor, lbd)))
|
|
lhs_batch = bump_dims(lhs_batch, lbd)
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
|
elif (type(rbd) is int and lbd is None):
|
|
# The right vmapped dimension becomes an additional tensor dimension in the
|
|
# batched dot_general.
|
|
rhs_tensor = [d for d in range(rhs_ndim)
|
|
if d not in rhs_batch and d not in rhs_contract]
|
|
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
|
int(sum(np.less(rhs_tensor, rbd))))
|
|
rhs_batch = bump_dims(rhs_batch, rbd)
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
|
else:
|
|
# We wouldn't be here if we didn't have at least one vmapped dimension.
|
|
assert False
|
|
|
|
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
|
return new_dimension_numbers, result_batch_dim
|
|
|
|
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
|
|
dimension_numbers, **params):
|
|
lhs_aval, _ = in_avals
|
|
(lhs_contract, _), _ = dimension_numbers
|
|
padded_axes = [(i, lhs_aval.shape[i].val) for i in lhs_contract
|
|
if isinstance(lhs_aval.shape[i], pe.BoundedAxisSize)]
|
|
lhs_ = _replace_masked_values(lhs, 0, padded_axes)
|
|
return [dot_general(lhs_, rhs, dimension_numbers=dimension_numbers, **params)]
|
|
|
|
def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
|
|
# * suppress printing precision or preferred_element_type when None.
|
|
# * print dimension_numbers as list-of-lists to be shorter.
|
|
printed_params = {k: v for k, v in eqn.params.items() if v is not None}
|
|
(lhs_cont, rhs_cont), (lhs_batch, rhs_batch) = eqn.params['dimension_numbers']
|
|
printed_params['dimension_numbers'] = (
|
|
(list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch)))
|
|
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
|
|
|
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
|
_dot_general_dtype_rule, 'dot_general')
|
|
ad.defbilinear(dot_general_p,
|
|
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
|
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
|
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
|
|
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
|
|
|
|
def precision_attr(precision: Precision) -> ir.ArrayAttr:
|
|
if precision is None:
|
|
full_precision = (Precision.DEFAULT, Precision.DEFAULT)
|
|
elif not isinstance(precision, tuple):
|
|
full_precision = (precision, precision)
|
|
else:
|
|
full_precision = precision
|
|
return ir.ArrayAttr.get(
|
|
[hlo.PrecisionAttr.get(str(p)) for p in full_precision])
|
|
|
|
|
|
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
|
precision, preferred_element_type: np.dtype | None,
|
|
platform: str = "default"):
|
|
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
|
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
|
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
|
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
|
|
del preferred_element_type # Implied by the output aval
|
|
lhs_aval, rhs_aval = ctx.avals_in
|
|
lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype
|
|
aval_out, = ctx.avals_out
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
# TODO(b/...): JAX's dot_general primitive accepts the same input dtype
|
|
# combinations that are accepted in XLA's shape_inference.cc (the canonical
|
|
# reference for the HLO type system), but actually different XLA platforms
|
|
# fail on codegen for different accepted cases. To handle those cases, we
|
|
# insert ConvertOps on the input, in a platform-dependent way.
|
|
if lhs_dtype != rhs_dtype:
|
|
if platform == "tpu":
|
|
handled = lambda dt: (dtypes.issubdtype(dt, np.floating) or
|
|
dtypes.issubdtype(dt, np.integer))
|
|
if not (handled(lhs_dtype) and handled(rhs_dtype)):
|
|
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
|
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
|
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
|
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
|
lhs_dtype = rhs_dtype = aval_out.dtype
|
|
else: # cpu and gpu
|
|
# Do not convert mixed fp8 types to output type.
|
|
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
|
|
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
|
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
|
|
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
|
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
|
lhs_dtype = rhs_dtype = aval_out.dtype
|
|
|
|
|
|
dot_dnums = hlo.DotDimensionNumbers.get(
|
|
lhs_batching_dimensions=list(lhs_batch),
|
|
rhs_batching_dimensions=list(rhs_batch),
|
|
lhs_contracting_dimensions=list(lhs_contracting),
|
|
rhs_contracting_dimensions=list(rhs_contracting))
|
|
return [
|
|
hlo.dot_general(
|
|
mlir.aval_to_ir_type(aval_out),
|
|
lhs,
|
|
rhs,
|
|
dot_dnums,
|
|
precision_config=precision_attr(precision))
|
|
]
|
|
|
|
mlir.register_lowering(dot_general_p, _dot_general_lower)
|
|
|
|
for platform in ["cpu", "tpu"]:
|
|
mlir.register_lowering(dot_general_p,
|
|
partial(_dot_general_lower, platform=platform),
|
|
platform=platform)
|
|
|
|
|
|
def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> Shape:
|
|
m, k = lhs.shape
|
|
group_count, rk, n = rhs.shape
|
|
if k != rk:
|
|
raise TypeError(f"ragged_dot requires that lhs.shape[1] == rhs.shape[1]: got {k} and {rk}.")
|
|
num_groups = group_sizes.shape[0]
|
|
if group_count != num_groups:
|
|
raise TypeError(f"ragged_dot requires that rhs.shape[0] == group_sizes.shape[0]: got {group_count} and {num_groups}.")
|
|
return (m, n)
|
|
|
|
# DotDimensionNumbers used in the dot_general call for ragged_dot().
|
|
_RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], []))
|
|
|
|
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
|
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
|
|
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
|
|
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
|
|
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
|
|
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type)
|
|
|
|
|
|
def _ragged_dot_jvp_rule(
|
|
primals, tangents, precision, preferred_element_type, group_offset
|
|
):
|
|
# note - we could ostensibly just get this by passing on the
|
|
# value to ragged_dot below, but, this feels cleaner.
|
|
if group_offset is not None:
|
|
raise NotImplementedError('Unimplemented group_offset support.')
|
|
x, y, gs = primals
|
|
dx, dy, _ = tangents # no tan on the gs
|
|
|
|
# primal
|
|
primal_out = ragged_dot(
|
|
x,
|
|
y,
|
|
gs,
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type,
|
|
)
|
|
|
|
# tangent
|
|
dx_out = (
|
|
ragged_dot(
|
|
dx,
|
|
y,
|
|
gs,
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type,
|
|
)
|
|
if type(dx) is not ad_util.Zero
|
|
else _zeros(primal_out)
|
|
)
|
|
dy_out = (
|
|
ragged_dot(
|
|
x,
|
|
dy,
|
|
gs,
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type,
|
|
)
|
|
if type(dy) is not ad_util.Zero
|
|
else _zeros(primal_out)
|
|
)
|
|
tangent_out = dx_out + dy_out
|
|
|
|
return primal_out, tangent_out
|
|
|
|
|
|
def _ragged_to_dense(x, y, group_sizes):
|
|
from jax._src.lax import control_flow # avoid circular imports
|
|
shape = (y.shape[0], x.shape[0], x.shape[1])
|
|
x = broadcast_in_dim(x, shape, [1, 2])
|
|
iota = broadcasted_iota(group_sizes.dtype, shape, 1)
|
|
group_ends = control_flow.cumsum(group_sizes)
|
|
group_starts = concatenate(
|
|
[_zeros(group_sizes)[:1], group_ends[:-1]],
|
|
dimension=0,
|
|
)
|
|
group_ends = broadcast_in_dim(group_ends, shape, (0,))
|
|
group_starts = broadcast_in_dim(group_starts, shape, (0,))
|
|
mask = bitwise_and(group_starts <= iota, iota < group_ends)
|
|
x = select(mask, x, _zeros(x))
|
|
return x
|
|
|
|
|
|
def _ragged_dot_transpose_rule(
|
|
ct, *operands, precision, preferred_element_type, group_offset
|
|
):
|
|
x, y, gs = operands
|
|
if group_offset is not None:
|
|
raise NotImplementedError('Unimplemented group_offset support.')
|
|
|
|
if ad.is_undefined_primal(y):
|
|
grad_x = None
|
|
else:
|
|
y_t = _matrix_transpose(y)
|
|
grad_x = ragged_dot(
|
|
ct,
|
|
y_t,
|
|
gs,
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type,
|
|
)
|
|
|
|
if ad.is_undefined_primal(x):
|
|
grad_y = None
|
|
else:
|
|
y = y.aval if ad.is_undefined_primal(y) else y
|
|
x_dense = _ragged_to_dense(x, y, group_sizes=gs)
|
|
ct_dense = _ragged_to_dense(ct, y, group_sizes=gs)
|
|
dimension_numbers = (([1], [1]), ([0], [0]))
|
|
grad_y = dot_general(
|
|
x_dense,
|
|
ct_dense,
|
|
dimension_numbers,
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type,
|
|
)
|
|
|
|
return grad_x, grad_y, None
|
|
|
|
|
|
ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
|
|
_ragged_dot_dtype_rule, 'ragged_dot')
|
|
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
|
|
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
|
|
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
|
|
|
|
def _ragged_dot_impl(
|
|
lhs: Array,
|
|
rhs: Array,
|
|
group_sizes: Array,
|
|
precision: PrecisionLike = None,
|
|
preferred_element_type: DTypeLike | None = None,
|
|
group_offset: Array | None = None,
|
|
) -> Array:
|
|
if group_offset is not None:
|
|
raise NotImplementedError("Unimplemented group_offset support.")
|
|
lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes)
|
|
return dot_general(
|
|
lhs,
|
|
rhs,
|
|
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
|
precision=precision,
|
|
preferred_element_type=preferred_element_type,
|
|
)
|
|
|
|
mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False))
|
|
|
|
|
|
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
|
|
_check_shapelike('broadcast_in_dim', 'shape', shape)
|
|
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
|
|
broadcast_dimensions)
|
|
operand_ndim = np.ndim(operand)
|
|
if operand_ndim != len(broadcast_dimensions):
|
|
msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
|
|
'operand ndim; got broadcast_dimensions {} for operand ndim {}.')
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim))
|
|
if len(shape) < operand_ndim:
|
|
msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank '
|
|
'to the operand shape; got operand ndim {} and target broadcast ndim {}.')
|
|
raise TypeError(msg.format(operand_ndim, len(shape)))
|
|
if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
|
|
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
|
'dimensions, got {} for operand ndim {} and shape {}.')
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
|
|
if not all(core.definitely_equal_one_of_dim(operand.shape[i],
|
|
[1, shape[broadcast_dimensions[i]]])
|
|
for i in range(operand_ndim)):
|
|
msg = (
|
|
"broadcast_in_dim operand dimension sizes must either be 1, or be "
|
|
"equal to their corresponding dimensions in the target broadcast "
|
|
"shape; got operand of shape {}, target broadcast shape {}, "
|
|
"broadcast_dimensions {} ")
|
|
raise TypeError(msg.format(
|
|
tuple(core.replace_tracer_for_error_message(d) for d in operand.shape),
|
|
shape, broadcast_dimensions))
|
|
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
|
|
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
|
|
msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; "
|
|
"got broadcast_dimensions {}")
|
|
raise TypeError(msg.format(broadcast_dimensions))
|
|
|
|
return shape
|
|
|
|
def _broadcast_in_dim_typecheck_rule(
|
|
_, operand, *dyn_shape, shape, broadcast_dimensions):
|
|
if not dyn_shape:
|
|
out_aval, effects = broadcast_in_dim_p.abstract_eval(
|
|
operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
|
return [out_aval], effects
|
|
else:
|
|
# TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule
|
|
out_shape = _merge_dyn_shape(shape, dyn_shape)
|
|
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
|
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
|
|
operand.aval.weak_type)
|
|
return [out_aval], core.no_effects
|
|
|
|
def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
|
|
shape, broadcast_dimensions):
|
|
if type(ct) is ad_util.Zero:
|
|
return [ad_util.Zero(operand.aval)]
|
|
unit_dims = [i for i, s in enumerate(operand.aval.shape)
|
|
if core.definitely_equal(s, 1)]
|
|
bdims = tuple(np.delete(broadcast_dimensions, unit_dims))
|
|
axes = tuple(np.delete(range(len(shape)), bdims))
|
|
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
|
|
[None] * len(dyn_shape))
|
|
|
|
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
|
broadcast_dimensions):
|
|
# `dyn_shape` is the dynamic portion of the target shape. `shape`
|
|
# is the target shape, with `None` for dynamic sections.
|
|
# broadcast_dimensions gives indices where dimensions of the input
|
|
# have to go: dimension i of the input becomes dimension
|
|
# broadcast_dimensions[i] of the output.
|
|
operand, *dyn_shape = batched_args
|
|
operand_bdim, *dyn_shape_bdims = batch_dims
|
|
|
|
stacked_size = None
|
|
if operand_bdim is not None:
|
|
if isinstance(operand_bdim, RaggedAxis):
|
|
stacked_axis = operand_bdim.stacked_axis
|
|
else:
|
|
stacked_axis = operand_bdim
|
|
new_operand = batching.moveaxis(operand, stacked_axis, 0)
|
|
if isinstance(operand_bdim, RaggedAxis):
|
|
stacked_size = operand_bdim.size
|
|
else:
|
|
stacked_size = operand.shape[stacked_axis]
|
|
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
|
|
else:
|
|
new_operand = operand
|
|
new_broadcast_dimensions = tuple(np.add(1, broadcast_dimensions))
|
|
|
|
# TODO(mattjj,axch) This section assumes that the shape of the operand is
|
|
# broadcast-compatible with the requested shape. We should tweak vmap to run
|
|
# the abstract_eval rule so this can be checked while the raggedness
|
|
# information is available.
|
|
dyn_limits = []
|
|
out_ragged_sizes = []
|
|
for sizes, bdim in zip(dyn_shape, dyn_shape_bdims):
|
|
if bdim is None:
|
|
# TODO(mattjj,axch) Is this what bdim == None means?
|
|
assert isinstance(sizes, int)
|
|
bound = sizes
|
|
else:
|
|
bound = sizes.dtype.bound
|
|
out_ragged_sizes.append(sizes)
|
|
if stacked_size is None:
|
|
stacked_size = len(sizes)
|
|
else:
|
|
msg = "All segments lengths arrays must be the same length"
|
|
assert len(sizes) == stacked_size, msg
|
|
dyn_limits.append(bound)
|
|
new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
|
|
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions)
|
|
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
|
|
out_bdim = batching.make_batch_axis(
|
|
result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes))
|
|
return result, out_bdim
|
|
|
|
def _broadcast_in_dim_fwd_rule(eqn):
|
|
v, *dyn = eqn.invars
|
|
if not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape):
|
|
return [v], None
|
|
else:
|
|
return [None], eqn
|
|
|
|
def _broadcast_in_dim_staging_rule(
|
|
trace, x, *dyn, shape, broadcast_dimensions):
|
|
params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions)
|
|
if not dyn:
|
|
return trace.default_process_primitive(broadcast_in_dim_p, (x,), params)
|
|
aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type)
|
|
return _dyn_shape_staging_rule(trace, broadcast_in_dim_p, aval, x, *dyn,
|
|
**params)
|
|
|
|
def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
|
|
shape, broadcast_dimensions):
|
|
del in_avals, dyn_shape
|
|
out_aval, = out_avals
|
|
new_shape = []
|
|
new_dyn_shape = []
|
|
for d in out_aval.shape:
|
|
if type(d) is pe.BoundedAxisSize:
|
|
new_shape.append(d.bound)
|
|
elif type(d) is int:
|
|
new_shape.append(d)
|
|
else:
|
|
assert isinstance(d, core.Tracer)
|
|
new_shape.append(None)
|
|
new_dyn_shape.append(d)
|
|
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape),
|
|
broadcast_dimensions=broadcast_dimensions)]
|
|
|
|
def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions):
|
|
operand, *dyn_shape = primals
|
|
operand_dot, *_ = tangents
|
|
y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape,
|
|
broadcast_dimensions=broadcast_dimensions)
|
|
if type(operand_dot) is ad_util.Zero:
|
|
y_dot = ad_util.Zero.from_primal_value(y)
|
|
else:
|
|
y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape,
|
|
broadcast_dimensions=broadcast_dimensions)
|
|
return y, y_dot
|
|
|
|
def _broadcast_in_dim_partial_eval(
|
|
trace, operand, *dyn_shape, shape, broadcast_dimensions):
|
|
if not dyn_shape:
|
|
return trace.default_process_primitive(
|
|
broadcast_in_dim_p, (operand, *dyn_shape),
|
|
dict(shape=shape, broadcast_dimensions=broadcast_dimensions))
|
|
assert all(t.pval.is_known() for t in dyn_shape)
|
|
operand_tracer = trace.instantiate_const(operand)
|
|
dyn_shape_tracers = map(trace.instantiate_const, dyn_shape)
|
|
dyn_shape_tracers_ = iter(dyn_shape_tracers)
|
|
shape_ = [next(dyn_shape_tracers_) if d is None else d for d in shape]
|
|
out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type)
|
|
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
|
eqn = pe.new_eqn_recipe(
|
|
[operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
|
|
dict(shape=shape, broadcast_dimensions=broadcast_dimensions),
|
|
core.no_effects, source_info_util.current())
|
|
out_tracer.recipe = eqn
|
|
return out_tracer
|
|
|
|
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) -> Sequence[ir.Value]:
|
|
aval_out, = ctx.avals_out
|
|
if dyn_shape:
|
|
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
|
|
|
|
|
return [mlir.broadcast_in_dim(ctx, x, aval_out,
|
|
broadcast_dimensions=broadcast_dimensions)]
|
|
|
|
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
|
|
if (not dyn_shape and
|
|
not any(isinstance(d, core.DArray) and
|
|
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
|
shape = _broadcast_in_dim_shape_rule( # error checking
|
|
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
|
return core.ShapedArray(shape, x.dtype, x.weak_type)
|
|
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
|
|
# (even if x is a ShapedArray)
|
|
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
|
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)
|
|
|
|
broadcast_in_dim_p = standard_primitive(
|
|
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
|
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
|
|
ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule
|
|
ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule
|
|
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
|
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
|
|
pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_partial_eval
|
|
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
|
|
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
|
|
core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule
|
|
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
|
|
|
|
|
|
def _clamp_shape_rule(min, operand, max):
|
|
if min.shape and min.shape != operand.shape:
|
|
raise TypeError("clamp requires min.shape == operand.shape or min.shape == "
|
|
f"(), got min.shape={min.shape}, {operand.shape=}.")
|
|
if max.shape and max.shape != operand.shape:
|
|
raise TypeError("clamp requires max.shape == operand.shape or max.shape == "
|
|
f"(), got max.shape={max.shape}, {operand.shape=}.")
|
|
return operand.shape
|
|
|
|
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
|
|
'clamp')
|
|
|
|
def _clamp_batch_rule(batched_args, batch_dims, **params):
|
|
min, x, max = batched_args
|
|
min_bdim, x_bdim, max_bdim = batch_dims
|
|
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
|
if i is not None)
|
|
|
|
# avoid transposes and some broadcasts in special cases
|
|
if min_bdim == x_bdim == max_bdim:
|
|
if np.shape(min) == np.shape(x) == np.shape(max):
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
elif np.ndim(min) == np.ndim(max) == 0:
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
elif np.ndim(min) == np.ndim(max) == 1:
|
|
min = broadcast_in_dim(min, x.shape, [min_bdim])
|
|
max = broadcast_in_dim(max, x.shape, [max_bdim])
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
elif np.ndim(min) == 0 and np.ndim(max) == 0 and x_bdim is not None:
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
min = batching.bdim_at_front(min, min_bdim, size) if np.shape(min) else min
|
|
max = batching.bdim_at_front(max, max_bdim, size) if np.shape(max) else max
|
|
x = batching.bdim_at_front(x, x_bdim, size) if np.shape(x) else x
|
|
if np.ndim(min) == 0 and np.ndim(x) > 0:
|
|
min = broadcast(min, x.shape)
|
|
if np.ndim(max) == 0 and np.ndim(x) > 0:
|
|
max = broadcast(max, x.shape)
|
|
if 0 < np.ndim(min) < np.ndim(x):
|
|
assert np.ndim(min) == 1, np.ndim(min)
|
|
min = broadcast_in_dim(min, x.shape, [0])
|
|
if 0 < np.ndim(max) < np.ndim(x):
|
|
assert np.ndim(max) == 1, np.ndim(max)
|
|
max = broadcast_in_dim(max, x.shape, [0])
|
|
if np.ndim(min) > np.ndim(x):
|
|
assert np.ndim(x) == 0, np.ndim(x)
|
|
x = broadcast(x, min.shape)
|
|
return clamp_p.bind(min, x, max), 0
|
|
|
|
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
|
|
ad.defjvp(clamp_p,
|
|
lambda g, min, operand, max:
|
|
select(bitwise_and(gt(min, operand), lt(min, max)),
|
|
g, _zeros(operand)),
|
|
lambda g, min, operand, max:
|
|
select(bitwise_and(gt(operand, min), lt(operand, max)),
|
|
g, _zeros(operand)),
|
|
lambda g, min, operand, max:
|
|
select(lt(max, operand), g, _zeros(operand)))
|
|
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
|
|
mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp))
|
|
pe.def_trivial_padding(clamp_p)
|
|
|
|
def _concatenate_shape_rule(*operands, **kwargs):
|
|
dimension = kwargs.pop('dimension')
|
|
if not operands:
|
|
msg = "concatenate expects at least one operand, got 0."
|
|
raise TypeError(msg)
|
|
if not all(isinstance(operand, UnshapedArray) for operand in operands):
|
|
msg = "All objects to concatenate must be arrays, got {}."
|
|
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
|
raise TypeError(msg.format(type(op)))
|
|
if len({operand.ndim for operand in operands}) != 1:
|
|
msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."
|
|
raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
|
|
if not 0 <= dimension < operands[0].ndim:
|
|
msg = "concatenate dimension out of bounds: dimension {} for shapes {}."
|
|
raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands])))
|
|
shapes = [operand.shape[:dimension] + operand.shape[dimension+1:]
|
|
for operand in operands]
|
|
if not shapes[:-1] == shapes[1:]:
|
|
msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
|
|
"other than the one being concatenated: concatenating along "
|
|
"dimension {} for shapes {}.")
|
|
shapes = [operand.shape for operand in operands]
|
|
raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
|
|
|
|
concat_size = sum(o.shape[dimension] for o in operands)
|
|
ex_shape = operands[0].shape
|
|
return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
|
|
|
|
def _concatenate_dtype_rule(*operands, **kwargs):
|
|
check_same_dtypes('concatenate', *operands)
|
|
return operands[0].dtype
|
|
|
|
def _concatenate_transpose_rule(t, *operands, dimension):
|
|
operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape
|
|
for o in operands]
|
|
if type(t) is ad_util.Zero:
|
|
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
|
|
for o in operands]
|
|
else:
|
|
limit_points = np.cumsum(
|
|
[shape[dimension] for shape in operand_shapes]).tolist()
|
|
starts = np.zeros((len(operands), t.ndim), dtype=int).tolist()
|
|
limits = np.tile(t.shape, (len(operands), 1)).tolist()
|
|
|
|
for i, s in enumerate(starts[1:]):
|
|
s[dimension] = limit_points[:-1][i]
|
|
for i, l in enumerate(limits):
|
|
l[dimension] = limit_points[i]
|
|
|
|
return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o)
|
|
else None for o, start, limit in zip(operands, starts, limits)]
|
|
|
|
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
|
|
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
|
|
if bdim is not None)
|
|
operands = [batching.moveaxis(op, bdim, 0) if bdim is not None
|
|
else broadcast(op, (size,))
|
|
for op, bdim in zip(batched_args, batch_dims)]
|
|
return concatenate(operands, dimension + 1), 0
|
|
|
|
def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension):
|
|
if all(isinstance(a.shape[dimension], (int, np.integer))
|
|
for a in in_avals):
|
|
return [concatenate(operands, dimension)]
|
|
else:
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
concatenate_p = standard_primitive(
|
|
_concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate')
|
|
ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
|
|
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
|
|
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
|
|
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
|
|
|
|
def _concatenate_lower(ctx, *xs, dimension):
|
|
return [hlo.concatenate(xs, mlir.i64_attr(dimension))]
|
|
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
|
|
|
|
|
def _pad_dtype_rule(operand, padding_value, *, padding_config):
|
|
if operand.dtype != padding_value.dtype:
|
|
msg = "pad operand and padding_value must be same dtype: got {} and {}."
|
|
raise TypeError(msg.format(operand.dtype, padding_value.dtype))
|
|
|
|
return _input_dtype(operand, padding_value)
|
|
|
|
def _pad_shape_rule(operand, padding_value, *, padding_config):
|
|
del padding_value
|
|
op_shape = np.shape(operand)
|
|
if not len(padding_config) == np.ndim(operand):
|
|
raise ValueError("length of padding_config must equal the number of axes "
|
|
f"of operand, got padding_config {padding_config} "
|
|
f"for operand shape {op_shape}")
|
|
if not all(i >= 0 for _, _, i in padding_config):
|
|
raise ValueError("interior padding in padding_config must be nonnegative, "
|
|
f"got padding_config {padding_config}")
|
|
result = tuple(l + h + core.dilate_dim(d, i + 1)
|
|
for (l, h, i), d in zip(padding_config, op_shape))
|
|
if not all(d >= 0 for d in result):
|
|
msg = (f"Dimension size after padding is not at least 0, "
|
|
f"got result shape {result}, for padding_config {padding_config}"
|
|
f" and operand shape {op_shape}")
|
|
raise ValueError(msg)
|
|
return result
|
|
|
|
def _pad_transpose(t, operand, padding_value, *, padding_config):
|
|
if type(t) is ad_util.Zero:
|
|
t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
|
|
t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None
|
|
else:
|
|
lo, hi, interior = util.unzip3(padding_config)
|
|
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
|
|
|
|
def t_op():
|
|
unpad_config = safe_zip(np.negative(lo), np.negative(hi),
|
|
np.zeros_like(interior))
|
|
unpadded = pad(t, np.array(0., t.dtype), unpad_config)
|
|
return slicing.slice(unpadded, np.zeros_like(lo), unpadded.shape,
|
|
np.add(interior, 1))
|
|
|
|
t_operand = t_op() if ad.is_undefined_primal(operand) else None
|
|
t_padv = sub(total(t), total(t_operand)) if ad.is_undefined_primal(padding_value) else None
|
|
return [t_operand, t_padv]
|
|
|
|
def _pad_batch_rule(batched_args, batch_dims, *, padding_config):
|
|
operand, padding_value = batched_args
|
|
operand_bdim, padding_value_bdim = batch_dims
|
|
if operand_bdim is None:
|
|
operand_bdim = 0
|
|
operand = broadcast(operand, (padding_value.shape[padding_value_bdim],))
|
|
|
|
padding_config = list(padding_config)
|
|
padding_config.insert(operand_bdim, (0, 0, 0))
|
|
if padding_value_bdim is None:
|
|
return pad(operand, padding_value, padding_config), operand_bdim
|
|
|
|
assert padding_value_bdim == 0, padding_value_bdim
|
|
|
|
x = pad(operand, _zero(operand), padding_config)
|
|
mask = pad(full_like(operand, True, np.bool_), False, padding_config)
|
|
broadcasted_padding = broadcast_in_dim(padding_value, x.shape,
|
|
(operand_bdim,))
|
|
return select(mask, x, broadcasted_padding), operand_bdim
|
|
|
|
pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad')
|
|
ad.deflinear2(pad_p, _pad_transpose)
|
|
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
|
|
|
def _pad_lower(ctx, x, padding_value, *, padding_config):
|
|
aval_out, = ctx.avals_out
|
|
low, high, interior = util.unzip3(padding_config)
|
|
return [mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)]
|
|
|
|
mlir.register_lowering(pad_p, _pad_lower)
|
|
|
|
|
|
# The squeeze primitive exists for the benefit of masking and other
|
|
# transformations that need to keep track of axis identity.
|
|
# For example, consider reshaping a 2D array with shape (1, N) into a 1D array
|
|
# with shape (N,). This results in the following JAXpr:
|
|
# reshape[ dimension=None new_sizes=(N,) ]
|
|
# For N > 1, we can match up the output array axis with the second axis of the
|
|
# input. But for N = 1, it is not clear how axes match up: all we know from the
|
|
# JAXpr is that we are reshaping from (1, 1) to (1,).
|
|
# In contrast, squeeze[ dimensions=(0,) ] is unambiguous.
|
|
|
|
|
|
def _squeeze_dtype_rule(operand, *, dimensions):
|
|
return operand.dtype
|
|
|
|
def _squeeze_shape_rule(operand, *, dimensions):
|
|
return _compute_squeeze_shape(np.shape(operand), dimensions)
|
|
|
|
def _compute_squeeze_shape(shape, dimensions):
|
|
dims_set = set(dimensions)
|
|
if len(dims_set) != len(dimensions):
|
|
raise ValueError(f"dimensions are not unique: {dimensions}")
|
|
if not all(0 <= d < len(shape) for d in dims_set):
|
|
raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
|
|
if any(not core.definitely_equal(shape[d], 1) for d in dimensions):
|
|
raise ValueError(
|
|
"cannot select an axis to squeeze out which has size not equal to "
|
|
f"one, got {shape=} and {dimensions=}")
|
|
return tuple(s for i, s in enumerate(shape) if i not in dims_set)
|
|
|
|
def _squeeze_transpose_rule(t, operand, *, dimensions):
|
|
assert ad.is_undefined_primal(operand)
|
|
return [expand_dims(t, dimensions)]
|
|
|
|
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
operand, bdim = batching.move_stacked_axis(operand, bdim, 0)
|
|
dimensions = tuple(np.add(1, dimensions))
|
|
out_stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
|
|
bdim_out = batching.shape_as_bdim(
|
|
out_stack_dim,
|
|
_compute_squeeze_shape(batching.bdim_as_shape(bdim, operand.shape), dimensions))
|
|
return squeeze(operand, dimensions=dimensions), bdim_out
|
|
|
|
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
|
|
'squeeze')
|
|
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
|
|
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
|
pe.def_trivial_padding(squeeze_p)
|
|
|
|
def _squeeze_lower(ctx, operand, *, dimensions):
|
|
del dimensions # Implied by the output aval.
|
|
return [mlir.reshape(ctx, operand, ctx.avals_out[0])]
|
|
|
|
mlir.register_lowering(squeeze_p, _squeeze_lower)
|
|
|
|
|
|
def shape_as_value(shape: core.Shape):
|
|
"""Converts a shape that may contain Poly values into a JAX value."""
|
|
if len(shape) == 0:
|
|
return full((0,), np.array(0, np.int64))
|
|
dims = [
|
|
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
|
|
(0,))
|
|
for d in shape
|
|
]
|
|
return concatenate(dims, dimension=0)
|
|
|
|
def _reshape_shape_rule(operand, *, new_sizes, dimensions):
|
|
if not all(d >= 0 for d in new_sizes):
|
|
msg = 'reshape new_sizes must all be positive, got {}.'
|
|
raise TypeError(msg.format(new_sizes))
|
|
# TODO(necula): re-enable this check
|
|
operand_size = math.prod(np.shape(operand))
|
|
new_size = math.prod(new_sizes)
|
|
if (not config.dynamic_shapes.value and
|
|
not operand_size == new_size):
|
|
msg = (f"reshape total size must be unchanged, got new_sizes {new_sizes} "
|
|
f"(of total size {new_size}) for shape {np.shape(operand)} "
|
|
f"(of total size {operand_size}).")
|
|
raise TypeError(msg)
|
|
if dimensions is not None:
|
|
if set(dimensions) != set(range(np.ndim(operand))):
|
|
msg = ('reshape dimensions must be a permutation of operand dimensions, '
|
|
'got dimensions {} for shape {}.')
|
|
raise TypeError(msg.format(dimensions, np.shape(operand)))
|
|
return tuple(new_sizes)
|
|
|
|
def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions):
|
|
if not dyn_shape:
|
|
out_aval, effects = reshape_p.abstract_eval(
|
|
operand.aval, new_sizes=new_sizes, dimensions=dimensions)
|
|
return [out_aval], effects
|
|
else:
|
|
# TODO(mattjj, necula): perform more checks like _reshape_shape_rule
|
|
out_shape = _merge_dyn_shape(new_sizes, dyn_shape)
|
|
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
|
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
|
|
operand.aval.weak_type)
|
|
return [out_aval], core.no_effects
|
|
|
|
|
|
def _reshape_dtype_rule(operand, *, new_sizes, dimensions):
|
|
return operand.dtype
|
|
|
|
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions):
|
|
assert ad.is_undefined_primal(operand)
|
|
if dimensions is None:
|
|
return [reshape(t, operand.aval.shape)]
|
|
else:
|
|
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)),
|
|
np.argsort(dimensions))]
|
|
|
|
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
operand = batching.moveaxis(operand, bdim, 0)
|
|
if dimensions is not None:
|
|
dimensions = (0,) + tuple(np.add(1, dimensions))
|
|
return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
|
|
|
|
|
|
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
|
|
aval_out, = ctx.avals_out
|
|
if dimensions is not None:
|
|
x = hlo.transpose(x, mlir.dense_int_array(dimensions))
|
|
if dyn_shape:
|
|
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
|
|
return [mlir.reshape(ctx, x, aval_out)]
|
|
|
|
def _reshape_staging_rule(
|
|
trace, x, *dyn, new_sizes, dimensions):
|
|
params = dict(new_sizes=new_sizes, dimensions=dimensions)
|
|
if not dyn:
|
|
return trace.default_process_primitive(reshape_p, (x,), params)
|
|
av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type)
|
|
return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params)
|
|
|
|
reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
|
|
'reshape')
|
|
ad.deflinear2(reshape_p, _reshape_transpose_rule)
|
|
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
|
|
mlir.register_lowering(reshape_p, _reshape_lower)
|
|
core.custom_typechecks[reshape_p] = _reshape_typecheck_rule
|
|
pe.custom_staging_rules[reshape_p] = _reshape_staging_rule
|
|
|
|
|
|
def _rev_shape_rule(operand, *, dimensions):
|
|
_check_shapelike('rev', 'dimensions', dimensions)
|
|
if len(set(dimensions)) != len(dimensions):
|
|
msg = 'rev dimensions must be unique, got {}.'
|
|
raise TypeError(msg.format(dimensions))
|
|
if dimensions and not _max(dimensions) < operand.ndim:
|
|
msg = ('rev dimensions must all be less than operand ndim, got dimensions '
|
|
'{} for operand ndim {}.')
|
|
raise TypeError(msg.format(dimensions, operand.ndim))
|
|
return operand.shape
|
|
|
|
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
|
|
return rev(operand, new_dimensions), bdim
|
|
|
|
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
|
|
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
|
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
|
|
|
def _rev_lower(ctx, x, *, dimensions):
|
|
return [hlo.reverse(x, mlir.dense_int_array(dimensions))]
|
|
mlir.register_lowering(rev_p, _rev_lower)
|
|
|
|
|
|
def _transpose_shape_rule(operand, *, permutation):
|
|
if not isinstance(permutation, (tuple, list, np.ndarray)):
|
|
msg = "transpose permutation must be a tuple/list/ndarray, got {}."
|
|
raise TypeError(msg.format(type(permutation)))
|
|
if tuple(sorted(permutation)) != tuple(range(operand.ndim)):
|
|
msg = ("transpose permutation isn't a permutation of operand dimensions, "
|
|
"got permutation {} for operand shape {}.")
|
|
raise TypeError(msg.format(permutation, operand.shape))
|
|
return tuple(operand.shape[old_idx] for old_idx in permutation)
|
|
|
|
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
|
|
perm = (stack_dim,) + tuple(i if i < stack_dim else i+1 for i in permutation)
|
|
if isinstance(bdim, RaggedAxis):
|
|
res_bdim = batching.transpose_ragged_axes(bdim.move_stacked_axis(0), perm)
|
|
else:
|
|
res_bdim = 0
|
|
return transpose(operand, perm), res_bdim
|
|
|
|
def _transpose_lower(ctx, x, *, permutation):
|
|
aval_out, = ctx.avals_out
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
|
elt_shape = aval_out.dtype._rules.physical_element_aval(
|
|
aval_out.dtype).shape
|
|
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
|
|
permutation = [*permutation, *trailing_dims]
|
|
return [hlo.transpose(x, mlir.dense_int_array(permutation))]
|
|
|
|
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
|
'transpose')
|
|
ad.deflinear2(transpose_p,
|
|
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
|
|
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
|
mlir.register_lowering(transpose_p, _transpose_lower)
|
|
pe.def_trivial_padding(transpose_p)
|
|
|
|
|
|
def _select_shape_rule(which, *cases):
|
|
if len(cases) == 0:
|
|
raise TypeError("select must have at least one case")
|
|
if any(case.shape != cases[0].shape for case in cases[1:]):
|
|
msg = "select cases must have the same shapes, got [{}]."
|
|
raise TypeError(msg.format(", ".join([str(c.shape) for c in cases])))
|
|
if which.shape and which.shape != cases[0].shape:
|
|
msg = ("select `which` must be scalar or have the same shape as cases, "
|
|
"got `which` shape {} but case shape {}.")
|
|
raise TypeError(msg.format(which.shape, cases[0].shape))
|
|
return cases[0].shape
|
|
|
|
def _select_dtype_rule(which, *cases):
|
|
check_same_dtypes("select", *cases)
|
|
if (not dtypes.issubdtype(which.dtype, np.bool_) and
|
|
not dtypes.issubdtype(which.dtype, np.integer)):
|
|
raise TypeError("select `which` must be boolean or integer type, got "
|
|
f"{which.dtype}.")
|
|
if dtypes.issubdtype(which.dtype, np.bool_) and len(cases) > 2:
|
|
raise TypeError("select with boolean `which` cannot have > 2 cases.")
|
|
return cases[0].dtype
|
|
|
|
def _select_weak_type_rule(which, *cases):
|
|
return all(c.weak_type for c in cases)
|
|
|
|
def _select_transpose_rule(t, which, *cases):
|
|
assert not ad.is_undefined_primal(which)
|
|
if type(t) is ad_util.Zero:
|
|
return [None] + [ad_util.Zero(c.aval) if ad.is_undefined_primal(c) else None
|
|
for c in cases]
|
|
else:
|
|
zeros = full_like(t, 0)
|
|
if dtypes.dtype(which) == np.dtype(np.bool_):
|
|
ct0 = select(which, zeros, t) if ad.is_undefined_primal(cases[0]) else None
|
|
ct1 = select(which, t, zeros) if ad.is_undefined_primal(cases[1]) else None
|
|
return (None, ct0, ct1)
|
|
else:
|
|
return [None] + [
|
|
select(eq(which, _const(which, i)), t, zeros)
|
|
if ad.is_undefined_primal(case) else None for i, case in enumerate(cases)
|
|
]
|
|
|
|
def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
|
which, *cases = batched_args
|
|
which_bdim, *case_bdims = batch_dims
|
|
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
|
if i is not None)
|
|
|
|
# avoid transposes and some broadcasts in special cases
|
|
if all(which_bdim == bdim for bdim in case_bdims):
|
|
if np.shape(which) == np.shape(cases[0]):
|
|
return select_n(which, *cases), which_bdim
|
|
else:
|
|
# vmapped function had a scalar which with nonscalar args
|
|
assert np.ndim(which) == 1
|
|
which = broadcast_in_dim(which, cases[0].shape, [which_bdim])
|
|
return select_n(which, *cases), which_bdim
|
|
elif np.ndim(which) == 0 and all(bdim is not None for bdim in case_bdims):
|
|
if all(case_bdims[0] == bdim for bdim in case_bdims[1:]):
|
|
return select_n(which, *cases), case_bdims[0]
|
|
elif all(np.shape(cases[0]) == np.shape(c) for c in cases):
|
|
bdim = case_bdims[0]
|
|
other_cases = [batching.moveaxis(c, c_bdim, bdim)
|
|
for c, c_bdim in zip(cases[1:], case_bdims[1:])]
|
|
return select_n(which, cases[0], *other_cases), bdim
|
|
|
|
which = (batching.bdim_at_front(which, which_bdim, size) if np.shape(which)
|
|
else which)
|
|
if not all(() == np.shape(c) for c in cases):
|
|
cases = [batching.bdim_at_front(c, bdim, size)
|
|
for c, bdim in zip(cases, case_bdims)]
|
|
assert all(np.shape(cases[0]) == np.shape(c) for c in cases[1:])
|
|
if 0 < np.ndim(which) < np.ndim(cases[0]):
|
|
# vmapped function had a scalar which with nonscalar args
|
|
assert np.ndim(which) == 1
|
|
which = broadcast_in_dim(which, cases[0].shape, [0])
|
|
if np.ndim(which) > np.ndim(cases[0]):
|
|
assert np.ndim(cases[0]) == 0
|
|
cases = [broadcast(c, which.shape) for c in cases]
|
|
return select_n(which, *cases), 0
|
|
|
|
def _select_jvp(primals, tangents):
|
|
which, *case_primals = primals
|
|
case_tangents = tangents[1:]
|
|
out = select_n(which, *case_primals)
|
|
if all(type(t) is ad_util.Zero for t in case_tangents):
|
|
out_dot = ad_util.Zero(case_tangents[0].aval)
|
|
else:
|
|
z = _zeros(next(t for t in case_tangents if type(t) is not ad_util.Zero))
|
|
case_tangents = [z if type(t) is ad_util.Zero else t for t in case_tangents]
|
|
out_dot = select_n(which, *case_tangents)
|
|
return out, out_dot
|
|
|
|
def _select_hlo_lowering_opaque(ctx, which, *cases):
|
|
avals_in = ctx.avals_in
|
|
aval_out, = ctx.avals_out
|
|
assert all(aval_case == aval_out for aval_case in avals_in[1:])
|
|
select_lower = _select_hlo_lowering
|
|
|
|
physical_aval_out = core.physical_aval(aval_out)
|
|
physical_avals_cases = [physical_aval_out] * (len(avals_in) - 1)
|
|
aval_which = avals_in[0]
|
|
aval_which_bcast = physical_aval_out.update(dtype=aval_which.dtype)
|
|
assert aval_which_bcast.shape[:aval_which.ndim] == aval_which.shape
|
|
|
|
bcast_dims = list(range(aval_which.ndim))
|
|
which_bcast = mlir.broadcast_in_dim(
|
|
ctx, which, aval_which_bcast, broadcast_dimensions=bcast_dims)
|
|
|
|
return mlir.delegate_lowering(
|
|
ctx, select_lower, which_bcast, *cases,
|
|
avals_in=[aval_which_bcast, *physical_avals_cases],
|
|
avals_out=[physical_aval_out])[0]
|
|
|
|
|
|
def _select_hlo_lowering(ctx, which, *cases):
|
|
which_aval = ctx.avals_in[0]
|
|
aval_out, = ctx.avals_out
|
|
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
|
return [_select_hlo_lowering_opaque(ctx, which, *cases)]
|
|
|
|
if which_aval.dtype == np.dtype(np.bool_):
|
|
assert len(cases) <= 2
|
|
if len(cases) == 1: return cases
|
|
return [hlo.select(which, cases[1], cases[0])]
|
|
|
|
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
|
|
compare_type = 'SIGNED'
|
|
else:
|
|
compare_type = 'UNSIGNED'
|
|
lt = 'LT'
|
|
|
|
def _select(offset, cases):
|
|
assert len(cases) > 0
|
|
if len(cases) == 1:
|
|
return cases[0]
|
|
mid = len(cases) // 2
|
|
pred = mlir.compare_hlo(which,
|
|
mlir.full_like_aval(ctx, offset + mid, which_aval),
|
|
lt, compare_type)
|
|
return hlo.select(pred, _select(offset, cases[:mid]),
|
|
_select(offset + mid, cases[mid:]))
|
|
|
|
return [_select(0, cases)]
|
|
|
|
select_n_p = standard_primitive(
|
|
_select_shape_rule, _select_dtype_rule, 'select_n',
|
|
weak_type_rule=_select_weak_type_rule)
|
|
ad.primitive_jvps[select_n_p] = _select_jvp
|
|
ad.primitive_transposes[select_n_p] = _select_transpose_rule
|
|
batching.primitive_batchers[select_n_p] = _select_batch_rule
|
|
mlir.register_lowering(select_n_p, _select_hlo_lowering)
|
|
pe.def_trivial_padding(select_n_p)
|
|
|
|
|
|
def _reduce_shape_rule(*avals, computation, jaxpr, dimensions):
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
if any(arg.shape != () for arg in init_val_avals):
|
|
init_val_shapes = [a.shape for a in init_val_avals]
|
|
raise ValueError(f'reduce found non-scalar initial value: {init_val_shapes}')
|
|
return [tuple(np.delete(op.shape, dimensions)) for op in operand_avals]
|
|
|
|
def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions):
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals]
|
|
init_val_dtypes = [dtypes.canonicalize_dtype(init.dtype) for init in init_val_avals]
|
|
if operand_dtypes != init_val_dtypes:
|
|
raise TypeError(
|
|
"reduce operand dtypes should match corresponding initial value dtypes, "
|
|
f"got operands={operand_avals} and initial_values={init_val_avals}")
|
|
return operand_dtypes
|
|
|
|
def _reduce_weak_type_rule(*avals, computation, jaxpr, dimensions):
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
return [op.weak_type and init_val.weak_type
|
|
for op, init_val in safe_zip(operand_avals, init_val_avals)]
|
|
|
|
def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
|
|
dimensions):
|
|
# TODO(mattjj,frostig): use batch_jaxpr, delete computation (assumes poly??)
|
|
num_operands = len(batched_args) // 2
|
|
operands, init_values = split_list(batched_args, [num_operands])
|
|
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
|
|
if all(init_value_bdim is batching.not_mapped
|
|
for init_value_bdim in init_value_bdims):
|
|
size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
|
|
if ax is not None)
|
|
operands = [batching.bdim_at_front(arg, bdim, size)
|
|
for arg, bdim in zip(operands, operand_bdims)]
|
|
new_dimensions = [d + 1 for d in dimensions]
|
|
new_operand_bdims = [0] * num_operands
|
|
return reduce_p.bind(*(operands + init_values),
|
|
computation=computation,
|
|
dimensions=tuple(new_dimensions),
|
|
jaxpr=jaxpr), new_operand_bdims
|
|
else:
|
|
raise NotImplementedError # loop and stack
|
|
|
|
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
|
|
input_shape = np.array(primals[0].shape, dtype=int)
|
|
|
|
n = np.prod(input_shape[list(axes)])
|
|
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
|
|
|
# Move the reduced axes to the front, and flatten them to 1D.
|
|
permutation = axes + tuple(non_axes)
|
|
new_shape = (n,) + tuple(input_shape[non_axes])
|
|
primals = tuple(reshape(x, new_shape, permutation) for x in primals)
|
|
tangents = tuple(reshape(t, new_shape, permutation) for t in tangents)
|
|
|
|
for d in range(len(non_axes) + 1):
|
|
reducer = api.vmap(reducer)
|
|
def _reduce_tree(*xs, axis=0):
|
|
"""Reduce by repeatedly splitting the array and multiplying."""
|
|
while xs[0].shape[axis] > 1:
|
|
n = xs[0].shape[axis]
|
|
n1 = (n + 1) // 2
|
|
n2 = n - n1
|
|
xs1 = [slicing.slice_in_dim(x, 0, n1) for x in xs]
|
|
xs2 = [slicing.slice_in_dim(x, n1, None) for x in xs]
|
|
if n2 != n1:
|
|
paddings = [(0, 0, 0)] * len(xs[0].shape)
|
|
paddings[axis] = (0, 1, 0)
|
|
xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)]
|
|
xs = reducer(*(xs1 + xs2))
|
|
if xs[0].shape[axis] == 0:
|
|
return [full(input_shape[non_axes], i) for i in init_values]
|
|
return tuple(squeeze(x, (axis,)) for x in xs)
|
|
|
|
return api.jvp(_reduce_tree, primals, tangents)
|
|
|
|
def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions):
|
|
primal_xs, init_values = split_list(primals, [len(primals) // 2])
|
|
tangent_xs, tangent_init = split_list(tangents, [len(tangents) // 2])
|
|
# This test may be too strict, if a value is actually zero but we cannot prove
|
|
# it is symbolically zero.
|
|
if any(type(t) is not ad_util.Zero for t in tangent_init):
|
|
raise NotImplementedError(
|
|
"Gradient of general lax.reduce with non-zero tangents for "
|
|
"initial values to reduction not implemented")
|
|
reducer = core.jaxpr_as_fun(jaxpr)
|
|
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
|
|
|
|
reduce_p = core.Primitive('reduce')
|
|
reduce_p.multiple_results = True
|
|
reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
|
|
reduce_p.def_abstract_eval(
|
|
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
|
|
_reduce_dtype_rule, _reduce_weak_type_rule))
|
|
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
|
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
|
|
|
|
def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
|
|
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
|
operands, init_values = util.split_list(values, [len(values) // 2])
|
|
init_value_avals = ctx.avals_in[len(values) // 2:]
|
|
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
|
operands, init_values, mlir.dense_int_array(dimensions))
|
|
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
|
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
|
with ir.InsertionPoint(reducer):
|
|
name_stack = source_info_util.new_name_stack()
|
|
if jaxpr.effects:
|
|
raise NotImplementedError('Cannot lower effectful `reduce`.')
|
|
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
|
|
name_stack, mlir.TokenSet(),
|
|
jaxpr.consts,
|
|
*reducer.arguments,
|
|
dim_var_values=ctx.dim_var_values)
|
|
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
|
return op.results
|
|
|
|
mlir.register_lowering(reduce_p, _reduce_lower)
|
|
|
|
|
|
def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
|
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
|
|
"of number.".format(name, dtype_to_string(operand.dtype)))
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
def _reduce_sum_shape_rule(operand, *, axes):
|
|
return _reduce_op_shape_rule(operand, axes=axes)
|
|
|
|
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
|
assert ad.is_undefined_primal(operand)
|
|
input_shape = operand.aval.shape
|
|
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
|
|
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
|
|
assert result.shape == input_shape
|
|
return [result]
|
|
|
|
def _reducer_padding(traceable, ident, in_avals, out_avals, operand, *, axes):
|
|
del out_avals
|
|
aval, = in_avals
|
|
padded_axes = [(i, d.val) for i, d in enumerate(aval.shape)
|
|
if isinstance(d, pe.BoundedAxisSize)]
|
|
operand_ = _replace_masked_values(operand, ident(aval.dtype), padded_axes)
|
|
return [traceable(operand_, axes)]
|
|
|
|
def _replace_masked_values(x, val, padded_axes):
|
|
if not padded_axes: return x
|
|
dtype = dtypes._scalar_type_to_dtype(int)
|
|
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
|
|
return select(_reduce(operator.and_, masks), x, full_like(x, val))
|
|
|
|
|
|
reduce_sum_p = standard_primitive(
|
|
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
|
'reduce_sum')
|
|
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
|
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
|
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
|
|
_get_sum_identity)
|
|
|
|
|
|
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
|
del input_shape # Unused.
|
|
if len(axes) != len(set(axes)):
|
|
raise ValueError(f"duplicate value in 'axes' of reduction: {axes}")
|
|
if not all(0 <= a < operand.ndim for a in axes):
|
|
raise ValueError(f"reduction axes {axes} contains out-of-bounds indices for {operand}.")
|
|
axes = frozenset(axes)
|
|
return tuple(d for i, d in enumerate(operand.shape) if i not in axes)
|
|
|
|
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
|
reducer = lambda x, y: [mul(x, y)]
|
|
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
|
|
primals, tangents, axes)
|
|
return primals_out[0], tangents_out[0]
|
|
|
|
reduce_prod_p = standard_primitive(
|
|
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
|
'reduce_prod')
|
|
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
|
batching.defreducer(reduce_prod_p, _get_prod_identity)
|
|
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
|
|
_get_prod_identity)
|
|
|
|
|
|
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
|
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
|
|
# locations in a single pass (rather than comparing equality) and use a
|
|
# gather, and/or even push along the chosen elements of g (b/112040122)
|
|
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
|
location_indicators = convert_element_type(
|
|
_eq_meet(operand, reshape(ans, shape)), g.dtype)
|
|
counts = _reduce_sum(location_indicators, axes)
|
|
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
|
|
|
|
|
|
reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
|
'reduce_max')
|
|
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
|
batching.defreducer(reduce_max_p, _get_max_identity)
|
|
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
|
|
_get_max_identity)
|
|
|
|
|
|
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
|
'reduce_min')
|
|
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
|
batching.defreducer(reduce_min_p, _get_min_identity)
|
|
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min,
|
|
_get_min_identity)
|
|
|
|
|
|
def _argminmax_shape_rule(operand, *, axes, index_dtype):
|
|
axis, = axes
|
|
if not (0 <= axis < len(operand.shape)):
|
|
raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}")
|
|
if operand.shape[axis] < 1:
|
|
raise ValueError("argmin and argmax require non-empty reduced dimension. "
|
|
f"operand.shape={operand.shape} {axis=}")
|
|
return tuple(np.delete(operand.shape, axis))
|
|
|
|
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
|
if not dtypes.issubdtype(index_dtype, np.integer):
|
|
raise TypeError("index_dtype must be an integer type, but got {}"
|
|
.format(dtype_to_string(index_dtype)))
|
|
return index_dtype
|
|
|
|
class _ArgMinMaxReducer:
|
|
|
|
def __init__(self, value_comparator):
|
|
self._value_comparator = value_comparator
|
|
|
|
def __repr__(self):
|
|
# Override the repr so that the metadata attached to the lowered op does not
|
|
# contain unstable function ids. This plays more nicely with computation
|
|
# fingerprint calculation in the compilation cache.
|
|
return f'_ArgMinMaxReducer({self._value_comparator.__name__})'
|
|
|
|
def __call__(self, op_val_index, acc_val_index):
|
|
op_val, op_index = op_val_index
|
|
acc_val, acc_index = acc_val_index
|
|
# Pick op_val if Lt (for argmin) or if NaN
|
|
pick_op_val = bitwise_or(self._value_comparator(op_val, acc_val),
|
|
ne(op_val, op_val))
|
|
# If x and y are not NaN and x = y, then pick the first
|
|
pick_op_index = bitwise_or(pick_op_val,
|
|
bitwise_and(eq(op_val, acc_val),
|
|
lt(op_index, acc_index)))
|
|
return (select(pick_op_val, op_val, acc_val),
|
|
select(pick_op_index, op_index, acc_index))
|
|
|
|
def _compute_argminmax(value_comparator, get_identity,
|
|
operand, *, index_dtype, axes):
|
|
# value_comparator is either lax.lt (for argmin) or lax.gt
|
|
# get_identity(operand.dtype) is inf for argmin or -inf for argmax
|
|
axis, = axes
|
|
indices = broadcasted_iota(index_dtype, np.shape(operand), axis)
|
|
res = reduce([operand, indices],
|
|
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
|
_ArgMinMaxReducer(value_comparator),
|
|
axes)
|
|
return res[1]
|
|
|
|
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
|
'argmin', weak_type_rule=_strip_weak_type)
|
|
batching.defreducer(argmin_p, _get_min_identity)
|
|
ad.defjvp_zero(argmin_p)
|
|
|
|
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
|
'argmax', weak_type_rule=_strip_weak_type)
|
|
batching.defreducer(argmax_p, _get_max_identity)
|
|
ad.defjvp_zero(argmax_p)
|
|
|
|
mlir.register_lowering(argmin_p, mlir.cache_lowering(mlir.lower_fun(
|
|
partial(_compute_argminmax, lt, _get_min_identity),
|
|
multiple_results=False)))
|
|
|
|
mlir.register_lowering(argmax_p, mlir.cache_lowering(mlir.lower_fun(
|
|
partial(_compute_argminmax, gt, _get_max_identity),
|
|
multiple_results=False)))
|
|
|
|
|
|
def _reduce_logical_shape_rule(operand, *, axes):
|
|
if operand.dtype != np.bool_ and not np.issubdtype(operand.dtype, np.integer):
|
|
raise TypeError(f"logical reduction requires operand dtype bool or int, got {operand.dtype}.")
|
|
return tuple(np.delete(operand.shape, axes))
|
|
|
|
reduce_or_p = standard_primitive(
|
|
_reduce_logical_shape_rule, _input_dtype, 'reduce_or',
|
|
weak_type_rule=_strip_weak_type)
|
|
batching.defreducer(reduce_or_p, _get_bitwise_or_identity)
|
|
|
|
|
|
reduce_and_p = standard_primitive(
|
|
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
|
|
weak_type_rule=_strip_weak_type)
|
|
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
|
|
|
|
|
|
reduce_xor_p = standard_primitive(
|
|
_reduce_logical_shape_rule, _input_dtype, 'reduce_xor',
|
|
weak_type_rule=_strip_weak_type)
|
|
batching.defreducer(reduce_xor_p, _get_bitwise_or_identity)
|
|
|
|
|
|
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
|
aval_out, = ctx.avals_out
|
|
dtype = aval_out.dtype
|
|
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
|
[mlir.ir_constant(unit_factory(aval_out.dtype))],
|
|
mlir.dense_int_array(axes))
|
|
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
|
|
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
|
with ir.InsertionPoint(reducer_region):
|
|
hlo.return_([reducer(*reducer_region.arguments)])
|
|
return op.results
|
|
|
|
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
|
|
_get_sum_identity))
|
|
mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, hlo.MulOp,
|
|
_get_prod_identity))
|
|
mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, hlo.OrOp,
|
|
_get_bitwise_or_identity))
|
|
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, hlo.AndOp,
|
|
_get_bitwise_and_identity))
|
|
mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, hlo.XorOp,
|
|
_get_bitwise_or_identity))
|
|
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_hlo,
|
|
_get_min_identity))
|
|
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_hlo,
|
|
_get_max_identity))
|
|
|
|
|
|
def _reduce_precision_shape_rule(operand, *, exponent_bits, mantissa_bits):
|
|
exponent_bits = operator.index(exponent_bits)
|
|
mantissa_bits = operator.index(mantissa_bits)
|
|
if exponent_bits < 1:
|
|
raise ValueError(f"reduce_precision: exponent_bits must be positive; got {exponent_bits}")
|
|
if mantissa_bits < 0:
|
|
raise ValueError(f"reduce_precision: mantissa_bits must be non-negative; got {mantissa_bits}")
|
|
return operand.shape
|
|
|
|
|
|
reduce_precision_p = standard_primitive(
|
|
_reduce_precision_shape_rule,
|
|
partial(unop_dtype_rule, _identity, _float, 'reduce_precision'),
|
|
name='reduce_precision')
|
|
ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)])
|
|
batching.defvectorized(reduce_precision_p)
|
|
|
|
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
|
aval_out, = ctx.avals_out
|
|
return [hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits),
|
|
mlir.i32_attr(mantissa_bits))]
|
|
|
|
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
|
|
|
|
|
_UINT_DTYPES = {
|
|
16: np.dtype(np.uint16),
|
|
32: np.dtype(np.uint32),
|
|
64: np.dtype(np.uint64),
|
|
}
|
|
|
|
_INT_DTYPES = {
|
|
16: np.dtype(np.int16),
|
|
32: np.dtype(np.int32),
|
|
64: np.dtype(np.int64),
|
|
}
|
|
|
|
|
|
def _sort_abstract_eval(*args, **kwargs):
|
|
args = tuple(raise_to_shaped(arg) for arg in args)
|
|
if any(arg.shape != args[0].shape for arg in args[1:]):
|
|
shapes = " ".join(str(a.shape) for a in args)
|
|
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
|
|
return args
|
|
|
|
|
|
def _canonicalize_float_for_sort(x):
|
|
# In the sort comparator, we are going to use a comparision operator where -0
|
|
# would be before 0, and -NaN and NaN appear at the beginning and end of the
|
|
# ordering. In this scheme, -0 would be before 0, and -NaN and NaN appear at
|
|
# the beginning and end of the ordering. This causes issues for stable
|
|
# sorts, so we avoid this by standardizing the representation of zeros
|
|
# and NaNs in the output.
|
|
|
|
result = select(eq(x, _zero(x)), _zeros(x), x)
|
|
with config.debug_nans(False):
|
|
result = select(_isnan(x), full_like(result, np.nan), result)
|
|
|
|
return result
|
|
|
|
|
|
# Default comparator that sorts the operands lexicographically on the
|
|
# first `num_keys` arguments.
|
|
# For floating point types, a total order is created where
|
|
# -infinity < ... < 0 < ... < infinity < NaN.
|
|
# 0.0 and -0.0 are treated as equivalent, as are all NaN representations.
|
|
# For complex types, the (real, imag) pairs are sorted lexicographically
|
|
# (following NumPy's semantics).
|
|
# This code adds complex-number support and lexicographic ordering to the algorithm from:
|
|
# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33
|
|
def _sort_lt_comparator(*operands, num_keys=1):
|
|
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
|
|
p = None
|
|
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
|
p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None
|
|
else lt_to_p.bind(xk, yk))
|
|
return p
|
|
|
|
# Similar to sort_lt_comparator, but implements less than or equal. Used by
|
|
# the searchsorted() implementation.
|
|
def _sort_le_comparator(*operands, num_keys=1):
|
|
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
|
|
p = None
|
|
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
|
p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None
|
|
else le_to_p.bind(xk, yk))
|
|
return p
|
|
|
|
def _operands_to_keys(*operands, num_keys=1):
|
|
assert len(operands) >= 2 and len(operands) % 2 == 0, operands
|
|
assert len(operands) // 2 >= num_keys, (operands, num_keys)
|
|
x_keys, y_keys = [], []
|
|
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
|
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
|
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
|
x_keys.extend([_canonicalize_float_for_sort(real(x)), _canonicalize_float_for_sort(imag(x))])
|
|
y_keys.extend([_canonicalize_float_for_sort(real(y)), _canonicalize_float_for_sort(imag(y))])
|
|
elif dtypes.issubdtype(x.dtype, np.floating):
|
|
x_keys.append(_canonicalize_float_for_sort(x))
|
|
y_keys.append(_canonicalize_float_for_sort(y))
|
|
else:
|
|
x_keys.append(x)
|
|
y_keys.append(y)
|
|
return x_keys, y_keys
|
|
|
|
|
|
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
|
|
shape = primals[0].shape
|
|
iotas = []
|
|
for dim, size in enumerate(shape):
|
|
dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64
|
|
iotas.append(broadcasted_iota(dtype, shape, dim))
|
|
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
|
|
is_stable=is_stable, num_keys=num_keys)
|
|
idx = tuple(primals[-1] if i == dimension else iotas[i]
|
|
for i in range(len(shape)))
|
|
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
|
|
return tuple(primals[:-1]), tangents_out
|
|
|
|
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
|
|
prototype_arg, new_bdim = next(
|
|
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
|
new_args = []
|
|
for arg, bdim in zip(batched_args, batch_dims):
|
|
if bdim is None:
|
|
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
|
|
new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
|
|
else:
|
|
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
|
new_dimension = dimension + (new_bdim <= dimension)
|
|
bdims = (new_bdim,) * len(new_args)
|
|
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),
|
|
bdims)
|
|
|
|
|
|
sort_p = Primitive('sort')
|
|
sort_p.multiple_results = True
|
|
sort_p.def_impl(partial(dispatch.apply_primitive, sort_p))
|
|
sort_p.def_abstract_eval(_sort_abstract_eval)
|
|
ad.primitive_jvps[sort_p] = _sort_jvp
|
|
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
|
|
|
|
|
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
|
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
|
sort = hlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
|
mlir.flatten_ir_values(operands),
|
|
dimension=mlir.i64_attr(dimension),
|
|
is_stable=ir.BoolAttr.get(is_stable))
|
|
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
|
|
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
|
|
comparator = sort.comparator.blocks.append(
|
|
*util.flatten(zip(scalar_types, scalar_types)))
|
|
with ir.InsertionPoint(comparator):
|
|
lower_comparator = mlir.lower_fun(partial(_sort_lt_comparator),
|
|
multiple_results=False)
|
|
sub_ctx = ctx.replace(primitive=None,
|
|
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
|
|
avals_out=[core.ShapedArray((), np.bool_)])
|
|
|
|
out = lower_comparator(sub_ctx, *comparator.arguments, num_keys=num_keys)
|
|
hlo.return_(mlir.flatten_ir_values(out))
|
|
return sort.results
|
|
|
|
mlir.register_lowering(sort_p, _sort_lower)
|
|
|
|
|
|
def _top_k_abstract_eval(operand, *, k):
|
|
if dtypes.issubdtype(operand.dtype, np.complexfloating):
|
|
raise ValueError("top_k is not compatible with complex inputs.")
|
|
if k < 0:
|
|
raise ValueError(f"k argument to top_k must be nonnegative, got {k}")
|
|
if len(operand.shape) == 0:
|
|
raise TypeError("top_k operand must have >= 1 dimension, got {}"
|
|
.format(operand.shape))
|
|
shape = list(operand.shape)
|
|
if shape[-1] < k:
|
|
msg = "k argument to top_k must be no larger than minor dimension; {} vs {}"
|
|
raise ValueError(msg.format(k, shape))
|
|
shape[-1] = k
|
|
return (operand.update(shape=shape, dtype=operand.dtype,
|
|
weak_type=operand.weak_type),
|
|
operand.update(shape=shape, dtype=np.dtype(np.int32)))
|
|
|
|
def _top_k_jvp(primals, tangents, *, k):
|
|
operand, = primals
|
|
tangent, = tangents
|
|
primals_out = top_k(operand, k)
|
|
if type(tangent) is ad_util.Zero:
|
|
tangent_out = ad_util.Zero.from_primal_value(primals_out[0])
|
|
else:
|
|
_, k_idxs = primals_out
|
|
idx_shape = k_idxs.shape
|
|
rank = len(idx_shape)
|
|
gather_index_shape = idx_shape + (1,)
|
|
gather_indices = []
|
|
for i in range(rank-1):
|
|
_iota = iota(k_idxs.dtype, idx_shape[i])
|
|
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
|
|
gather_indices.append(_iota)
|
|
gather_indices.append(reshape(k_idxs, gather_index_shape))
|
|
gather_indices = concatenate(gather_indices, dimension=rank)
|
|
slice_sizes = (1,) * rank
|
|
dnums = slicing.GatherDimensionNumbers(
|
|
offset_dims=(),
|
|
collapsed_slice_dims=tuple(range(rank)),
|
|
start_index_map=tuple(range(rank)))
|
|
tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes)
|
|
return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1]))
|
|
|
|
def _top_k_batch_rule(batched_args, batch_dims, *, k):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
if bdim == operand.ndim-1:
|
|
perm = np.arange(operand.ndim)
|
|
perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
|
|
top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
|
|
return (transpose(top_k_v, perm),
|
|
transpose(top_k_i, perm)), (bdim, bdim)
|
|
else:
|
|
return top_k(operand, k=k), (bdim, bdim)
|
|
|
|
top_k_p = Primitive('top_k')
|
|
top_k_p.multiple_results = True
|
|
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
|
|
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
|
def _top_k_lower(ctx, operand, k):
|
|
if core.is_constant_dim(k):
|
|
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
|
|
k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
|
|
out_values_aval, out_indices_aval, = ctx.avals_out
|
|
return mlir.custom_call(
|
|
"stablehlo.dynamic_top_k",
|
|
result_types=[mlir.aval_to_ir_type(out_values_aval),
|
|
mlir.aval_to_ir_type(out_indices_aval)],
|
|
operands=[operand, k_value]).results
|
|
|
|
mlir.register_lowering(top_k_p, _top_k_lower)
|
|
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
|
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
|
|
|
|
def _stop_gradient_jvp_rule(primals, tangents):
|
|
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
|
x, = primals
|
|
return stop_gradient(x), ad_util.Zero.from_primal_value(x)
|
|
|
|
def _stop_gradient_batch_rule(batched_args, batch_dims):
|
|
x, = batched_args
|
|
dim, = batch_dims
|
|
return stop_gradient(x), dim
|
|
|
|
ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
|
|
batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
|
|
pe.def_trivial_padding(ad_util.stop_gradient_p)
|
|
|
|
|
|
def create_token(_=None):
|
|
"""Creates an XLA token value with no preconditions for sequencing effects.
|
|
|
|
Experimental.
|
|
|
|
The argument is ignored. It exists for backward compatibility.
|
|
"""
|
|
return create_token_p.bind()
|
|
|
|
create_token_p = Primitive("create_token")
|
|
create_token_p.def_impl(partial(dispatch.apply_primitive, create_token_p))
|
|
create_token_p.def_abstract_eval(lambda *_: abstract_token)
|
|
|
|
def _create_token_lowering(ctx, *operands):
|
|
aval_out, = ctx.avals_out
|
|
return [hlo.create_token()]
|
|
mlir.register_lowering(create_token_p, _create_token_lowering)
|
|
|
|
|
|
def after_all(*operands):
|
|
"""Merges one or more XLA token values. Experimental.
|
|
|
|
Wraps the XLA AfterAll operator."""
|
|
return after_all_p.bind(*operands)
|
|
|
|
def _after_all_abstract_eval(*operands):
|
|
if any(x is not abstract_token for x in operands):
|
|
raise TypeError("Arguments to after_all must be tokens")
|
|
return abstract_token
|
|
|
|
|
|
after_all_p = Primitive("after_all")
|
|
after_all_p.def_impl(partial(dispatch.apply_primitive, after_all_p))
|
|
after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
|
|
|
def _after_all_lowering(ctx, *operands):
|
|
aval_out, = ctx.avals_out
|
|
return [hlo.after_all(operands)]
|
|
mlir.register_lowering(after_all_p, _after_all_lowering)
|
|
|
|
|
|
class InOutFeedEffect(effects.Effect):
|
|
pass
|
|
infeed_effect = InOutFeedEffect()
|
|
outfeed_effect = InOutFeedEffect()
|
|
|
|
|
|
def infeed(token, shape=None, partitions=None):
|
|
"""Consumes an infeed value of `shape` from the host. Experimental.
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
|
`partitions` may be specified inside a `sharded_jit` function.
|
|
"""
|
|
flat_shapes, treedef = tree_util.tree_flatten(shape)
|
|
for shape in flat_shapes:
|
|
if not isinstance(shape, ShapedArray):
|
|
raise TypeError("shape argument to infeed must be a pytree of "
|
|
"ShapedArray values, got {}".format(shape))
|
|
if partitions is not None:
|
|
# Always replicate token.
|
|
# We specifically use type() to raise an error for PartitionSpecs.
|
|
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
|
raise ValueError(f"'partitions' argument to infeed should be a tuple, "
|
|
f"got {partitions}")
|
|
partitions = partitions + (None,)
|
|
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes),
|
|
partitions=partitions)
|
|
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
|
|
|
|
def _infeed_abstract_eval(token, *, shapes, partitions):
|
|
if token is not abstract_token:
|
|
raise TypeError("First argument to infeed must be a token")
|
|
return (*shapes, abstract_token), {infeed_effect}
|
|
|
|
|
|
infeed_p = Primitive("infeed")
|
|
infeed_p.multiple_results = True
|
|
infeed_p.def_impl(partial(dispatch.apply_primitive, infeed_p))
|
|
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
|
|
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
|
|
|
|
|
def _infeed_lowering(ctx, token, *, shapes, partitions):
|
|
output_types = safe_map(mlir.aval_to_ir_type, ctx.avals_out[:-1])
|
|
flat_output_types = mlir.flatten_ir_types(output_types)
|
|
# TODO(phawkins): verify `shapes` have a major-to-minor layout.
|
|
layouts = ir.ArrayAttr.get([
|
|
ir.ArrayAttr.get(
|
|
[mlir.i64_attr(i)
|
|
for i in range(len(aval.shape) - 1, -1, -1)])
|
|
for aval in shapes
|
|
])
|
|
infeed = hlo.InfeedOp(
|
|
flat_output_types + [hlo.TokenType.get()],
|
|
token,
|
|
infeed_config=ir.StringAttr.get(''),
|
|
layout=layouts)
|
|
if partitions is not None:
|
|
mlir.set_sharding(infeed, xla.sharding_to_proto(partitions))
|
|
token = infeed.results[-1]
|
|
outs = infeed.results[:-1]
|
|
return mlir.unflatten_ir_values_like_types(outs, output_types) + [
|
|
token,
|
|
]
|
|
|
|
mlir.register_lowering(infeed_p, _infeed_lowering)
|
|
|
|
|
|
def outfeed(token, xs, partitions = None):
|
|
"""Outfeeds value `xs` to the host. Experimental.
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
|
`partitions` may be specified inside a `sharded_jit` or `pjit` function.
|
|
"""
|
|
if partitions is not None:
|
|
# We specifically use type() to raise an error for PartitionSpecs.
|
|
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
|
raise ValueError(f"'partitions' argument to outfeed should be a tuple, "
|
|
f"got {partitions}")
|
|
flat_xs, _ = tree_util.tree_flatten(xs)
|
|
return outfeed_p.bind(token, *flat_xs, partitions=partitions)
|
|
|
|
def _outfeed_abstract_eval(token, *xs, partitions):
|
|
if token is not abstract_token:
|
|
raise TypeError("First argument to outfeed must be a token")
|
|
return abstract_token, {outfeed_effect}
|
|
|
|
outfeed_p = Primitive("outfeed")
|
|
outfeed_p.def_impl(partial(dispatch.apply_primitive, outfeed_p))
|
|
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
|
|
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
|
|
|
|
|
def _outfeed_lowering(ctx, token, *xs, partitions):
|
|
outfeed = hlo.OutfeedOp(
|
|
mlir.flatten_ir_values(xs),
|
|
token,
|
|
outfeed_config=ir.StringAttr.get(''))
|
|
if partitions is not None:
|
|
mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions))
|
|
return outfeed.results
|
|
|
|
mlir.register_lowering(outfeed_p, _outfeed_lowering)
|
|
|
|
|
|
def rng_uniform(a, b, shape):
|
|
"""Stateful PRNG generator. Experimental and its use is discouraged.
|
|
|
|
Returns uniformly distributed random numbers in the range [a, b). If
|
|
b <= a, then the result is undefined, and different implementations may
|
|
return different results.
|
|
|
|
You should use jax.random for most purposes; this function exists only for
|
|
niche use cases with special performance requirements.
|
|
|
|
This API may be removed at any time.
|
|
"""
|
|
return rng_uniform_p.bind(a, b, shape=tuple(shape))
|
|
|
|
def _rng_uniform_abstract_eval(a, b, *, shape):
|
|
if a.dtype != b.dtype:
|
|
raise ValueError(
|
|
"Arguments to rng_uniform must have identical dtypes, got {} "
|
|
"and {}.".format(a.dtype, b.dtype))
|
|
if a.shape != () or b.shape != ():
|
|
raise ValueError(
|
|
"Arguments to rng_uniform must be scalars; got shapes {} and {}."
|
|
.format(a.shape, b.shape))
|
|
return a.update(shape=shape, dtype=a.dtype,
|
|
weak_type=(a.weak_type and b.weak_type))
|
|
|
|
rng_uniform_p = Primitive("rng_uniform")
|
|
rng_uniform_p.def_impl(partial(dispatch.apply_primitive, rng_uniform_p))
|
|
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
|
|
|
def _rng_uniform_lowering(ctx, a, b, *, shape):
|
|
aval_out, = ctx.avals_out
|
|
shape = mlir.ir_constant(np.array(aval_out.shape, np.int64))
|
|
return [hlo.rng(a, b, shape, hlo.RngDistributionAttr.get('UNIFORM'))]
|
|
|
|
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
|
|
|
|
|
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
|
|
del dtype, algorithm
|
|
return (key.shape, tuple(shape))
|
|
|
|
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
|
|
del shape, algorithm
|
|
return (key.dtype, dtype)
|
|
|
|
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
|
|
del shape, dtype, algorithm
|
|
return (key.weak_type, False)
|
|
|
|
RandomAlgorithm = xops.RandomAlgorithm
|
|
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[method-assign]
|
|
|
|
def _rng_algorithm(algorithm: RandomAlgorithm):
|
|
if algorithm == RandomAlgorithm.RNG_THREE_FRY:
|
|
return hlo.RngAlgorithmAttr.get("THREE_FRY")
|
|
elif algorithm == RandomAlgorithm.RNG_PHILOX:
|
|
return hlo.RngAlgorithmAttr.get("PHILOX")
|
|
elif algorithm == RandomAlgorithm.RNG_DEFAULT:
|
|
return hlo.RngAlgorithmAttr.get("DEFAULT")
|
|
else:
|
|
assert False
|
|
|
|
def _rng_bit_generator_lowering(
|
|
ctx, key, *, shape, dtype, algorithm):
|
|
key_type = ir.RankedTensorType(key.type)
|
|
key_shape, key_etype = key_type.shape, key_type.element_type
|
|
# While the RngBitGenerator HLO accepts a u64[2] key on all backends, we
|
|
# typically represent the key argument to this primitive as a u32[4] so as to
|
|
# sidestep issues with the jax_enable_x64=False configuration. As a result, we
|
|
# need to convert u32[4] -> u64[2] here in the translation rule. However, we
|
|
# also polymorphically allow a u64[2] for backward compatibility.
|
|
#
|
|
# Separately, xops.RngBitGenerator doesn't support generating u8 or
|
|
# u16, so we request u32 and truncate in that case.
|
|
u32_type = ir.IntegerType.get_unsigned(32)
|
|
u64_type = ir.IntegerType.get_unsigned(64)
|
|
assert ((key_shape == [4] and key_etype == u32_type) or
|
|
(key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype)
|
|
dtype = np.dtype(dtype)
|
|
etype = mlir.dtype_to_ir_type(dtype)
|
|
if dtype in (np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
|
|
np.dtype('uint64')):
|
|
rbg_etype = etype
|
|
rbg_dtype = dtype
|
|
else:
|
|
rbg_etype = u32_type
|
|
rbg_dtype = np.uint32
|
|
if key_etype == u32_type:
|
|
key = hlo.bitcast_convert(
|
|
ir.RankedTensorType.get([2], u64_type),
|
|
hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key))
|
|
algorithm_attr = _rng_algorithm(algorithm)
|
|
_, out_vals_aval = ctx.avals_out
|
|
if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out):
|
|
output_shape = mlir.shape_tensor(
|
|
mlir.eval_dynamic_shape(ctx, out_vals_aval.shape))
|
|
out_key, out_vals = mlir.custom_call(
|
|
"stablehlo.dynamic_rng_bit_generator",
|
|
result_types=[key.type,
|
|
mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))],
|
|
operands=[key, output_shape],
|
|
extra_attributes=dict(rng_algorithm=algorithm_attr)).results
|
|
else:
|
|
out_key, out_vals = hlo.RngBitGeneratorOp(
|
|
key.type,
|
|
ir.RankedTensorType.get(shape, rbg_etype),
|
|
algorithm_attr, key).results
|
|
if key_etype == u32_type:
|
|
out_key = hlo.reshape(
|
|
ir.RankedTensorType.get([4], u32_type),
|
|
hlo.bitcast_convert(
|
|
ir.RankedTensorType.get([2, 2], u32_type), out_key))
|
|
if rbg_etype != etype:
|
|
out_vals = hlo.convert(
|
|
ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype),
|
|
out_vals)
|
|
return [out_key, out_vals]
|
|
|
|
|
|
rng_bit_generator_p = Primitive("rng_bit_generator")
|
|
rng_bit_generator_p.multiple_results = True
|
|
rng_bit_generator_p.def_impl(
|
|
partial(dispatch.apply_primitive, rng_bit_generator_p))
|
|
rng_bit_generator_p.def_abstract_eval(
|
|
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
|
|
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
|
_rng_bit_generator_weak_type_rule))
|
|
mlir.register_lowering(rng_bit_generator_p,
|
|
_rng_bit_generator_lowering)
|
|
|
|
|
|
def _array_copy(arr: ArrayLike) -> Array:
|
|
return copy_p.bind(arr)
|
|
|
|
|
|
def _which_dim_sharded(s: PmapSharding) -> int | None:
|
|
sharded_dim = None
|
|
for i, s in enumerate(s.sharding_spec.sharding):
|
|
if isinstance(s, (pxla.Unstacked, pxla.Chunked)):
|
|
sharded_dim = i
|
|
break
|
|
return sharded_dim
|
|
|
|
|
|
def _identity_fn(x): return x
|
|
|
|
|
|
def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
|
|
axis_name, static_broadcasted_tuple, donate_tuple = api._shared_code_pmap(
|
|
_identity_fn, None, (), (), sharded_dim, sharded_dim)
|
|
p = api._prepare_pmap(
|
|
_identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple,
|
|
donate_tuple, None, None, None, args, kwargs)
|
|
out_flat = pxla.xla_pmap_impl(
|
|
p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name,
|
|
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
|
devices=p.devices, in_axes=p.in_axes_flat,
|
|
out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__,
|
|
donated_invars=p.donated_invars,
|
|
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
|
)
|
|
return tree_util.tree_unflatten(p.out_tree(), out_flat)
|
|
|
|
|
|
# TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a
|
|
# method on jax.Array so that we can bypass the XLA compilation here.
|
|
def _copy_impl(prim, *args, **kwargs):
|
|
a, = args
|
|
if isinstance(a, Array) and isinstance(a.sharding, PmapSharding):
|
|
sharded_dim = _which_dim_sharded(a.sharding)
|
|
if sharded_dim is None:
|
|
return dispatch.apply_primitive(prim, *args, **kwargs)
|
|
return _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs)
|
|
return dispatch.apply_primitive(prim, *args, **kwargs)
|
|
|
|
# The copy_p primitive exists for expressing making copies of runtime arrays.
|
|
# For that reason we don't simplify it out of jaxprs (e.g. for jit invariance).
|
|
# It's used in jnp.array(x, copy=True), which is the user-facing API.
|
|
copy_p = core.Primitive('copy')
|
|
copy_p.def_impl(partial(_copy_impl, copy_p))
|
|
copy_p.def_abstract_eval(lambda x: x)
|
|
mlir.register_lowering(copy_p, lambda ctx, x: [x])
|
|
ad.deflinear(copy_p, lambda t: [copy_p.bind(t)])
|
|
pe.def_trivial_padding(copy_p)
|
|
batching.defvectorized(copy_p)
|
|
def _propagate_mem_kind_copy(in_mem_kind):
|
|
return in_mem_kind
|
|
pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy
|
|
|
|
def rng_bit_generator(key, shape, dtype=np.uint32,
|
|
algorithm=RandomAlgorithm.RNG_DEFAULT):
|
|
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
|
|
|
|
Returns uniformly distributed random bits with the specified shape and dtype
|
|
(what is required to be an integer type) using the platform specific
|
|
default algorithm or the one specified.
|
|
|
|
It provides direct access to the RngBitGenerator primitive exposed by XLA
|
|
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
|
|
level API access.
|
|
|
|
Most users should use `jax.random` instead for a stable and more user
|
|
friendly API.
|
|
"""
|
|
shape = core.canonicalize_shape(shape)
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'),
|
|
np.dtype('uint32'), np.dtype('uint64')}:
|
|
raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}')
|
|
return tuple(
|
|
rng_bit_generator_p.bind(
|
|
key, shape=shape, dtype=dtype, algorithm=algorithm))
|
|
|
|
|
|
def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension):
|
|
if not dyn_shape:
|
|
# TODO(mattjj) Generalize shape_like checking to permit dynamic shapes
|
|
_check_shapelike("iota", "shape", shape)
|
|
if not any(dtypes.issubdtype(dtype, t) for t in _num):
|
|
msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
|
typename = dtype_to_string(dtype)
|
|
accepted_typenames = (t.__name__ for t in _num)
|
|
raise TypeError(msg.format(typename, ', '.join(accepted_typenames)))
|
|
if not 0 <= dimension < len(shape):
|
|
raise ValueError("iota dimension must be between 0 and len(shape), got "
|
|
f"{dimension=} for {shape=}")
|
|
if (not dyn_shape and
|
|
not any(isinstance(d, core.DArray) and
|
|
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
|
return ShapedArray(shape, dtype)
|
|
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
|
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
|
|
|
|
iota_p = Primitive('iota')
|
|
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
|
|
iota_p.def_abstract_eval(_iota_abstract_eval)
|
|
|
|
def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension):
|
|
params = dict(dtype=dtype, shape=shape, dimension=dimension)
|
|
if not dyn_shape:
|
|
return trace.default_process_primitive(iota_p, (), params)
|
|
aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
|
|
return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params)
|
|
pe.custom_staging_rules[iota_p] = _iota_staging_rule
|
|
|
|
def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension):
|
|
if not dyn_shape:
|
|
out_aval, effects = iota_p.abstract_eval(
|
|
dtype=dtype, shape=shape, dimension=dimension)
|
|
return [out_aval], effects
|
|
else:
|
|
out_shape = _merge_dyn_shape(shape, dyn_shape)
|
|
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
|
|
out_aval = core.DShapedArray(tuple(out_shape), dtype, False)
|
|
return [out_aval], core.no_effects
|
|
core.custom_typechecks[iota_p] = _iota_typecheck_rule
|
|
|
|
def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
|
|
del dtype
|
|
aval_out, = ctx.avals_out
|
|
if dyn_shape:
|
|
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
|
return [mlir.iota(ctx, aval_out, dimension=dimension)]
|
|
mlir.register_lowering(iota_p, _iota_lower)
|
|
|
|
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension):
|
|
(segment_lengths,), (ax,) = in_vals, in_dims
|
|
assert ax == 0
|
|
bound = segment_lengths.dtype.bound
|
|
ragged_axis, = (i for i, dim in enumerate(shape) if dim is None)
|
|
shape = (len(segment_lengths),) + _merge_dyn_shape(shape, (bound,))
|
|
iota = broadcasted_iota(dtype, shape, dimension+1)
|
|
return iota, batching.RaggedAxis(ax, ((ragged_axis+1, segment_lengths),))
|
|
batching.primitive_batchers[iota_p] = _iota_batching_rule
|
|
|
|
def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension):
|
|
out_aval, = out_avals
|
|
new_shape = []
|
|
new_dyn_shape = []
|
|
for d in out_aval.shape:
|
|
if type(d) is pe.BoundedAxisSize:
|
|
new_shape.append(d.bound)
|
|
elif type(d) is int:
|
|
new_shape.append(d)
|
|
else:
|
|
assert isinstance(d, core.Tracer)
|
|
new_shape.append(None)
|
|
new_dyn_shape.append(d)
|
|
return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape),
|
|
dtype=dtype, dimension=dimension)]
|
|
pe.padding_rules[iota_p] = _iota_padding_rule
|
|
|
|
|
|
### util
|
|
|
|
_ndim = np.ndim
|
|
|
|
|
|
def _dilate_shape(shape, dilation):
|
|
"""Utility function for computing the shape resulting from a dilation."""
|
|
if not np.all(np.greater(dilation, 0)):
|
|
msg = "All dilations must be positive, got {}."
|
|
raise TypeError(msg.format(dilation))
|
|
dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation)
|
|
return tuple(map(core.dilate_dim, shape, dilation))
|
|
|
|
def _ceil_divide(x1, x2):
|
|
return -np.floor_divide(np.negative(x1), x2)
|
|
|
|
|
|
class PaddingType(enum.Enum):
|
|
VALID = 1
|
|
SAME = 2
|
|
SAME_LOWER = 3
|
|
|
|
|
|
def padtype_to_pads(in_shape, window_shape, window_strides, padding):
|
|
"""Convert padding string to list of pairs of pad values."""
|
|
|
|
if isinstance(padding, str):
|
|
mapping = {
|
|
'VALID': PaddingType.VALID,
|
|
'SAME': PaddingType.SAME,
|
|
'SAME_LOWER': PaddingType.SAME_LOWER,
|
|
}
|
|
try:
|
|
padding = mapping[padding.upper()]
|
|
except KeyError as err:
|
|
msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}."
|
|
raise RuntimeError(msg.format(padding)) from err
|
|
|
|
if padding == PaddingType.SAME or padding == PaddingType.SAME_LOWER:
|
|
out_shape = _ceil_divide(in_shape, window_strides)
|
|
pad_sizes = (core.max_dim(d, 0)
|
|
for d in (out_shape - 1) * window_strides +
|
|
window_shape - in_shape)
|
|
if padding == PaddingType.SAME:
|
|
pads = [
|
|
(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes
|
|
]
|
|
else:
|
|
pads = [
|
|
(pad_size - pad_size // 2, pad_size // 2) for pad_size in pad_sizes
|
|
]
|
|
# Avoids verbose numpy scalars in jaxprs.
|
|
return [p.item() if isinstance(p, np.generic) else p for p in pads]
|
|
elif padding == PaddingType.VALID:
|
|
return [(0, 0)] * len(in_shape)
|
|
else:
|
|
msg = "Unknown padding type: {}."
|
|
raise TypeError(msg.format(padding))
|
|
|
|
|
|
# Map of lax function to equivalent jax.numpy function for use in error string below.
|
|
_JNP_FUNCTION_EQUIVALENTS = {
|
|
'abs': 'fabs',
|
|
'acos': 'arccos',
|
|
'acosh': 'arccosh',
|
|
'add': 'add',
|
|
'asin': 'arcsin',
|
|
'asinh': 'arcsinh',
|
|
'atan': 'arctan',
|
|
'atan2': 'arctan2',
|
|
'atanh': 'arctanh',
|
|
'bitwise_and': 'bitwise_and',
|
|
'bitwise_not': 'bitwise_not',
|
|
'bitwise_or': 'bitwise_or',
|
|
'bitwise_xor': 'bitwise_xor',
|
|
'cbrt': 'cbrt',
|
|
'ceil': 'ceil',
|
|
'concatenate': 'concatenate',
|
|
'cos': 'cos',
|
|
'cosh': 'cosh',
|
|
'div': 'divide',
|
|
'eq': 'equal',
|
|
'exp': 'exp',
|
|
'expm1': 'expm1',
|
|
'floor': 'floor',
|
|
'greater': 'greater',
|
|
'greater_equal': 'greater_equal',
|
|
'less': 'less',
|
|
'less_equal': 'less_equal',
|
|
'log': 'log',
|
|
'logical_and': 'logical_and',
|
|
'logical_not': 'logical_not',
|
|
'logical_or': 'logical_or',
|
|
'logical_xor': 'logical_xor',
|
|
'log1p': 'log1p',
|
|
'max': 'maximum',
|
|
'min': 'minimum',
|
|
'mul': 'multiply',
|
|
'ne': 'not_equal',
|
|
'neg': 'negative',
|
|
'nextafter': 'nextafter',
|
|
'pow': 'float_power',
|
|
'round': 'round',
|
|
'select': 'where',
|
|
'shift_left': 'left_shift',
|
|
'shift_right_logical': 'right_shift',
|
|
'shift_right_arithmetic': 'right_shift',
|
|
'sign': 'sign',
|
|
'sin': 'sin',
|
|
'sinh': 'sinh',
|
|
'sqrt': 'sqrt',
|
|
'sub': 'subtract',
|
|
'tan': 'tan',
|
|
'tanh': 'tanh'
|
|
}
|
|
|
|
def check_same_dtypes(name: str, *avals: core.UnshapedArray) -> None:
|
|
"""Check that dtypes agree, possibly ignoring float precision."""
|
|
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
|
# allows mixed floating point precision, but the HLO verifier often rejects it
|
|
if any(dtypes.issubdtype(aval.dtype, dtypes.extended) for aval in avals):
|
|
return # TODO(mattjj,frostig): do some checking, friend
|
|
if len(avals) < 2:
|
|
return
|
|
|
|
dtype = dtypes.canonicalize_dtype(avals[0].dtype)
|
|
if any(dtypes.canonicalize_dtype(aval.dtype) != dtype for aval in avals[1:]):
|
|
msg = "lax.{} requires arguments to have the same dtypes, got {}."
|
|
if name in _JNP_FUNCTION_EQUIVALENTS:
|
|
equiv = _JNP_FUNCTION_EQUIVALENTS[name]
|
|
msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
|
|
raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals)))
|
|
|
|
|
|
def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
|
|
"""Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints)."""
|
|
if not isinstance(obj, (tuple, list, np.ndarray)):
|
|
msg = "{} {} must be of type tuple/list/ndarray, got {}."
|
|
raise TypeError(msg.format(fun_name, arg_name, type(obj)))
|
|
# bool(obj) for an ndarray raises an error, so we check len
|
|
if not len(obj): # pylint: disable=g-explicit-length-test
|
|
return
|
|
if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and
|
|
any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)):
|
|
return # TODO(mattjj): handle more checks in the dynamic shape case
|
|
obj_arr = np.array(obj)
|
|
if obj_arr.ndim != 1:
|
|
msg = "{} {} must be 1-dimensional, got {}."
|
|
raise TypeError(msg.format(obj_arr.ndim))
|
|
try:
|
|
canonicalize_shape(obj_arr)
|
|
except TypeError as err:
|
|
msg = "{} {} must have every element be an integer type, got {}."
|
|
raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err
|
|
lower_bound, bound_error = (
|
|
(1, "strictly positive") if non_zero_shape else (0, "nonnegative"))
|
|
if not all(d >= lower_bound for d in obj_arr):
|
|
msg = "{} {} must have every element be {}, got {}."
|
|
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
|
|
|
|
|
|
def _const(example, val):
|
|
dtype = _dtype(example)
|
|
if dtypes.is_python_scalar(example):
|
|
val = dtypes.scalar_type_of(example)(val)
|
|
return val if dtype == _dtype(val) else np.array(val, dtype)
|
|
return np.array(val, dtype)
|
|
|
|
_zeros: Callable = partial(full_like, fill_value=0)
|
|
_zero: Callable = partial(full_like, shape=(), fill_value=0)
|
|
_ones: Callable = partial(full_like, fill_value=1)
|
|
_one: Callable = partial(full_like, shape=(), fill_value=1)
|
|
_twos: Callable = partial(full_like, fill_value=2)
|
|
_two: Callable = partial(full_like, shape=(), fill_value=2)
|
|
|
|
dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
|
_dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
|
|
|
def _isnan(x: ArrayLike) -> Array:
|
|
return ne(x, x)
|
|
|
|
def _iscomplex(x) -> bool:
|
|
return dtypes.issubdtype(_dtype(x), np.complexfloating)
|
|
|
|
|
|
def ranges_like(*xs):
|
|
start = 0
|
|
for x in xs:
|
|
x_len = len(x)
|
|
yield range(start, start + x_len)
|
|
start += x_len
|
|
|
|
|
|
def remaining(original, *removed_lists):
|
|
removed = set(itertools.chain(*removed_lists))
|
|
return [i for i in original if i not in removed]
|
|
|
|
|
|
def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None:
|
|
"""Turns an API precision specification into a pair of enumeration values.
|
|
|
|
The API can take the precision as a string, or int, and either as a single
|
|
value to apply to both operands, or as a sequence of two values.
|
|
"""
|
|
if precision is None:
|
|
if config.default_matmul_precision.value is None:
|
|
return None
|
|
try:
|
|
return (
|
|
Precision(config.default_matmul_precision.value),
|
|
Precision(config.default_matmul_precision.value),
|
|
)
|
|
except TypeError:
|
|
raise ValueError(
|
|
"jax_default_matmul_precision flag must be set to None or a value in "
|
|
f"{list(_precision_strings)}, but got {config.default_matmul_precision.value}"
|
|
) from None
|
|
elif isinstance(precision, str) and precision in _precision_strings:
|
|
return Precision(precision), Precision(precision)
|
|
elif isinstance(precision, Precision):
|
|
return precision, precision
|
|
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
|
all(isinstance(p, Precision) for p in precision)):
|
|
return type_cast(tuple[Precision, Precision], precision)
|
|
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
|
all(isinstance(s, str) for s in precision)):
|
|
s1, s2 = precision
|
|
p1 = type_cast(tuple[Precision, Precision], canonicalize_precision(s1))[0]
|
|
p2 = type_cast(tuple[Precision, Precision], canonicalize_precision(s2))[0]
|
|
return (p1, p2)
|
|
else:
|
|
raise ValueError(
|
|
f"Precision argument must be None, a string in {list(_precision_strings)}, "
|
|
"a lax.Precision value or a tuple of two lax.Precision values or "
|
|
f"strings; got {precision}.")
|
|
|
|
def _balanced_eq(x, z, y):
|
|
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
|
|
select(_eq_meet(y, z), _twos(z), _ones(z)))
|
|
|
|
|
|
def _eq_meet(a, b):
|
|
a_dtype, b_dtype = _dtype(a), _dtype(b)
|
|
if a_dtype != b_dtype:
|
|
higher_dtype = dtypes.promote_types(a_dtype, b_dtype)
|
|
if higher_dtype == a_dtype:
|
|
a = convert_element_type(a, b_dtype)
|
|
else:
|
|
b = convert_element_type(b, a_dtype)
|
|
return eq(a, b)
|
|
|
|
|
|
def _abstractify(x):
|
|
return raise_to_shaped(core.get_aval(x))
|
|
|
|
|
|
def empty(dtype):
|
|
return empty_p.bind(dtype=dtype)
|
|
empty_p = core.Primitive('empty')
|
|
empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
|
|
def _empty_lower(ctx, *, dtype):
|
|
dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
|
|
phys_aval = core.physical_aval(core.ShapedArray((), dtype))
|
|
return mlir.ir_constant(np.zeros(phys_aval.shape, phys_aval.dtype)),
|
|
mlir.register_lowering(empty_p, _empty_lower)
|
|
|
|
|
|
tie_p = core.Primitive('tie')
|
|
tie_p.def_impl(lambda x, y: y)
|
|
tie_p.def_abstract_eval(lambda x, y: y)
|
|
mlir.register_lowering(tie_p, lambda ctx, x, y: [y])
|
|
ad.primitive_jvps[tie_p] = \
|
|
lambda primals, tangents: (tie_p.bind(*primals), tangents[-1])
|
|
ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct]
|
|
pe.def_trivial_padding(tie_p)
|
|
batching.defvectorized(tie_p)
|
|
|
|
|
|
class BIntRules:
|
|
@staticmethod
|
|
def physical_element_aval(dtype) -> core.ShapedArray:
|
|
return core.ShapedArray((), np.dtype('int32'))
|
|
|
|
@staticmethod
|
|
def result_handler(sticky_device, aval):
|
|
def handler(_, buf):
|
|
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
|
return core.DArray(aval, buf)
|
|
return handler
|
|
|
|
@staticmethod
|
|
def global_sharded_result_handler(aval, out_sharding, committed):
|
|
phys_aval = core.physical_aval(aval)
|
|
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
|
|
|
if not dispatch.is_single_device_sharding(out_sharding):
|
|
raise NotImplementedError # TODO(mattjj)
|
|
else:
|
|
phys_sharding = out_sharding
|
|
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
|
|
|
def handler(bufs):
|
|
return core.DArray(aval, phys_handler(bufs))
|
|
return handler
|
|
|
|
@staticmethod
|
|
def convert_from(bint_dtype, other_dtype) -> bool:
|
|
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
|
|
|
@staticmethod
|
|
def convert_to(other_dtype, bint_dtype) -> bool:
|
|
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
|
|
|
|
|
core.bint._rules = BIntRules
|
|
|
|
|
|
def optimization_barrier(operand, /):
|
|
"""Prevents the compiler from moving operations across the barrier.
|
|
|
|
Optimization barriers have a number of possible uses:
|
|
|
|
* An optimization barrier ensures that all inputs are evaluated before any
|
|
operators that depend on the barrier's outputs. This can be used to enforce
|
|
a particular order of operations.
|
|
* An optimization barrier prevents common subexpression elimination. This is
|
|
used by JAX to implement rematerialization.
|
|
* Optimization barriers prevent compiler fusions. That is, operations before
|
|
the barrier may not be fused into the same kernel as operations after the
|
|
barrier by the compiler.
|
|
|
|
JAX does not define derivative or batching rules for an optimization barrier.
|
|
|
|
Optimization barriers have no effect outside a compiled function.
|
|
|
|
Args:
|
|
operand: a pytree of JAX values.
|
|
|
|
Returns:
|
|
A pytree of JAX values, with the same structure and contents as ``operand``.
|
|
|
|
Examples:
|
|
Prevents common-subexpression elimination between the two calls to `sin`:
|
|
|
|
>>> def f(x):
|
|
... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
|
|
>>> jax.jit(f)(0.)
|
|
Array(0., dtype=float32, weak_type=True)
|
|
"""
|
|
flat_args, treedef = tree_util.tree_flatten(operand)
|
|
return tree_util.tree_unflatten(
|
|
treedef, optimization_barrier_p.bind(*flat_args))
|
|
|
|
|
|
def _optimization_barrier_abstract_eval(*args):
|
|
return args
|
|
|
|
def _optimization_barrier_lowering_rule(ctx, *args):
|
|
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
|
|
flat_args = mlir.flatten_ir_values(args)
|
|
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
|
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
|
|
|
|
|
|
optimization_barrier_p = core.Primitive('optimization_barrier')
|
|
optimization_barrier_p.multiple_results = True
|
|
optimization_barrier_p.def_impl(
|
|
partial(dispatch.apply_primitive, optimization_barrier_p))
|
|
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
|
mlir.register_lowering(optimization_barrier_p,
|
|
_optimization_barrier_lowering_rule)
|