2020-10-17 14:33:26 -04:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import builtins
|
2021-08-11 17:32:36 -04:00
|
|
|
import enum
|
2020-10-17 14:33:26 -04:00
|
|
|
import functools
|
2021-09-13 17:24:44 -04:00
|
|
|
from functools import partial
|
2020-10-17 14:33:26 -04:00
|
|
|
import itertools
|
|
|
|
import operator
|
2021-12-09 22:42:04 -05:00
|
|
|
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union
|
2020-10-17 14:33:26 -04:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import jax
|
|
|
|
from jax import core
|
2021-06-07 14:51:04 -07:00
|
|
|
from jax._src import ad_util
|
2021-04-13 09:42:54 -07:00
|
|
|
from jax._src import api
|
2021-09-08 09:00:23 -07:00
|
|
|
from jax._src import api_util
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src import device_array
|
|
|
|
from jax._src import dispatch
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax import linear_util as lu
|
2021-04-07 19:35:17 -07:00
|
|
|
from jax._src import dtypes
|
2020-11-10 15:57:19 -08:00
|
|
|
from jax import tree_util
|
2021-04-19 08:52:48 -07:00
|
|
|
from jax._src.config import config
|
2021-11-23 16:34:33 -08:00
|
|
|
from jax.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
|
|
|
raise_to_shaped, abstract_token, canonicalize_shape)
|
2021-06-02 06:24:42 -07:00
|
|
|
from jax._src.abstract_arrays import array_types
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.interpreters import partial_eval as pe
|
2021-11-23 18:57:45 -08:00
|
|
|
from jax.interpreters import mlir
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.interpreters import xla
|
|
|
|
from jax.interpreters import pxla
|
|
|
|
from jax.interpreters import ad
|
|
|
|
from jax.interpreters import invertible_ad as iad
|
|
|
|
from jax.interpreters import batching
|
|
|
|
from jax.interpreters import masking
|
2021-11-23 18:57:45 -08:00
|
|
|
from jax._src import util
|
2021-09-13 17:24:44 -04:00
|
|
|
from jax._src.util import (cache, safe_zip, prod, safe_map, canonicalize_axis,
|
|
|
|
split_list)
|
2020-10-17 14:33:26 -04:00
|
|
|
from jax.tree_util import tree_map
|
2021-09-23 06:33:25 -07:00
|
|
|
import jax._src.lib
|
|
|
|
from jax._src.lib import pytree
|
|
|
|
from jax._src.lib import xla_bridge
|
|
|
|
from jax._src.lib import xla_client
|
2021-11-23 18:57:45 -08:00
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
from jax._src.lib.mlir.dialects import chlo
|
|
|
|
from jax._src.lib.mlir.dialects import mhlo
|
2021-11-23 16:34:33 -08:00
|
|
|
from jax._src.lax.utils import (
|
|
|
|
_argnum_weak_type,
|
|
|
|
_input_dtype,
|
|
|
|
standard_abstract_eval,
|
|
|
|
standard_multi_result_abstract_eval,
|
|
|
|
standard_named_shape_rule,
|
|
|
|
standard_primitive,
|
|
|
|
)
|
|
|
|
from jax._src.lax import slicing
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
xb = xla_bridge
|
|
|
|
xc = xla_client
|
|
|
|
xops = xla_client.ops
|
2021-11-24 11:15:04 -08:00
|
|
|
xe = xla_client._xla
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
_max = builtins.max
|
2020-11-05 09:24:18 -05:00
|
|
|
_min = builtins.min
|
2020-10-17 14:33:26 -04:00
|
|
|
_reduce = functools.reduce
|
|
|
|
|
|
|
|
Array = Any
|
|
|
|
DType = Any
|
2021-04-01 15:37:01 +03:00
|
|
|
Shape = core.Shape
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
T = TypeVar("T")
|
|
|
|
|
2021-11-24 20:27:10 -08:00
|
|
|
@functools.partial(jax.jit, inline=True)
|
|
|
|
def _array_copy(arr):
|
|
|
|
"""Return an on-device copy of a DeviceArray.
|
|
|
|
|
|
|
|
This is a private method; users can access this via ``jnp.array(x, copy=True)``.
|
|
|
|
|
|
|
|
Why do we need copies in a purely functional langauge? Well, JAX is *almost*
|
|
|
|
purely functional: the semantics of `donate_argnums` mean that sometimes buffers
|
|
|
|
are consumed, and you actually need to ensure a copy is generated on device.
|
|
|
|
"""
|
|
|
|
# TODO(jakevdp): There is no XLA copy operation, so for the time being we rely
|
|
|
|
# on an implementation detail: although XLA will optimize away non-operations like
|
|
|
|
# adding zero, it still results in a copied buffer. Eventually, we should move to
|
|
|
|
# a more direct method that avoids inserting a spurious add_p/or_p into the jaxpr.
|
|
|
|
if arr.dtype == bool:
|
|
|
|
return bitwise_or(arr, _const(arr, False))
|
|
|
|
return add(arr, _const(arr, 0))
|
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
def _try_broadcast_shapes(
|
|
|
|
shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]:
|
2020-11-13 14:55:04 -08:00
|
|
|
assert shapes
|
|
|
|
if len(shapes) == 1: return shapes[0]
|
|
|
|
rank, *others = {len(shape) for shape in shapes}
|
|
|
|
if others: return None # must have consistent rank
|
|
|
|
if not rank: return () # scalar case
|
2021-12-09 22:42:04 -05:00
|
|
|
result_shape = [-1] * rank
|
2020-11-13 14:55:04 -08:00
|
|
|
for i, sizes in enumerate(zip(*shapes)):
|
2021-12-10 23:22:11 -08:00
|
|
|
non_1s = {d for d in sizes if not core.symbolic_equal_dim(d, 1)}
|
2021-04-05 12:27:53 +03:00
|
|
|
if len(non_1s) > 1:
|
|
|
|
return None # must have equal sizes other than 1-sized axes
|
|
|
|
result_shape[i] = next(iter(non_1s), 1)
|
2021-04-01 15:37:01 +03:00
|
|
|
|
2020-11-13 14:55:04 -08:00
|
|
|
return tuple(result_shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
@cache()
|
2021-12-09 22:42:04 -05:00
|
|
|
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]:
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
|
|
|
|
if len(shapes) == 1:
|
|
|
|
return shapes[0]
|
|
|
|
ndim = _max(len(shape) for shape in shapes)
|
2021-12-09 22:42:04 -05:00
|
|
|
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
|
|
|
|
result_shape = _try_broadcast_shapes(shape_list)
|
2020-10-17 14:33:26 -04:00
|
|
|
if result_shape is None:
|
|
|
|
raise ValueError("Incompatible shapes for broadcasting: {}"
|
2021-12-09 22:42:04 -05:00
|
|
|
.format(tuple(shape_list)))
|
2020-10-17 14:33:26 -04:00
|
|
|
return result_shape
|
|
|
|
|
|
|
|
def _identity(x): return x
|
|
|
|
|
|
|
|
### traceables
|
|
|
|
|
|
|
|
def neg(x: Array) -> Array:
|
|
|
|
r"""Elementwise negation: :math:`-x`."""
|
|
|
|
return neg_p.bind(x)
|
|
|
|
|
|
|
|
def sign(x: Array) -> Array:
|
|
|
|
r"""Elementwise sign.
|
|
|
|
|
|
|
|
For floating-point inputs, returns
|
|
|
|
:math:`\mathrm{sign}(x) = \begin{cases}
|
|
|
|
-1 & x < 0\\
|
|
|
|
-0 & x = -0\\
|
|
|
|
\mathit{NaN} & x = \mathit{NaN}\\
|
|
|
|
+0 & x = +0\\
|
|
|
|
1 & x > 0
|
|
|
|
\end{cases}`
|
|
|
|
|
|
|
|
For signed integer inputs, returns
|
|
|
|
:math:`\mathrm{sign}(x) = \begin{cases}
|
|
|
|
-1 & x < 0\\
|
|
|
|
0 & x = 0\\
|
|
|
|
1 & x > 0
|
|
|
|
\end{cases}`
|
|
|
|
|
|
|
|
For complex inputs, returns the complex phase, i.e.
|
|
|
|
:math:`\mathrm{sign}(x) = \frac{x}{|x|}`.
|
|
|
|
"""
|
|
|
|
return sign_p.bind(x)
|
|
|
|
|
|
|
|
def nextafter(x1: Array, x2: Array) -> Array:
|
|
|
|
r"""Returns the next representable value after `x1` in the direction of `x2`.
|
|
|
|
|
|
|
|
Note that in some environments flush-denormal-to-zero semantics is used.
|
|
|
|
This means that, around zero, this function returns strictly non-zero
|
|
|
|
values which appear as zero in any operations. Consider this example::
|
2021-01-28 15:20:02 -08:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
>>> jnp.nextafter(0, 1) # denormal numbers are representable
|
|
|
|
DeviceArray(1.e-45, dtype=float32)
|
|
|
|
>>> jnp.nextafter(0, 1) * 1 # but are flushed to zero
|
|
|
|
DeviceArray(0., dtype=float32)
|
|
|
|
|
|
|
|
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
|
|
|
|
"""
|
2021-04-15 15:16:29 -07:00
|
|
|
return nextafter_p.bind(x1, x2)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def floor(x: Array) -> Array:
|
|
|
|
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
|
|
|
|
return floor_p.bind(x)
|
|
|
|
|
|
|
|
def ceil(x: Array) -> Array:
|
|
|
|
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
|
|
|
|
return ceil_p.bind(x)
|
|
|
|
|
2021-11-23 16:34:33 -08:00
|
|
|
class RoundingMethod(enum.IntEnum):
|
2020-11-23 15:33:04 +01:00
|
|
|
AWAY_FROM_ZERO = 0
|
|
|
|
TO_NEAREST_EVEN = 1
|
|
|
|
|
|
|
|
def round(x: Array,
|
|
|
|
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
|
|
|
|
) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
r"""Elementwise round.
|
|
|
|
|
2020-11-23 15:33:04 +01:00
|
|
|
Rounds values to the nearest integer.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: an array or scalar value to round.
|
|
|
|
rounding_method: the method to use when rounding halfway values
|
|
|
|
(e.g., `0.5`). See ``lax.RoundingMethod`` for the list of possible
|
|
|
|
values.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the elementwise rounding of x.
|
|
|
|
"""
|
|
|
|
rounding_method = RoundingMethod(rounding_method)
|
|
|
|
return round_p.bind(x, rounding_method=rounding_method)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def is_finite(x: Array) -> Array:
|
|
|
|
r"""Elementwise :math:`\mathrm{isfinite}`.
|
|
|
|
|
|
|
|
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
|
|
|
|
:math:`\mathit{NaN}`.
|
|
|
|
"""
|
|
|
|
return is_finite_p.bind(x)
|
|
|
|
|
|
|
|
def exp(x: Array) -> Array:
|
|
|
|
r"""Elementwise exponential: :math:`e^x`."""
|
|
|
|
return exp_p.bind(x)
|
|
|
|
|
|
|
|
def expm1(x: Array) -> Array:
|
|
|
|
r"""Elementwise :math:`e^{x} - 1`."""
|
|
|
|
return expm1_p.bind(x)
|
|
|
|
|
|
|
|
def log(x: Array) -> Array:
|
|
|
|
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
|
|
|
|
return log_p.bind(x)
|
|
|
|
|
|
|
|
def log1p(x: Array) -> Array:
|
|
|
|
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
|
|
|
|
return log1p_p.bind(x)
|
|
|
|
|
|
|
|
def tanh(x: Array) -> Array:
|
|
|
|
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
|
|
|
|
return tanh_p.bind(x)
|
|
|
|
|
|
|
|
def sin(x: Array) -> Array:
|
|
|
|
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
|
|
|
|
return sin_p.bind(x)
|
|
|
|
|
|
|
|
def cos(x: Array) -> Array:
|
|
|
|
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
|
|
|
|
return cos_p.bind(x)
|
|
|
|
|
|
|
|
def atan2(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise arc tangent of two variables:
|
|
|
|
:math:`\mathrm{atan}({x \over y})`."""
|
|
|
|
return atan2_p.bind(x, y)
|
|
|
|
|
|
|
|
def betainc(a: Array, b: Array, x: Array) -> Array:
|
|
|
|
r"""Elementwise regularized incomplete beta integral."""
|
|
|
|
return regularized_incomplete_beta_p.bind(a, b, x)
|
|
|
|
|
|
|
|
def lgamma(x: Array) -> Array:
|
|
|
|
r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`."""
|
|
|
|
return lgamma_p.bind(x)
|
|
|
|
|
|
|
|
def digamma(x: Array) -> Array:
|
|
|
|
r"""Elementwise digamma: :math:`\psi(x)`."""
|
|
|
|
return digamma_p.bind(x)
|
|
|
|
|
|
|
|
def igamma(a: Array, x: Array) -> Array:
|
|
|
|
r"""Elementwise regularized incomplete gamma function."""
|
|
|
|
return igamma_p.bind(a, x)
|
|
|
|
|
|
|
|
def igammac(a: Array, x: Array) -> Array:
|
|
|
|
r"""Elementwise complementary regularized incomplete gamma function."""
|
|
|
|
return igammac_p.bind(a, x)
|
|
|
|
|
|
|
|
def igamma_grad_a(a: Array, x: Array) -> Array:
|
|
|
|
r"""Elementwise derivative of the regularized incomplete gamma function."""
|
|
|
|
return igamma_grad_a_p.bind(a, x)
|
|
|
|
|
|
|
|
def random_gamma_grad(a: Array, x: Array) -> Array:
|
|
|
|
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
|
|
|
|
return random_gamma_grad_p.bind(a, x)
|
|
|
|
|
|
|
|
def bessel_i0e(x: Array) -> Array:
|
|
|
|
r"""Exponentially scaled modified Bessel function of order 0:
|
|
|
|
:math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
|
|
|
|
"""
|
|
|
|
return bessel_i0e_p.bind(x)
|
|
|
|
|
|
|
|
def bessel_i1e(x: Array) -> Array:
|
|
|
|
r"""Exponentially scaled modified Bessel function of order 1:
|
|
|
|
:math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)`
|
|
|
|
"""
|
|
|
|
return bessel_i1e_p.bind(x)
|
|
|
|
|
|
|
|
def erf(x: Array) -> Array:
|
|
|
|
r"""Elementwise error function: :math:`\mathrm{erf}(x)`."""
|
|
|
|
return erf_p.bind(x)
|
|
|
|
|
|
|
|
def erfc(x: Array) -> Array:
|
|
|
|
r"""Elementwise complementary error function:
|
|
|
|
:math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`."""
|
|
|
|
return erfc_p.bind(x)
|
|
|
|
|
|
|
|
def erf_inv(x: Array) -> Array:
|
|
|
|
r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`."""
|
|
|
|
return erf_inv_p.bind(x)
|
|
|
|
|
|
|
|
def real(x: Array) -> Array:
|
|
|
|
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
|
|
|
|
|
|
|
|
Returns the real part of a complex number.
|
|
|
|
"""
|
|
|
|
return real_p.bind(x)
|
|
|
|
|
|
|
|
def imag(x: Array) -> Array:
|
|
|
|
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
|
|
|
|
|
|
|
|
Returns the imaginary part of a complex number.
|
|
|
|
"""
|
|
|
|
return imag_p.bind(x)
|
|
|
|
|
|
|
|
def complex(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise make complex number: :math:`x + jy`.
|
|
|
|
|
|
|
|
Builds a complex number from real and imaginary parts.
|
|
|
|
"""
|
2021-04-15 15:16:29 -07:00
|
|
|
return complex_p.bind(x, y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def conj(x: Array) -> Array:
|
|
|
|
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
|
|
|
|
return conj_p.bind(x, input_dtype=_dtype(x))
|
|
|
|
|
|
|
|
def abs(x: Array) -> Array:
|
|
|
|
r"""Elementwise absolute value: :math:`|x|`."""
|
|
|
|
return abs_p.bind(x)
|
|
|
|
|
|
|
|
def pow(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise power: :math:`x^y`."""
|
|
|
|
return pow_p.bind(x, y)
|
|
|
|
|
|
|
|
def integer_pow(x: Array, y: int) -> Array:
|
|
|
|
r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer."""
|
2021-01-19 11:36:39 -08:00
|
|
|
return integer_pow_p.bind(x, y=y)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def sqrt(x: Array) -> Array:
|
|
|
|
r"""Elementwise square root: :math:`\sqrt{x}`."""
|
|
|
|
return sqrt_p.bind(x)
|
|
|
|
|
|
|
|
def rsqrt(x: Array) -> Array:
|
2021-02-16 17:18:38 -08:00
|
|
|
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
|
2020-10-17 14:33:26 -04:00
|
|
|
return rsqrt_p.bind(x)
|
|
|
|
|
2021-07-22 14:00:52 -07:00
|
|
|
def cbrt(x: Array) -> Array:
|
|
|
|
r"""Elementwise cube root: :math:`\cbrt{x}`."""
|
|
|
|
return cbrt_p.bind(x)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def bitwise_not(x: Array) -> Array:
|
|
|
|
r"""Elementwise NOT: :math:`\neg x`."""
|
|
|
|
return not_p.bind(x)
|
|
|
|
|
|
|
|
def bitwise_and(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise AND: :math:`x \wedge y`."""
|
|
|
|
return and_p.bind(x, y)
|
|
|
|
|
|
|
|
def bitwise_or(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise OR: :math:`x \vee y`."""
|
|
|
|
return or_p.bind(x, y)
|
|
|
|
|
|
|
|
def bitwise_xor(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
|
|
|
|
return xor_p.bind(x, y)
|
|
|
|
|
|
|
|
def population_count(x: Array) -> Array:
|
|
|
|
r"""Elementwise popcount, count the number of set bits in each element."""
|
|
|
|
return population_count_p.bind(x)
|
|
|
|
|
2021-03-19 22:35:31 -07:00
|
|
|
def clz(x: Array) -> Array:
|
|
|
|
r"""Elementwise count-leading-zeros."""
|
|
|
|
return clz_p.bind(x)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def add(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise addition: :math:`x + y`."""
|
|
|
|
return add_p.bind(x, y)
|
|
|
|
|
|
|
|
def sub(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise subtraction: :math:`x - y`."""
|
|
|
|
return sub_p.bind(x, y)
|
|
|
|
|
|
|
|
def mul(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise multiplication: :math:`x \times y`."""
|
|
|
|
return mul_p.bind(x, y)
|
|
|
|
|
|
|
|
def div(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise division: :math:`x \over y`."""
|
|
|
|
return div_p.bind(x, y)
|
|
|
|
|
|
|
|
def rem(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise remainder: :math:`x \bmod y`."""
|
|
|
|
return rem_p.bind(x, y)
|
|
|
|
|
|
|
|
def max(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
|
|
|
|
|
|
|
|
For complex numbers, uses a lexicographic comparison on the
|
|
|
|
`(real, imaginary)` pairs."""
|
|
|
|
return max_p.bind(x, y)
|
|
|
|
|
|
|
|
def min(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
|
|
|
|
|
|
|
For complex numbers, uses a lexicographic comparison on the
|
|
|
|
`(real, imaginary)` pairs."""
|
|
|
|
return min_p.bind(x, y)
|
|
|
|
|
|
|
|
def shift_left(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise left shift: :math:`x \ll y`."""
|
|
|
|
return shift_left_p.bind(x, y)
|
|
|
|
|
|
|
|
def shift_right_arithmetic(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise arithmetic right shift: :math:`x \gg y`."""
|
|
|
|
return shift_right_arithmetic_p.bind(x, y)
|
|
|
|
|
|
|
|
def shift_right_logical(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise logical right shift: :math:`x \gg y`."""
|
|
|
|
return shift_right_logical_p.bind(x, y)
|
|
|
|
|
|
|
|
def eq(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise equals: :math:`x = y`."""
|
|
|
|
return eq_p.bind(x, y)
|
|
|
|
|
|
|
|
def ne(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise not-equals: :math:`x \neq y`."""
|
|
|
|
return ne_p.bind(x, y)
|
|
|
|
|
|
|
|
def ge(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
|
|
|
|
return ge_p.bind(x, y)
|
|
|
|
|
|
|
|
def gt(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise greater-than: :math:`x > y`."""
|
|
|
|
return gt_p.bind(x, y)
|
|
|
|
|
|
|
|
def le(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
|
|
|
|
return le_p.bind(x, y)
|
|
|
|
|
|
|
|
def lt(x: Array, y: Array) -> Array:
|
|
|
|
r"""Elementwise less-than: :math:`x < y`."""
|
|
|
|
return lt_p.bind(x, y)
|
|
|
|
|
2021-03-28 10:32:02 -07:00
|
|
|
def convert_element_type(operand: Array, new_dtype: DType) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Elementwise cast.
|
2021-03-28 10:32:02 -07:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
Wraps XLA's `ConvertElementType
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
|
|
|
|
operator, which performs an elementwise conversion from one type to another.
|
|
|
|
Similar to a C++ `static_cast`.
|
|
|
|
|
|
|
|
Args:
|
2021-02-08 13:37:25 -08:00
|
|
|
operand: an array or scalar value to be cast
|
2021-03-28 10:32:02 -07:00
|
|
|
new_dtype: a NumPy dtype representing the target type.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
|
|
|
|
"""
|
2021-02-05 20:30:14 -08:00
|
|
|
if hasattr(operand, '__jax_array__'):
|
|
|
|
operand = operand.__jax_array__()
|
2021-03-28 10:32:02 -07:00
|
|
|
return _convert_element_type(operand, new_dtype, weak_type=False)
|
2021-02-08 13:37:25 -08:00
|
|
|
|
2021-03-28 10:32:02 -07:00
|
|
|
def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
|
|
|
|
weak_type: bool = False):
|
2021-03-21 19:38:12 -07:00
|
|
|
# Don't canonicalize old_dtype because x64 context might cause
|
|
|
|
# un-canonicalized operands to be passed in.
|
2021-12-01 10:33:26 -08:00
|
|
|
old_dtype = dtypes.dtype(operand, canonicalize=False)
|
2021-02-08 13:37:25 -08:00
|
|
|
old_weak_type = dtypes.is_weakly_typed(operand)
|
|
|
|
|
2021-12-01 10:33:26 -08:00
|
|
|
if new_dtype is None:
|
|
|
|
new_dtype = old_dtype
|
|
|
|
else:
|
|
|
|
new_dtype = np.dtype(new_dtype)
|
|
|
|
new_dtype = dtypes.dtype(new_dtype, canonicalize=True)
|
2021-03-12 15:26:06 -08:00
|
|
|
new_weak_type = bool(weak_type)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
|
|
msg = "Casting complex values to real discards the imaginary part"
|
|
|
|
warnings.warn(msg, np.ComplexWarning, stacklevel=2)
|
2021-02-08 13:37:25 -08:00
|
|
|
|
2021-03-21 19:38:12 -07:00
|
|
|
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
|
|
|
|
# not be an error since the target dtype fits the value. Handle this case by
|
2021-03-21 19:41:04 -07:00
|
|
|
# converting to a NumPy array before calling bind. Without this step, we'd
|
|
|
|
# first canonicalize the input to a value of dtype int32 or int64, leading to
|
|
|
|
# an overflow error.
|
2021-03-21 19:38:12 -07:00
|
|
|
if type(operand) is int:
|
|
|
|
operand = np.asarray(operand, new_dtype)
|
|
|
|
|
2021-03-21 13:39:57 -07:00
|
|
|
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
|
2021-11-22 08:22:10 -08:00
|
|
|
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
2021-02-08 13:37:25 -08:00
|
|
|
return operand
|
|
|
|
else:
|
|
|
|
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
|
|
|
|
weak_type=new_weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array:
|
|
|
|
"""Elementwise bitcast.
|
|
|
|
|
|
|
|
Wraps XLA's `BitcastConvertType
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
|
|
|
|
operator, which performs a bit cast from one type to another. The bitwidth
|
|
|
|
of the source and destination types must match.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array or scalar value to be cast
|
|
|
|
new_dtype: the new type. Should be a NumPy type.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array with the same shape as `operand`, bitcast elementwise to
|
|
|
|
`new_dtype`.
|
|
|
|
"""
|
|
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
2021-02-08 13:37:25 -08:00
|
|
|
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def clamp(min: Array, x: Array, max: Array) -> Array:
|
|
|
|
r"""Elementwise clamp.
|
|
|
|
|
|
|
|
Returns :math:`\mathrm{clamp}(x) = \begin{cases}
|
|
|
|
\mathit{min} & \text{if } x < \mathit{min},\\
|
|
|
|
\mathit{max} & \text{if } x > \mathit{max},\\
|
|
|
|
x & \text{otherwise}
|
|
|
|
\end{cases}`.
|
|
|
|
"""
|
|
|
|
return clamp_p.bind(min, x, max)
|
|
|
|
|
|
|
|
def concatenate(operands: Sequence[Array], dimension: int) -> Array:
|
|
|
|
"""Concatenates a sequence of arrays along `dimension`.
|
|
|
|
|
|
|
|
Wraps XLA's `Concatenate
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operands: a sequence of arrays to concatenate. The arrays must have equal
|
|
|
|
shapes, except in the `dimension` axis.
|
|
|
|
dimension: the dimension along which to concatenate the arrays.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the concatenation.
|
|
|
|
"""
|
2021-11-16 17:36:28 -05:00
|
|
|
if len(operands) == 0:
|
|
|
|
raise ValueError("concatenate requires a non-empty sequences of arrays")
|
2020-10-17 14:33:26 -04:00
|
|
|
return concatenate_p.bind(*operands, dimension=dimension)
|
|
|
|
|
2021-11-08 14:15:31 -08:00
|
|
|
|
|
|
|
class _enum_descriptor(object):
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
def __get__(self, _, owner):
|
|
|
|
return owner(self.val)
|
|
|
|
|
|
|
|
|
|
|
|
class Precision(xla_client.PrecisionConfig.Precision): # type: ignore
|
|
|
|
"""Precision enum for lax functions
|
|
|
|
|
|
|
|
The `precision` argument to JAX functions generally controls the tradeoff
|
|
|
|
between speed and accuracy for array computations on accelerator backends,
|
|
|
|
(i.e. TPU and GPU). Members are:
|
|
|
|
|
|
|
|
DEFAULT:
|
|
|
|
Fastest mode, but least accurate. Performs computations in bfloat16.
|
|
|
|
Aliases: ``'default'``, ``'fastest'``, ``'bfloat16'``.
|
|
|
|
HIGH:
|
|
|
|
Slower but more accurate. Performs float32 computations in 3 bfloat16
|
|
|
|
passes, or using tensorfloat32 where available. Aliases: ``'high'`,
|
|
|
|
``'bfloat16_3x'``, ``'tensorfloat32'``.
|
|
|
|
HIGHEST:
|
|
|
|
Slowest but most accurate. Performs computations in float32 or float64
|
|
|
|
as applicable. Aliases: ``'highest'``, ``'float32'``.
|
|
|
|
"""
|
|
|
|
# Wrap enum values with this class.
|
|
|
|
DEFAULT = _enum_descriptor('default')
|
|
|
|
HIGH = _enum_descriptor('high')
|
|
|
|
HIGHEST = _enum_descriptor('highest')
|
|
|
|
|
|
|
|
_strings = {
|
|
|
|
'highest': xla_client.PrecisionConfig.Precision.HIGHEST,
|
|
|
|
'float32': xla_client.PrecisionConfig.Precision.HIGHEST,
|
|
|
|
'high': xla_client.PrecisionConfig.Precision.HIGH,
|
|
|
|
'bfloat16_3x': xla_client.PrecisionConfig.Precision.HIGH,
|
|
|
|
'tensorfloat32': xla_client.PrecisionConfig.Precision.HIGH,
|
|
|
|
'default': xla_client.PrecisionConfig.Precision.DEFAULT,
|
|
|
|
'bfloat16': xla_client.PrecisionConfig.Precision.DEFAULT,
|
|
|
|
'fastest': xla_client.PrecisionConfig.Precision.DEFAULT,
|
|
|
|
None: xla_client.PrecisionConfig.Precision.DEFAULT,
|
|
|
|
}
|
|
|
|
def __init__(self, arg0):
|
|
|
|
arg0 = self._strings.get(arg0, arg0)
|
|
|
|
super().__init__(arg0)
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
PrecisionType = Any
|
2021-03-23 20:58:52 -07:00
|
|
|
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
|
|
|
|
Tuple[PrecisionType, PrecisionType]]
|
2021-11-08 14:15:31 -08:00
|
|
|
|
2020-12-10 02:29:40 +00:00
|
|
|
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
|
|
|
preferred_element_type: Optional[DType] = None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
|
|
|
|
|
|
|
Wraps XLA's `Dot
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
For more general contraction, see the `dot_general` operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs: an array of rank 1 or 2.
|
|
|
|
rhs: an array of rank 1 or 2.
|
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
2021-11-08 14:15:31 -08:00
|
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
2020-10-17 14:33:26 -04:00
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
2021-11-08 14:15:31 -08:00
|
|
|
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
|
2020-12-10 02:29:40 +00:00
|
|
|
preferred_element_type: Optional. Either ``None``, which means the default
|
|
|
|
accumulation type for the input types, or a datatype, indicating to
|
|
|
|
accumulate results to and return a result with that datatype.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the product.
|
|
|
|
"""
|
2021-06-12 11:42:15 +03:00
|
|
|
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.symbolic_equal_dim(lhs.shape[-1], rhs.shape[0]):
|
2020-10-17 14:33:26 -04:00
|
|
|
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
|
2021-10-19 06:48:34 -07:00
|
|
|
precision=precision,
|
|
|
|
preferred_element_type=preferred_element_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
|
|
|
|
lhs.shape, rhs.shape))
|
|
|
|
|
|
|
|
|
|
|
|
DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]],
|
|
|
|
Tuple[Sequence[int], Sequence[int]]]
|
|
|
|
|
|
|
|
def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
|
2020-12-10 02:29:40 +00:00
|
|
|
precision: PrecisionLike = None,
|
|
|
|
preferred_element_type: Optional[DType] = None) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
"""More general contraction operator.
|
|
|
|
|
|
|
|
Wraps XLA's `DotGeneral
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs: an array
|
|
|
|
rhs: an array
|
|
|
|
dimension_numbers: a tuple of tuples of the form
|
|
|
|
`((lhs_contracting_dims, rhs_contracting_dims),
|
|
|
|
(lhs_batch_dims, rhs_batch_dims))`
|
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
2021-11-08 14:15:31 -08:00
|
|
|
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
2020-10-17 14:33:26 -04:00
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
|
2021-11-08 14:15:31 -08:00
|
|
|
:class:`~jax.lax.Precision` enums indicating precision of ``lhs``` and ``rhs``.
|
2020-12-10 02:29:40 +00:00
|
|
|
preferred_element_type: Optional. Either ``None``, which means the default
|
|
|
|
accumulation type for the input types, or a datatype, indicating to
|
|
|
|
accumulate results to and return a result with that datatype.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the result.
|
|
|
|
"""
|
2021-10-28 16:48:59 -07:00
|
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
|
|
|
api_util._ensure_index_tuple(rhs_contract))
|
|
|
|
bdims = (api_util._ensure_index_tuple(lhs_batch),
|
|
|
|
api_util._ensure_index_tuple(rhs_batch))
|
2021-12-06 12:20:48 -08:00
|
|
|
preferred_element_type = (
|
|
|
|
None if preferred_element_type is None else
|
|
|
|
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
|
2020-10-17 14:33:26 -04:00
|
|
|
return dot_general_p.bind(lhs, rhs,
|
2021-10-28 16:48:59 -07:00
|
|
|
dimension_numbers=(cdims, bdims),
|
2021-09-13 15:39:02 -07:00
|
|
|
precision=canonicalize_precision(precision),
|
2020-12-10 02:29:40 +00:00
|
|
|
preferred_element_type=preferred_element_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def broadcast(operand: Array, sizes: Sequence[int]) -> Array:
|
2021-07-12 15:33:26 -07:00
|
|
|
"""Broadcasts an array, adding new leading dimensions
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array
|
2021-07-15 03:12:07 -07:00
|
|
|
sizes: a sequence of integers, giving the sizes of new leading dimensions
|
2021-07-12 15:33:26 -07:00
|
|
|
to add to the front of the array.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the result.
|
2021-07-12 15:33:26 -07:00
|
|
|
|
|
|
|
See Also:
|
|
|
|
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
|
|
|
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
|
|
|
|
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
|
|
|
|
|
|
|
|
def broadcast_in_dim(operand: Array, shape: Shape,
|
|
|
|
broadcast_dimensions: Sequence[int]) -> Array:
|
|
|
|
"""Wraps XLA's `BroadcastInDim
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
|
|
|
|
operator.
|
2021-07-12 15:33:26 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array
|
|
|
|
shape: the shape of the target array
|
|
|
|
broadcast_dimensions: which dimension in the target shape each dimension
|
|
|
|
of the operand shape corresponds to
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the result.
|
|
|
|
|
|
|
|
See Also:
|
|
|
|
jax.lax.broadcast : simpler interface to add new leading dimensions.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
|
|
|
shape = _broadcast_in_dim_shape_rule(
|
|
|
|
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
|
|
|
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
|
2021-11-22 08:22:10 -08:00
|
|
|
and isinstance(operand, (device_array.DeviceArray, core.Tracer))):
|
2020-10-17 14:33:26 -04:00
|
|
|
return operand
|
|
|
|
return broadcast_in_dim_p.bind(
|
|
|
|
operand, shape=tuple(shape),
|
|
|
|
broadcast_dimensions=tuple(broadcast_dimensions))
|
|
|
|
|
|
|
|
def broadcast_to_rank(x: Array, rank: int) -> Array:
|
|
|
|
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
|
|
|
|
return broadcast(x, (1,) * (rank - x.ndim))
|
|
|
|
|
|
|
|
def reshape(operand: Array, new_sizes: Shape,
|
|
|
|
dimensions: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
"""Wraps XLA's `Reshape
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
|
|
|
|
``lax.expand_dims``. These preserve information about axis identity that may
|
|
|
|
be useful for advanced transformation rules.
|
2021-04-29 09:48:52 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: array to be reshaped.
|
|
|
|
new_sizes: sequence of integers specifying the resulting shape. The size
|
|
|
|
of the final array must match the size of the input.
|
|
|
|
dimensions: optional sequence of integers specifying the permutation order of
|
|
|
|
the input shape. If specified, the length must match ``operand.shape``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
out: reshaped array.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Simple reshaping from one to two dimensions:
|
|
|
|
|
|
|
|
>>> x = jnp.arange(6)
|
|
|
|
>>> y = reshape(x, (2, 3))
|
|
|
|
>>> y
|
|
|
|
DeviceArray([[0, 1, 2],
|
|
|
|
[3, 4, 5]], dtype=int32)
|
|
|
|
|
|
|
|
Reshaping back to one dimension:
|
|
|
|
|
|
|
|
>>> reshape(y, (6,))
|
|
|
|
DeviceArray([0, 1, 2, 3, 4, 5], dtype=int32)
|
|
|
|
|
|
|
|
Reshaping to one dimension with permutation of dimensions:
|
|
|
|
|
|
|
|
>>> reshape(y, (6,), (1, 0))
|
|
|
|
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
|
|
|
new_sizes = canonicalize_shape(new_sizes) # TODO
|
|
|
|
new_sizes = tuple(new_sizes)
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
|
2021-10-28 16:48:59 -07:00
|
|
|
if dimensions is None:
|
|
|
|
same_dims = True
|
|
|
|
dims = None
|
|
|
|
else:
|
|
|
|
dims = api_util._ensure_index_tuple(dimensions)
|
|
|
|
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
|
2021-08-27 14:05:21 -07:00
|
|
|
if (np.shape(operand) and same_shape and same_dims
|
2021-11-22 08:22:10 -08:00
|
|
|
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
2020-10-17 14:33:26 -04:00
|
|
|
return operand
|
|
|
|
else:
|
|
|
|
return reshape_p.bind(
|
|
|
|
operand, new_sizes=new_sizes,
|
2021-10-28 16:48:59 -07:00
|
|
|
dimensions=None if dims is None or same_dims else dims)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def pad(operand: Array, padding_value: Array,
|
|
|
|
padding_config: Sequence[Tuple[int, int, int]]) -> Array:
|
|
|
|
"""Applies low, high, and/or interior padding to an array.
|
|
|
|
|
|
|
|
Wraps XLA's `Pad
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to be padded.
|
|
|
|
padding_value: the value to be inserted as padding. Must have the same dtype
|
|
|
|
as ``operand``.
|
|
|
|
padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
|
|
|
|
giving the amount of low, high, and interior (dilation) padding to insert
|
|
|
|
in each dimension.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The ``operand`` array with padding value ``padding_value`` inserted in each
|
|
|
|
dimension according to the ``padding_config``.
|
|
|
|
"""
|
|
|
|
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
|
|
|
|
|
|
|
def rev(operand: Array, dimensions: Sequence[int]) -> Array:
|
|
|
|
"""Wraps XLA's `Rev
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
|
|
|
|
operator.
|
|
|
|
"""
|
|
|
|
return rev_p.bind(operand, dimensions=tuple(dimensions))
|
|
|
|
|
|
|
|
def select(pred: Array, on_true: Array, on_false: Array) -> Array:
|
|
|
|
"""Wraps XLA's `Select
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#select>`_
|
|
|
|
operator.
|
|
|
|
"""
|
|
|
|
return select_p.bind(pred, on_true, on_false)
|
|
|
|
|
|
|
|
def transpose(operand: Array, permutation: Sequence[int]) -> Array:
|
|
|
|
"""Wraps XLA's `Transpose
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2020-07-22 12:10:43 -07:00
|
|
|
permutation = tuple(operator.index(d) for d in permutation)
|
2021-08-27 14:05:21 -07:00
|
|
|
if (permutation == tuple(range(np.ndim(operand)))
|
2021-11-22 08:22:10 -08:00
|
|
|
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
2020-10-17 14:33:26 -04:00
|
|
|
return operand
|
|
|
|
else:
|
|
|
|
return transpose_p.bind(operand, permutation=permutation)
|
|
|
|
|
|
|
|
def argmin(operand: Array, axis: int,
|
|
|
|
index_dtype: DType) -> Tuple[Array, Array]:
|
|
|
|
"""Computes the index of the minimum element along ``axis``."""
|
|
|
|
return argmin_p.bind(operand, axes=(axis,),
|
|
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
|
|
|
|
def argmax(operand: Array, axis: int,
|
|
|
|
index_dtype: DType) -> Tuple[Array, Array]:
|
|
|
|
"""Computes the index of the maximum element along ``axis``."""
|
|
|
|
return argmax_p.bind(operand, axes=(axis,),
|
|
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
def reduce(operands: Any,
|
|
|
|
init_values: Any,
|
|
|
|
computation: Callable[[Any, Any], Any],
|
|
|
|
dimensions: Sequence[int]) -> Any:
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Wraps XLA's `Reduce
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
|
|
|
|
operator.
|
2021-02-08 09:23:35 -05:00
|
|
|
|
|
|
|
``init_values`` and ``computation`` together must form a `monoid
|
|
|
|
<https://en.wikipedia.org/wiki/Monoid>`_
|
|
|
|
for correctness. That is ``init_values`` must be an identity of
|
|
|
|
``computation``, and ``computation`` must be associative. XLA may exploit both
|
|
|
|
of these properties during code generation; if either is violated the result
|
|
|
|
is undefined.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
2020-11-10 15:57:19 -08:00
|
|
|
flat_operands, operand_tree = tree_util.tree_flatten(operands)
|
|
|
|
flat_init_values, init_value_tree = tree_util.tree_flatten(init_values)
|
|
|
|
if operand_tree != init_value_tree:
|
|
|
|
raise ValueError('Operands must have the same tree structure as init_values:'
|
|
|
|
f' {operand_tree} vs. {init_value_tree}')
|
|
|
|
if len(flat_operands) != len(flat_init_values):
|
|
|
|
raise ValueError('Must have same total number of operands as init_values: '
|
|
|
|
f' {len(flat_operands)} vs. {len(flat_init_values)}')
|
|
|
|
monoid_reducer = _get_monoid_reducer(computation, flat_init_values)
|
2020-10-17 14:33:26 -04:00
|
|
|
if monoid_reducer:
|
2021-02-08 13:37:25 -08:00
|
|
|
# monoid reducers bypass the weak_type_rule, so we set it explicitly.
|
|
|
|
weak_type = dtypes.is_weakly_typed(*flat_operands) and dtypes.is_weakly_typed(*flat_init_values)
|
2021-03-28 10:32:02 -07:00
|
|
|
return _convert_element_type(monoid_reducer(*flat_operands, dimensions),
|
|
|
|
weak_type=weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2020-11-10 15:57:19 -08:00
|
|
|
flat_init_avals = safe_map(_abstractify, flat_init_values)
|
|
|
|
jaxpr, consts, out_tree = _variadic_reduction_jaxpr(
|
|
|
|
computation, tuple(flat_init_avals), init_value_tree)
|
|
|
|
out = reduce_p.bind(*(flat_operands + flat_init_values), computation=computation,
|
2020-10-17 14:33:26 -04:00
|
|
|
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
|
2020-11-10 15:57:19 -08:00
|
|
|
return tree_util.tree_unflatten(out_tree, out)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
@cache()
|
|
|
|
def _reduction_jaxpr(computation, aval):
|
|
|
|
pval = pe.PartialVal.unknown(aval)
|
2021-01-28 15:36:15 -08:00
|
|
|
@lu.wrap_init
|
|
|
|
def comp(x, y):
|
|
|
|
result = computation(x, y)
|
|
|
|
if not (isinstance(result, core.Tracer) or core.valid_jaxtype(result)):
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid return type from reduction function: {type(result)}\n"
|
|
|
|
f"Reduction functions should only return an array.\n"
|
|
|
|
f"Full return value: {result}")
|
|
|
|
return (result,)
|
2020-10-17 14:33:26 -04:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False)
|
|
|
|
return jaxpr, consts
|
|
|
|
|
2020-11-10 15:57:19 -08:00
|
|
|
@cache()
|
|
|
|
def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
|
|
|
|
avals = tree_util.tree_unflatten(aval_tree, flat_avals)
|
|
|
|
flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
|
|
|
|
comp = lu.wrap_init(computation)
|
|
|
|
flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
|
2021-08-05 14:06:45 -04:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
|
|
|
|
return jaxpr, tuple(consts), out_tree()
|
2020-11-10 15:57:19 -08:00
|
|
|
|
2021-11-22 13:20:55 -08:00
|
|
|
def _get_monoid_reducer(monoid_op: Callable,
|
|
|
|
xs: Sequence[Array]) -> Optional[Callable]:
|
2020-11-10 15:57:19 -08:00
|
|
|
if len(xs) != 1:
|
|
|
|
return None
|
|
|
|
x, = xs
|
2020-10-17 14:33:26 -04:00
|
|
|
aval = core.get_aval(x)
|
|
|
|
dtype = _dtype(x)
|
|
|
|
if (type(aval) is ConcreteArray) and aval.shape == ():
|
|
|
|
if monoid_op is add:
|
2021-02-08 13:37:25 -08:00
|
|
|
return np.equal(aval.val, 0) and partial(_reduce_sum)
|
2020-12-15 10:15:49 -08:00
|
|
|
elif monoid_op is mul:
|
2020-10-17 14:33:26 -04:00
|
|
|
return np.equal(aval.val, 1) and _reduce_prod
|
|
|
|
elif monoid_op is bitwise_or and dtype == np.bool_:
|
|
|
|
return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_or
|
|
|
|
elif monoid_op is bitwise_and and dtype == np.bool_:
|
|
|
|
return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_and
|
|
|
|
elif monoid_op is max:
|
|
|
|
return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_max
|
|
|
|
elif monoid_op is min:
|
|
|
|
return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_min
|
|
|
|
return None
|
|
|
|
|
|
|
|
def _get_max_identity(dtype: DType) -> Array:
|
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
|
|
return np.array(-np.inf, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
|
|
return np.array(dtypes.iinfo(dtype).min, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
|
|
return np.array(False, np.bool_)
|
|
|
|
|
|
|
|
def _get_min_identity(dtype: DType) -> Array:
|
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
|
|
return np.array(np.inf, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
|
|
return np.array(dtypes.iinfo(dtype).max, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
|
|
return np.array(True, np.bool_)
|
|
|
|
|
|
|
|
def _reduce_sum(operand: Array, axes: Sequence[int]) -> Array:
|
|
|
|
return reduce_sum_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
|
|
def _reduce_prod(operand: Array, axes: Sequence[int]) -> Array:
|
|
|
|
return reduce_prod_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
|
|
def _reduce_max(operand: Array, axes: Sequence[int]) -> Array:
|
|
|
|
return reduce_max_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
|
|
def _reduce_min(operand: Array, axes: Sequence[int]) -> Array:
|
|
|
|
return reduce_min_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
|
|
def _reduce_or(operand: Array, axes: Sequence[int]) -> Array:
|
|
|
|
return reduce_or_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
|
|
def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
|
|
|
|
return reduce_and_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
|
|
|
|
|
|
|
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
|
|
|
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
|
|
|
|
"""Wraps XLA's `Sort
|
2022-01-13 13:03:41 -08:00
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
|
|
|
|
|
|
|
|
For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values
|
|
|
|
are sorted to the end of the array. For complex inputs, the sort order is
|
|
|
|
lexicographic over the real and imaginary parts, with the real part primary.
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
operand : Array or sequence of arrays
|
|
|
|
dimension : integer dimension along which to sort. Default: -1.
|
|
|
|
is_stable : boolean specifying whether to use a stable sort. Default: True.
|
|
|
|
num_keys : number of operands to treat as sort keys. Default: 1.
|
|
|
|
For num_keys > 1, the sort order will be determined lexicographically using
|
|
|
|
the first `num_keys` arrays, with the first key being primary.
|
|
|
|
The remaining operands will be returned with the same permutation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
operand : sorted version of the input or inputs.
|
|
|
|
"""
|
|
|
|
if isinstance(operand, Sequence):
|
|
|
|
if len(operand) == 0:
|
|
|
|
raise TypeError("Sort requires at least one operand")
|
|
|
|
if not (1 <= num_keys <= len(operand)):
|
|
|
|
raise ValueError(f"num_keys={num_keys} must be between 1 and len(operand)={len(operand)}")
|
|
|
|
dimension = canonicalize_axis(dimension, len(operand[0].shape))
|
|
|
|
return tuple(sort_p.bind(*operand, dimension=dimension,
|
|
|
|
is_stable=is_stable,
|
|
|
|
num_keys=num_keys))
|
|
|
|
else:
|
|
|
|
if num_keys != 1:
|
|
|
|
raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.")
|
|
|
|
dimension = canonicalize_axis(dimension, len(operand.shape))
|
|
|
|
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
|
|
|
|
|
|
|
|
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
|
|
|
|
is_stable: bool = True) -> Tuple[Array, Array]:
|
2021-08-02 17:57:09 -07:00
|
|
|
"""Sorts ``keys`` along ``dimension`` and applies the same permutation to ``values``."""
|
2020-10-17 14:33:26 -04:00
|
|
|
dimension = canonicalize_axis(dimension, len(keys.shape))
|
|
|
|
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
|
|
|
|
return k, v
|
|
|
|
|
|
|
|
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
|
|
|
|
"""Returns top ``k`` values and their indices along the last axis of ``operand``."""
|
|
|
|
k = int(k)
|
|
|
|
if k < 0:
|
|
|
|
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
|
|
|
|
return top_k_p.bind(operand, k=k)
|
|
|
|
|
|
|
|
def tie_in(x: Array, y: Array) -> Array:
|
|
|
|
"""Deprecated. Ignores ``x`` and returns ``y``."""
|
|
|
|
return y
|
|
|
|
|
|
|
|
def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array:
|
|
|
|
"""Returns an array of `shape` filled with `fill_value`.
|
|
|
|
|
2021-01-15 11:49:19 +11:00
|
|
|
Args:
|
2020-10-17 14:33:26 -04:00
|
|
|
shape: sequence of integers, describing the shape of the output array.
|
|
|
|
fill_value: the value to fill the new array with.
|
|
|
|
dtype: the type of the output array, or `None`. If not `None`, `fill_value`
|
|
|
|
will be cast to `dtype`.
|
|
|
|
"""
|
|
|
|
shape = canonicalize_shape(shape)
|
|
|
|
if np.shape(fill_value):
|
|
|
|
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
|
|
|
raise TypeError(msg.format(np.shape(fill_value)))
|
2021-02-08 13:37:25 -08:00
|
|
|
weak_type = dtype is None and dtypes.is_weakly_typed(fill_value)
|
2020-10-17 14:33:26 -04:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
|
2021-03-28 10:32:02 -07:00
|
|
|
fill_value = _convert_element_type(fill_value, dtype, weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
return broadcast(fill_value, shape)
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
def _device_put_raw(x, weak_type=None):
|
2021-11-22 08:22:10 -08:00
|
|
|
if isinstance(x, device_array.DeviceArray):
|
2020-10-17 14:33:26 -04:00
|
|
|
return x
|
|
|
|
else:
|
2021-02-08 13:37:25 -08:00
|
|
|
aval = raise_to_shaped(core.get_aval(x), weak_type=weak_type)
|
2021-11-22 08:22:10 -08:00
|
|
|
return dispatch.array_result_handler(None, aval)(*dispatch.device_put(x))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
def zeros_like_shaped_array(aval: Array) -> Array:
|
2021-06-25 08:43:04 +02:00
|
|
|
assert isinstance(aval, ShapedArray)
|
2021-11-23 15:04:08 -08:00
|
|
|
if aval.dtype == dtypes.float0:
|
2021-07-25 16:07:33 +03:00
|
|
|
scalar_zero = np.zeros((), dtype=aval.dtype)
|
2021-11-23 15:04:08 -08:00
|
|
|
else:
|
|
|
|
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
|
2021-07-25 16:07:33 +03:00
|
|
|
return broadcast(scalar_zero, aval.shape)
|
2021-06-25 08:43:04 +02:00
|
|
|
|
|
|
|
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def iota(dtype: DType, size: int) -> Array:
|
|
|
|
"""Wraps XLA's `Iota
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2021-03-03 20:27:33 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-04-01 15:37:01 +03:00
|
|
|
size, = canonicalize_shape((size,))
|
2021-03-03 20:27:33 -05:00
|
|
|
return iota_p.bind(dtype=dtype, shape=(size,), dimension=0)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
|
|
|
|
"""Convenience wrapper around ``iota``."""
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
shape = canonicalize_shape(shape)
|
2020-10-23 07:34:32 -07:00
|
|
|
dimension = core.concrete_or_error(
|
|
|
|
int, dimension, "dimension argument of lax.broadcasted_iota")
|
|
|
|
return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
|
2020-10-23 07:34:32 -07:00
|
|
|
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
2020-10-17 14:33:26 -04:00
|
|
|
offset = int(offset)
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-08-03 09:12:04 +03:00
|
|
|
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
|
|
|
broadcasted_iota(np.int32, shape, 1))
|
2021-03-03 20:27:33 -05:00
|
|
|
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
|
2020-10-23 07:34:32 -07:00
|
|
|
"""This utility function exists for creating Kronecker delta arrays."""
|
2020-10-17 14:33:26 -04:00
|
|
|
axes = tuple(map(int, axes))
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-01-31 15:34:20 +02:00
|
|
|
base_shape = tuple(np.take(shape, axes)) # type: ignore[arg-type]
|
2021-03-03 20:27:33 -05:00
|
|
|
iotas = [broadcasted_iota(np.uint32, base_shape, i)
|
|
|
|
for i in range(len(base_shape))]
|
|
|
|
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
2021-03-28 10:32:02 -07:00
|
|
|
result = convert_element_type_p.bind(_reduce(operator.and_, eyes),
|
|
|
|
new_dtype=dtype, weak_type=False)
|
2021-03-03 20:27:33 -05:00
|
|
|
return broadcast_in_dim(result, shape, axes)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
|
2020-10-23 07:34:32 -07:00
|
|
|
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
|
2020-10-17 14:33:26 -04:00
|
|
|
offset = int(offset)
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2021-08-03 09:12:04 +03:00
|
|
|
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
|
|
|
broadcasted_iota(np.int32, shape, 1))
|
2021-03-03 20:27:33 -05:00
|
|
|
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
def stop_gradient(x: T) -> T:
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Stops gradient computation.
|
|
|
|
|
|
|
|
Operationally ``stop_gradient`` is the identity function, that is, it returns
|
|
|
|
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
|
|
|
|
gradients during forward or reverse-mode automatic differentiation. If there
|
|
|
|
are multiple nested gradient computations, ``stop_gradient`` stops gradients
|
|
|
|
for all of them.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> jax.grad(lambda x: x**2)(3.)
|
2021-11-23 15:04:08 -08:00
|
|
|
DeviceArray(6., dtype=float32, weak_type=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
|
2021-11-23 15:04:08 -08:00
|
|
|
DeviceArray(0., dtype=float32, weak_type=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
|
2021-11-23 15:04:08 -08:00
|
|
|
DeviceArray(2., dtype=float32, weak_type=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
2021-11-23 15:04:08 -08:00
|
|
|
DeviceArray(0., dtype=float32, weak_type=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
|
|
|
def stop(x):
|
|
|
|
if (dtypes.issubdtype(_dtype(x), np.floating) or
|
|
|
|
dtypes.issubdtype(_dtype(x), np.complexfloating)):
|
|
|
|
return ad_util.stop_gradient_p.bind(x)
|
|
|
|
else:
|
|
|
|
return x # only bind primitive on inexact dtypes, to avoid some staging
|
|
|
|
return tree_map(stop, x)
|
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
def reduce_precision(operand: Union[float, Array],
|
|
|
|
exponent_bits: int,
|
|
|
|
mantissa_bits: int) -> Array:
|
2021-11-23 18:57:45 -08:00
|
|
|
"""Wraps XLA's `ReducePrecision
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reduceprecision>`_
|
|
|
|
operator.
|
|
|
|
"""
|
|
|
|
exponent_bits = core.concrete_or_error(
|
|
|
|
operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision")
|
|
|
|
mantissa_bits = core.concrete_or_error(
|
|
|
|
operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision")
|
|
|
|
return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits)
|
|
|
|
|
2021-11-24 11:15:04 -08:00
|
|
|
def squeeze(array: Array, dimensions: Sequence[int]) -> Array:
|
2021-11-23 18:57:45 -08:00
|
|
|
"""Squeeze any number of size 1 dimensions from an array."""
|
|
|
|
ndim = np.ndim(array)
|
|
|
|
dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions))
|
|
|
|
if not dimensions:
|
|
|
|
return array
|
|
|
|
return squeeze_p.bind(array, dimensions=dimensions)
|
|
|
|
|
2021-11-24 11:15:04 -08:00
|
|
|
def expand_dims(array: Array, dimensions: Sequence[int]) -> Array:
|
2021-11-23 18:57:45 -08:00
|
|
|
"""Insert any number of size 1 dimensions into an array."""
|
|
|
|
ndim_out = np.ndim(array) + len(dimensions)
|
|
|
|
dims_set = frozenset(canonicalize_axis(i, ndim_out) for i in dimensions)
|
|
|
|
result_shape = list(np.shape(array))
|
|
|
|
for i in sorted(dims_set):
|
|
|
|
result_shape.insert(i, 1)
|
|
|
|
broadcast_dims = [i for i in range(ndim_out) if i not in dims_set]
|
|
|
|
return broadcast_in_dim(array, result_shape, broadcast_dims)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
### convenience wrappers around traceables
|
|
|
|
|
|
|
|
def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
|
|
|
|
shape: Optional[Shape] = None) -> Array:
|
|
|
|
"""Create a full array like np.full based on the example array `x`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: example array-like, used for shape and dtype information.
|
|
|
|
fill_value: a scalar value to fill the entries of the output array.
|
|
|
|
dtype: optional, a dtype parameter for the output ndarray.
|
|
|
|
shape: optional, a shape parameter for the output ndarray.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An ndarray with the same shape as `x` with its entries set equal to
|
|
|
|
`fill_value`, similar to the output of np.full.
|
|
|
|
"""
|
|
|
|
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
|
2021-02-08 13:37:25 -08:00
|
|
|
weak_type = dtype is None and dtypes.is_weakly_typed(x)
|
|
|
|
dtype = dtype or _dtype(x)
|
2021-03-28 10:32:02 -07:00
|
|
|
return full(fill_shape, _convert_element_type(fill_value, dtype, weak_type))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def collapse(operand: Array, start_dimension: int,
|
|
|
|
stop_dimension: int) -> Array:
|
|
|
|
"""Collapses dimensions of an array into a single dimension.
|
|
|
|
|
|
|
|
For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
|
|
|
|
``collapse(operand, 0, 2).shape == [6, 4]``. The elements of the collapsed
|
|
|
|
dimension are laid out major-to-minor, i.e., with the lowest-numbered
|
|
|
|
dimension as the slowest varying dimension.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an input array.
|
|
|
|
start_dimension: the start of the dimensions to collapse (inclusive).
|
|
|
|
stop_dimension: the end of the dimensions to collapse (exclusive).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array where dimensions ``[start_dimension, stop_dimension)`` have been
|
|
|
|
collapsed (raveled) into a single dimension.
|
|
|
|
"""
|
|
|
|
lo, hi = start_dimension, stop_dimension
|
|
|
|
size = prod(operand.shape[lo:hi])
|
|
|
|
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
|
|
|
return reshape(operand, new_shape)
|
|
|
|
|
|
|
|
|
|
|
|
def batch_matmul(lhs: Array, rhs: Array,
|
|
|
|
precision: PrecisionLike = None) -> Array:
|
|
|
|
"""Batch matrix multiplication."""
|
|
|
|
if _min(lhs.ndim, rhs.ndim) < 2:
|
|
|
|
raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}'
|
|
|
|
.format(lhs.ndim, rhs.ndim))
|
|
|
|
if lhs.ndim != rhs.ndim:
|
|
|
|
raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}'
|
|
|
|
.format(lhs.ndim, rhs.ndim))
|
|
|
|
lhs_contract = (lhs.ndim - 1,)
|
|
|
|
rhs_contract = (rhs.ndim - 2,)
|
|
|
|
batch = tuple(range(lhs.ndim - 2))
|
|
|
|
return dot_general(lhs, rhs, ((lhs_contract, rhs_contract), (batch, batch)),
|
|
|
|
precision=precision)
|
|
|
|
|
|
|
|
|
|
|
|
# These functions also exist in the XLA client library, but we treat them
|
|
|
|
# as non-primitive to maintain a smaller set of autodiff primitives.
|
|
|
|
|
|
|
|
def square(x: Array) -> Array:
|
|
|
|
r"""Elementwise square: :math:`x^2`."""
|
|
|
|
return integer_pow(x, 2)
|
|
|
|
|
|
|
|
def reciprocal(x: Array) -> Array:
|
|
|
|
r"""Elementwise reciprocal: :math:`1 \over x`."""
|
|
|
|
return integer_pow(x, -1)
|
|
|
|
|
|
|
|
def _upcast_fp16_for_computation(f):
|
|
|
|
@functools.wraps(f)
|
|
|
|
def f_wrapped(x):
|
|
|
|
dtype = _dtype(x)
|
|
|
|
if dtype == np.float16 or dtype == dtypes.bfloat16:
|
|
|
|
return convert_element_type(
|
|
|
|
f(convert_element_type(x, np.float32)), dtype)
|
|
|
|
return f(x)
|
|
|
|
|
|
|
|
return f_wrapped
|
|
|
|
|
|
|
|
def tan(x: Array) -> Array:
|
|
|
|
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
|
2020-12-07 17:34:27 -05:00
|
|
|
return tan_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def asin(x: Array) -> Array:
|
|
|
|
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
|
2020-12-07 17:34:27 -05:00
|
|
|
return asin_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def acos(x: Array) -> Array:
|
|
|
|
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
|
2020-12-07 17:34:27 -05:00
|
|
|
return acos_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def atan(x: Array) -> Array:
|
|
|
|
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
|
2020-12-07 17:34:27 -05:00
|
|
|
return atan_p.bind(x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def sinh(x: Array) -> Array:
|
|
|
|
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`."""
|
|
|
|
return sinh_p.bind(x)
|
|
|
|
|
|
|
|
def cosh(x: Array) -> Array:
|
|
|
|
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`."""
|
|
|
|
return cosh_p.bind(x)
|
|
|
|
|
|
|
|
def asinh(x: Array) -> Array:
|
|
|
|
r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
|
|
|
|
return asinh_p.bind(x)
|
|
|
|
|
|
|
|
def acosh(x: Array) -> Array:
|
|
|
|
r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
|
|
|
|
return acosh_p.bind(x)
|
|
|
|
|
|
|
|
def atanh(x: Array) -> Array:
|
|
|
|
r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
|
|
|
|
return atanh_p.bind(x)
|
|
|
|
|
|
|
|
|
|
|
|
# Add some methods to ShapedArray that rely on lax primitives
|
|
|
|
|
|
|
|
ShapedArray.broadcast = core.aval_method(broadcast)
|
|
|
|
ShapedArray.transpose = core.aval_method(transpose) # clobbered by lax_numpy
|
|
|
|
ShapedArray.reshape = core.aval_method(reshape) # clobbered by lax_numpy
|
|
|
|
|
|
|
|
def _iter(tracer):
|
|
|
|
if tracer.ndim == 0:
|
|
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
|
|
else:
|
|
|
|
n = int(tracer.shape[0])
|
|
|
|
# return (index_in_dim(tracer, i, keepdims=False) for i in range(n))
|
2021-11-23 16:34:33 -08:00
|
|
|
return iter([slicing.index_in_dim(tracer, i, keepdims=False)
|
|
|
|
for i in range(n)])
|
2020-10-17 14:33:26 -04:00
|
|
|
ShapedArray._iter = staticmethod(_iter)
|
|
|
|
|
|
|
|
# Add some ad handlers that use (or could use) lax primitives
|
|
|
|
|
2021-12-09 22:42:04 -05:00
|
|
|
def zeros_like_array(x: Array) -> Array:
|
2020-10-17 14:33:26 -04:00
|
|
|
return full_like(x, 0)
|
|
|
|
|
2020-11-02 22:39:45 +01:00
|
|
|
for t in itertools.chain(
|
|
|
|
dtypes.python_scalar_dtypes.keys(), array_types,
|
2021-11-22 08:22:10 -08:00
|
|
|
device_array.device_array_types,
|
|
|
|
[pxla.ShardedDeviceArray, pxla.pmap_lib.ShardedDeviceArray]):
|
2020-10-17 14:33:26 -04:00
|
|
|
ad_util.jaxval_adders[t] = add
|
2021-11-22 08:22:10 -08:00
|
|
|
ad_util.jaxval_zeros_likers[device_array._DeviceArray] = zeros_like_array
|
|
|
|
ad_util.jaxval_zeros_likers[device_array.Buffer] = zeros_like_array
|
2020-10-17 14:33:26 -04:00
|
|
|
ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array
|
2021-08-13 06:05:10 -07:00
|
|
|
ad_util.jaxval_zeros_likers[pxla.pmap_lib.ShardedDeviceArray] = zeros_like_array
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
### primitives
|
|
|
|
|
|
|
|
|
|
|
|
_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
|
|
|
|
_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
_strip_weak_type = lambda *args, **_: False
|
2021-03-09 13:48:15 -08:00
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
|
|
|
|
if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes):
|
|
|
|
msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
|
|
|
typename = str(np.dtype(aval.dtype).name)
|
|
|
|
accepted_typenames = (t.__name__ for t in accepted_dtypes)
|
|
|
|
raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames)))
|
|
|
|
return result_dtype(aval.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
def unop(result_dtype, accepted_dtypes, name, translation_rule=None):
|
|
|
|
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
|
2021-02-08 13:37:25 -08:00
|
|
|
weak_type_rule = partial(_naryop_weak_type_rule, name)
|
2020-10-17 14:33:26 -04:00
|
|
|
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
|
2021-10-18 18:06:48 -07:00
|
|
|
translation_rule=translation_rule,
|
|
|
|
weak_type_rule=weak_type_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.defvectorized(prim)
|
|
|
|
masking.defvectorized(prim)
|
|
|
|
return prim
|
|
|
|
standard_unop = partial(unop, _identity)
|
|
|
|
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
|
|
|
|
|
|
|
|
|
|
|
|
def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
|
|
|
|
aval_dtypes = [aval.dtype for aval in avals]
|
|
|
|
for i, (aval_dtype, types) in enumerate(zip(aval_dtypes, accepted_dtypes)):
|
|
|
|
if not any(dtypes.issubdtype(aval_dtype, t) for t in types):
|
|
|
|
if aval_dtype is dtypes.float0:
|
|
|
|
raise TypeError(
|
|
|
|
f"Called {name} with a float0 at position {i}. "
|
|
|
|
"float0s do not support any operations by design, because they "
|
|
|
|
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
|
|
|
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
|
|
|
"to cast a float0 array to a regular zeros array. \n"
|
|
|
|
"If you didn't expect to get a float0 you might have accidentally "
|
|
|
|
"taken a gradient with respect to an integer argument.")
|
|
|
|
else:
|
|
|
|
msg = ('{} does not accept dtype {} at position {}. '
|
|
|
|
'Accepted dtypes at position {} are subtypes of {}.')
|
|
|
|
typename = str(np.dtype(aval_dtype).name)
|
|
|
|
typenames = ', '.join(t.__name__ for t in types)
|
|
|
|
raise TypeError(msg.format(name, typename, i, i, typenames))
|
|
|
|
_check_same_dtypes(name, False, *aval_dtypes)
|
|
|
|
return result_dtype(*avals)
|
|
|
|
|
|
|
|
|
|
|
|
def _broadcasting_shape_rule(name, *avals):
|
2020-11-13 14:55:04 -08:00
|
|
|
shapes = [aval.shape for aval in avals if aval.shape]
|
|
|
|
if not shapes:
|
2020-10-17 14:33:26 -04:00
|
|
|
return ()
|
|
|
|
if len({len(shape) for shape in shapes}) != 1:
|
2021-08-17 14:37:27 -07:00
|
|
|
msg = '{}: arrays must have same number of dimensions, got {}.'
|
2020-10-17 14:33:26 -04:00
|
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
2021-12-10 23:22:11 -08:00
|
|
|
result_shape = []
|
|
|
|
for ds in zip(*shapes):
|
|
|
|
if all(d is ds[0] for d in ds):
|
|
|
|
# if all axes are identical objects, the resulting size is the object
|
|
|
|
result_shape.append(ds[0])
|
|
|
|
else:
|
|
|
|
# if all dims are equal (or 1), the result is the non-1 size
|
|
|
|
non_1s = {d for d in ds if not core.symbolic_equal_dim(d, 1)}
|
|
|
|
if len(non_1s) > 1:
|
|
|
|
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
|
|
|
|
f'{", ".join(map(str, map(tuple, shapes)))}.')
|
|
|
|
result_shape.append(non_1s.pop() if non_1s else 1)
|
|
|
|
return tuple(result_shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
def _naryop_weak_type_rule(name, *avals, **kwargs):
|
|
|
|
if any(aval.dtype is dtypes.float0 for aval in avals):
|
|
|
|
pos = next(i for i, aval in enumerate(avals) if aval.dtype is dtypes.float0)
|
|
|
|
raise TypeError(
|
|
|
|
f"Called {name} with a float0 at position {pos}. "
|
|
|
|
"float0s do not support any operations by design, because they "
|
|
|
|
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
|
|
|
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
|
|
|
"to cast a float0 array to a regular zeros array. \n"
|
|
|
|
"If you didn't expect to get a float0 you might have accidentally "
|
|
|
|
"taken a gradient with respect to an integer argument.")
|
|
|
|
return all(aval.weak_type for aval in avals)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def naryop(result_dtype, accepted_dtypes, name, translation_rule=None):
|
|
|
|
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
|
|
|
|
shape_rule = partial(_broadcasting_shape_rule, name)
|
2021-02-08 13:37:25 -08:00
|
|
|
weak_type_rule = partial(_naryop_weak_type_rule, name)
|
2020-10-17 14:33:26 -04:00
|
|
|
prim = standard_primitive(shape_rule, dtype_rule, name,
|
2021-02-08 13:37:25 -08:00
|
|
|
translation_rule=translation_rule,
|
|
|
|
weak_type_rule=weak_type_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.defbroadcasting(prim)
|
|
|
|
masking.defnaryop(prim)
|
|
|
|
return prim
|
|
|
|
standard_naryop = partial(naryop, _input_dtype)
|
|
|
|
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _broadcast_translate(op, ctx, avals_in, avals_out, *args):
|
|
|
|
"""Variant of _standard_translate that performs explicit broadcasting.
|
|
|
|
|
|
|
|
Not all XLA library functions perform their own broadcasting."""
|
|
|
|
aval_out, = avals_out
|
|
|
|
broadcasted_args = []
|
|
|
|
for aval_in, arg in zip(avals_in, args):
|
|
|
|
if aval_out.shape != aval_in.shape:
|
|
|
|
bcast_dims = tuple(range(len(aval_out.shape) - len(aval_in.shape),
|
|
|
|
len(aval_out.shape)))
|
|
|
|
arg = xops.BroadcastInDim(arg, aval_out.shape, bcast_dims)
|
|
|
|
broadcasted_args.append(arg)
|
|
|
|
return [op(*broadcasted_args)]
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
# Like autograd.numpy.numpy_vjps.unbroadcast, this utility handles transposition
|
|
|
|
# involving linear primitives with implicit broadcasting.
|
|
|
|
def _unbroadcast(aval, x):
|
|
|
|
if not isinstance(aval, ShapedArray):
|
|
|
|
raise TypeError("transpose with implicit broadcasting of unshaped values")
|
|
|
|
x_shape = np.shape(x)
|
2021-06-12 11:42:15 +03:00
|
|
|
if core.symbolic_equal_shape(aval.shape, x_shape):
|
2020-10-17 14:33:26 -04:00
|
|
|
return x
|
2021-04-15 15:16:29 -07:00
|
|
|
assert not aval.shape or len(x_shape) == len(aval.shape)
|
|
|
|
if not aval.shape:
|
|
|
|
return _reduce_sum(x, list(range(len(x_shape))))
|
|
|
|
else:
|
2021-06-12 11:42:15 +03:00
|
|
|
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.symbolic_equal_dim(a, b)]
|
2021-04-15 15:16:29 -07:00
|
|
|
if config.jax_enable_checks: assert all(aval.shape[i] == 1 for i in dims)
|
|
|
|
return reshape(_reduce_sum(x, dims), aval.shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
def _maybe_broadcast(target_shape, x):
|
2020-10-17 14:33:26 -04:00
|
|
|
x_shape = np.shape(x)
|
2021-06-12 11:42:15 +03:00
|
|
|
if core.symbolic_equal_shape(x_shape, target_shape):
|
2021-04-15 15:16:29 -07:00
|
|
|
return x
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-06-12 11:42:15 +03:00
|
|
|
dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape)) if core.symbolic_equal_dim(a, b)]
|
2021-04-15 15:16:29 -07:00
|
|
|
squeeze_shape = [x_shape[i] for i in dims]
|
|
|
|
return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def broadcast_mhlo(
|
|
|
|
aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
|
|
|
|
args: Sequence[ir.Value]) -> Sequence[ir.Value]:
|
|
|
|
"""Broadcasts MHLO values with broadcast-compatible shapes to the same shape.
|
|
|
|
"""
|
|
|
|
out = []
|
|
|
|
for aval, arg in zip(avals, args):
|
|
|
|
if aval.shape != aval_out.shape:
|
|
|
|
assert len(aval.shape) <= len(aval_out.shape), (aval, aval_out)
|
|
|
|
dims = mlir.dense_int_elements(
|
|
|
|
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
|
|
|
|
arg = mhlo.BroadcastInDimOp(
|
|
|
|
mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg,
|
|
|
|
dims).result
|
|
|
|
out.append(arg)
|
|
|
|
return out
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _nary_lower_mhlo(op: Callable, ctx,
|
2021-11-23 18:57:45 -08:00
|
|
|
*args: Union[ir.Value, Sequence[ir.Value]],
|
|
|
|
explicit_type=False, **params):
|
|
|
|
"""Lowers an elementwise operator to its MHLO/CHLO equivalent.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
explicit_type: does the MHLO/CHLO operator require its output type to be
|
|
|
|
provided?
|
|
|
|
"""
|
|
|
|
del params
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
broadcasted_args = broadcast_mhlo(aval_out, ctx.avals_in, args)
|
2021-11-23 18:57:45 -08:00
|
|
|
if explicit_type:
|
|
|
|
return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results
|
|
|
|
else:
|
|
|
|
return op(*broadcasted_args).results
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
_float = {np.floating}
|
|
|
|
_complex = {np.complexfloating}
|
|
|
|
_complex_elem_types = {np.float32, np.float64}
|
|
|
|
_int = {np.integer}
|
|
|
|
_bool = {np.bool_}
|
|
|
|
|
|
|
|
_num = _int | _float | _complex
|
|
|
|
_any = _int | _float | _complex | _bool
|
|
|
|
_bool_or_int = _int | _bool
|
|
|
|
|
|
|
|
neg_p = standard_unop(_num, 'neg')
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(neg_p, partial(_nary_lower_mhlo, mhlo.NegOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _sign_translation_rule(ctx, avals_in, avals_out, x):
|
|
|
|
c = ctx.builder
|
|
|
|
x_aval, = avals_in
|
|
|
|
dtype = x_aval.dtype
|
2020-10-17 14:33:26 -04:00
|
|
|
if dtypes.issubdtype(dtype, np.unsignedinteger):
|
2021-10-19 08:40:15 -07:00
|
|
|
zero = xops.Constant(c, np.array(0, dtype=dtype))
|
|
|
|
return [xops.Select(
|
|
|
|
xops.Eq(x, zero),
|
|
|
|
xops.Broadcast(zero, x_aval.shape),
|
|
|
|
xops.Broadcast(xops.Constant(c, np.array(1, dtype=dtype)),
|
|
|
|
x_aval.shape))]
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Sign(x)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
|
|
|
|
ad.defjvp_zero(sign_p)
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _sign_lower_mhlo(ctx, x):
|
|
|
|
x_aval, = ctx.avals_in
|
2021-11-23 18:57:45 -08:00
|
|
|
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
|
|
|
|
return mhlo.SelectOp(
|
|
|
|
mhlo.CompareOp(
|
|
|
|
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))),
|
|
|
|
x, mlir.full_like_aval(0, x_aval), ir.StringAttr.get("EQ"),
|
|
|
|
ir.StringAttr.get("UNSIGNED")).result,
|
|
|
|
mlir.full_like_aval(0, x_aval),
|
|
|
|
mlir.full_like_aval(1, x_aval)).results
|
|
|
|
return mhlo.SignOp(x).results
|
|
|
|
|
|
|
|
mlir.register_lowering(sign_p, _sign_lower_mhlo)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
_nextafter_translation_rule = partial(_broadcast_translate, xops.NextAfter)
|
2021-04-15 15:16:29 -07:00
|
|
|
nextafter_p = standard_naryop([_float, _float], 'nextafter',
|
|
|
|
translation_rule=_nextafter_translation_rule)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(nextafter_p, partial(_nary_lower_mhlo, chlo.NextAfterOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
floor_p = standard_unop(_float, 'floor')
|
|
|
|
ad.defjvp_zero(floor_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(floor_p, partial(_nary_lower_mhlo, mhlo.FloorOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
ceil_p = standard_unop(_float, 'ceil')
|
|
|
|
ad.defjvp_zero(ceil_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(ceil_p, partial(_nary_lower_mhlo, mhlo.CeilOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2020-11-23 15:33:04 +01:00
|
|
|
def _round_to_nearest_even(x):
|
|
|
|
half = _const(x, 0.5)
|
|
|
|
one = _const(x, 1)
|
|
|
|
round_val = floor(x)
|
|
|
|
fraction = x - round_val
|
|
|
|
nearest_even_int = sub(
|
|
|
|
round_val, mul(_const(x, 2), floor(mul(half, x))))
|
|
|
|
is_odd = eq(nearest_even_int, one)
|
|
|
|
return select(
|
|
|
|
bitwise_or(gt(fraction, half),
|
|
|
|
bitwise_and(eq(fraction, half), is_odd)),
|
|
|
|
add(round_val, one), round_val)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _round_translation_rule(ctx, avals_in, avals_out, x, *, rounding_method):
|
2020-11-23 15:33:04 +01:00
|
|
|
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Round(x)]
|
2020-11-23 15:33:04 +01:00
|
|
|
else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
2021-10-18 18:06:48 -07:00
|
|
|
rounding_fun = xla.lower_fun(_round_to_nearest_even, multiple_results=False,
|
|
|
|
new_style=True)
|
|
|
|
return rounding_fun(ctx, avals_in, avals_out, x)
|
2020-11-23 15:33:04 +01:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
round_p = standard_unop(_float, 'round')
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(round_p, _round_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.defjvp_zero(round_p)
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _round_lower(ctx, x, *, rounding_method):
|
2021-11-23 18:57:45 -08:00
|
|
|
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
|
|
|
|
return mhlo.RoundOp(x).results
|
|
|
|
else:
|
|
|
|
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
2021-12-16 08:34:10 -08:00
|
|
|
round_nearest = mlir.cache_lowering(mlir.lower_fun(_round_to_nearest_even,
|
|
|
|
multiple_results=False))
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
return round_nearest(ctx, x)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(round_p, _round_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
|
|
|
|
ad.defjvp_zero(is_finite_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(is_finite_p, partial(_nary_lower_mhlo, mhlo.IsFiniteOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
exp_p = standard_unop(_float | _complex, 'exp')
|
|
|
|
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
|
|
|
iad.definverse(exp_p, lambda r, x: log(r))
|
|
|
|
# For exp_p it is more efficient to use the reconstructed output for the vjp
|
|
|
|
# rule instead of computing it again from the input.
|
|
|
|
iad.primitive_ivjps[exp_p] = lambda x, y, ct: [[log(y[0])], [ct[0] * y[0]]]
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(exp_p, partial(_nary_lower_mhlo, mhlo.ExpOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
log_p = standard_unop(_float | _complex, 'log')
|
|
|
|
ad.defjvp(log_p, lambda g, x: div(g, x))
|
|
|
|
iad.definverse(log_p, lambda r, x: exp(r))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(log_p, partial(_nary_lower_mhlo, mhlo.LogOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
expm1_p = standard_unop(_float | _complex, 'expm1')
|
|
|
|
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(expm1_p, partial(_nary_lower_mhlo, mhlo.Expm1Op))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
log1p_p = standard_unop(_float | _complex, 'log1p')
|
|
|
|
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(log1p_p, partial(_nary_lower_mhlo, mhlo.Log1pOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
tanh_p = standard_unop(_float | _complex, 'tanh')
|
2020-12-14 11:48:16 -05:00
|
|
|
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
|
|
|
|
sub(_one(x), ans)))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
sin_p = standard_unop(_float | _complex, 'sin')
|
|
|
|
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SinOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
cos_p = standard_unop(_float | _complex, 'cos')
|
|
|
|
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
@partial(xla.lower_fun, multiple_results=False, new_style=True)
|
2020-12-07 17:34:27 -05:00
|
|
|
@_upcast_fp16_for_computation
|
|
|
|
def tan_translation_rule(x):
|
|
|
|
return div(sin(x), cos(x))
|
|
|
|
|
|
|
|
tan_p = standard_unop(_float | _complex, 'tan',
|
|
|
|
translation_rule=tan_translation_rule)
|
2021-05-21 15:07:35 +01:00
|
|
|
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
2020-12-07 17:34:27 -05:00
|
|
|
|
|
|
|
|
|
|
|
def asin_translation_rule(x):
|
|
|
|
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
|
|
|
return mul(_const(x, -1j), asinh(mul(_const(x, 1j), x)))
|
|
|
|
else:
|
|
|
|
return mul(_const(x, 2),
|
|
|
|
atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x))))))
|
|
|
|
|
|
|
|
asin_p = standard_unop(_float | _complex, 'asin',
|
2021-06-03 21:17:20 -07:00
|
|
|
translation_rule=xla.lower_fun(asin_translation_rule,
|
2021-10-18 18:06:48 -07:00
|
|
|
multiple_results=False,
|
|
|
|
new_style=True))
|
2020-12-07 17:34:27 -05:00
|
|
|
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
|
|
|
|
|
|
|
|
|
|
|
|
def acos_translation_rule(x):
|
|
|
|
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
|
|
|
result = mul(_const(x, 1j), acosh(x))
|
|
|
|
# By convention, numpy chooses the branch with positive real part.
|
|
|
|
rpart = real(result)
|
|
|
|
return select(
|
|
|
|
gt(rpart, _const(rpart, 0)),
|
|
|
|
result,
|
|
|
|
neg(result)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return select(
|
|
|
|
ne(x, _const(x, -1.0)),
|
|
|
|
mul(_const(x, 2),
|
|
|
|
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))),
|
|
|
|
full_like(x, np.pi))
|
|
|
|
|
|
|
|
acos_p = standard_unop(_float | _complex, 'acos',
|
2021-06-03 21:17:20 -07:00
|
|
|
translation_rule=xla.lower_fun(acos_translation_rule,
|
2021-10-18 18:06:48 -07:00
|
|
|
multiple_results=False,
|
|
|
|
new_style=True))
|
2020-12-07 17:34:27 -05:00
|
|
|
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
|
|
|
|
|
|
|
|
def atan_translation_rule(x):
|
2021-07-15 16:39:18 -04:00
|
|
|
return atan2(x, _const(x, 1))
|
2020-12-07 17:34:27 -05:00
|
|
|
|
|
|
|
atan_p = standard_unop(_float | _complex, 'atan',
|
2021-06-03 21:17:20 -07:00
|
|
|
translation_rule=xla.lower_fun(atan_translation_rule,
|
2021-10-18 18:06:48 -07:00
|
|
|
multiple_results=False,
|
|
|
|
new_style=True))
|
2020-12-07 17:34:27 -05:00
|
|
|
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
|
|
|
|
|
2021-07-02 16:18:26 -07:00
|
|
|
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.defjvp(atan2_p,
|
2021-04-15 15:16:29 -07:00
|
|
|
lambda g, x, y: g * (y / (square(x) + square(y))),
|
|
|
|
lambda g, x, y: g * -x / (square(x) + square(y)))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(atan2_p, partial(_nary_lower_mhlo, mhlo.Atan2Op))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
sinh_p = standard_unop(_float | _complex, 'sinh')
|
|
|
|
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
|
2021-12-06 15:36:20 -08:00
|
|
|
# TODO(b/209505237): the CHLO lowering of chlo.sinh is less accurate than that
|
|
|
|
# in the XLA client library. Use the fallback path for now.
|
|
|
|
# mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
cosh_p = standard_unop(_float | _complex, 'cosh')
|
|
|
|
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
|
|
|
|
|
|
|
asinh_p = standard_unop(_float | _complex, 'asinh')
|
|
|
|
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
|
|
|
|
|
|
|
|
acosh_p = standard_unop(_float | _complex, 'acosh')
|
|
|
|
ad.defjvp(acosh_p,
|
|
|
|
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
|
|
|
|
|
|
|
|
atanh_p = standard_unop(_float | _complex, 'atanh')
|
|
|
|
ad.defjvp(atanh_p,
|
2020-12-14 11:48:16 -05:00
|
|
|
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
regularized_incomplete_beta_p = standard_naryop(
|
|
|
|
[_float, _float, _float], 'regularized_incomplete_beta',
|
2021-10-18 18:06:48 -07:00
|
|
|
translation_rule=partial(_broadcast_translate,
|
|
|
|
xops.RegularizedIncompleteBeta))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def betainc_gradx(g, a, b, x):
|
|
|
|
lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
|
|
|
|
partial_x = exp((b - 1) * log1p(-x) +
|
|
|
|
(a - 1) * log(x) - lbeta)
|
|
|
|
return partial_x * g
|
|
|
|
|
|
|
|
def betainc_grad_not_implemented(g, a, b, x):
|
|
|
|
raise ValueError("Betainc gradient with respect to a and b not supported.")
|
|
|
|
|
|
|
|
ad.defjvp(regularized_incomplete_beta_p,
|
|
|
|
betainc_grad_not_implemented,
|
|
|
|
betainc_grad_not_implemented,
|
|
|
|
betainc_gradx)
|
|
|
|
|
|
|
|
lgamma_p = standard_unop(_float, 'lgamma')
|
|
|
|
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(lgamma_p, partial(_nary_lower_mhlo, chlo.LgammaOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
digamma_p = standard_unop(_float, 'digamma')
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(digamma_p, partial(_nary_lower_mhlo, chlo.DigammaOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
igamma_p = standard_naryop(
|
|
|
|
[_float, _float], 'igamma',
|
2021-10-18 18:06:48 -07:00
|
|
|
translation_rule=partial(_broadcast_translate, xops.Igamma))
|
2020-10-17 14:33:26 -04:00
|
|
|
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a',
|
2021-10-18 18:06:48 -07:00
|
|
|
translation_rule=partial(_broadcast_translate, xops.IgammaGradA))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def igamma_gradx(g, a, x):
|
2021-04-15 15:16:29 -07:00
|
|
|
return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def igamma_grada(g, a, x):
|
2021-04-15 15:16:29 -07:00
|
|
|
return g * igamma_grad_a(a, x)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
|
|
|
|
|
|
|
|
igammac_p = standard_naryop(
|
|
|
|
[_float, _float], 'igammac',
|
2021-10-18 18:06:48 -07:00
|
|
|
translation_rule=partial(_broadcast_translate, xops.Igammac))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def igammac_gradx(g, a, x):
|
|
|
|
return -igamma_gradx(g, a, x)
|
|
|
|
|
|
|
|
def igammac_grada(g, a, x):
|
|
|
|
return -igamma_grada(g, a, x)
|
|
|
|
|
|
|
|
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
|
|
|
|
|
|
|
|
random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad',
|
2021-10-18 18:06:48 -07:00
|
|
|
translation_rule=partial(_broadcast_translate, xops.RandomGammaGrad))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
|
|
|
|
ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
|
|
|
|
|
|
|
|
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
|
|
|
|
def _bessel_i1e_jvp(g, y, x):
|
|
|
|
eps = dtypes.finfo(_dtype(x)).eps
|
|
|
|
x_is_not_tiny = abs(x) > eps
|
|
|
|
safe_x = select(x_is_not_tiny, x, full_like(x, eps))
|
|
|
|
dy_dx = bessel_i0e(safe_x) - y * (sign(safe_x) + reciprocal(safe_x))
|
|
|
|
dy_dx = select(x_is_not_tiny, dy_dx, full_like(x, 0.5))
|
|
|
|
return g * dy_dx
|
|
|
|
ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
|
|
|
|
|
|
|
|
erf_p = standard_unop(_float, 'erf')
|
|
|
|
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
|
|
|
mul(g, exp(neg(square(x))))))
|
|
|
|
|
|
|
|
erfc_p = standard_unop(_float, 'erfc')
|
2021-05-21 15:07:35 +01:00
|
|
|
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
|
|
|
|
mul(g, exp(neg(square(x))))))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
erf_inv_p = standard_unop(_float, 'erf_inv')
|
|
|
|
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
|
|
|
|
mul(g, exp(square(ans)))))
|
|
|
|
|
|
|
|
real_p = unop(_complex_basetype, _complex, 'real')
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(real_p, partial(_nary_lower_mhlo, mhlo.RealOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
imag_p = unop(_complex_basetype, _complex, 'imag')
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(imag_p, partial(_nary_lower_mhlo, mhlo.ImagOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
|
|
|
def _complex_transpose_rule(t, x, y):
|
|
|
|
assert ad.is_undefined_primal(x) or ad.is_undefined_primal(y)
|
|
|
|
if ad.is_undefined_primal(x) and ad.is_undefined_primal(y):
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(x.aval), ad_util.Zero(y.aval)]
|
|
|
|
else:
|
|
|
|
return [_unbroadcast(x.aval, real(t)), _unbroadcast(y.aval, imag(neg(t)))]
|
|
|
|
elif ad.is_undefined_primal(x):
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(x.aval), None]
|
|
|
|
else:
|
|
|
|
return [_unbroadcast(x.aval, real(t)), None]
|
|
|
|
else:
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return [None, ad_util.Zero(y.aval)]
|
|
|
|
else:
|
|
|
|
return [None, _unbroadcast(y.aval, imag(neg(t)))]
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
_complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
|
|
|
|
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
|
|
|
'complex')
|
2021-04-15 15:16:29 -07:00
|
|
|
ad.deflinear2(complex_p, _complex_transpose_rule)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(complex_p, partial(_nary_lower_mhlo, mhlo.ComplexOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
|
|
|
|
|
2021-11-30 06:08:26 -08:00
|
|
|
def _conj_impl(x, **kw):
|
2021-11-23 18:57:45 -08:00
|
|
|
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
|
|
|
return complex(real(x), -imag(x))
|
|
|
|
else:
|
|
|
|
return complex(x, _zeros(x))
|
|
|
|
|
|
|
|
mlir.register_lowering(conj_p,
|
|
|
|
mlir.lower_fun(_conj_impl, multiple_results=False))
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _conj_transpose_rule(t, x, *, input_dtype):
|
|
|
|
assert ad.is_undefined_primal(x)
|
|
|
|
if dtypes.issubdtype(input_dtype, np.complexfloating):
|
|
|
|
return [conj(t)]
|
|
|
|
else:
|
|
|
|
return [real(t)]
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(conj_p,
|
|
|
|
lambda ctx, avals_in, avals_out, x, **kwargs: [xops.Conj(x)])
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
|
|
|
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
|
|
|
|
|
|
|
abs_p = unop(_complex_basetype, _num, 'abs')
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(abs_p, partial(_nary_lower_mhlo, mhlo.AbsOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _abs_jvp_rule(g, ans, x):
|
|
|
|
if _iscomplex(x):
|
|
|
|
return _maybe_real(mul(g, div(_maybe_conj(x),
|
|
|
|
_replace_zero(convert_element_type(ans, _dtype(x))))))
|
|
|
|
else:
|
|
|
|
return select(ge(x, _zero(x)), g, neg(g))
|
|
|
|
ad.defjvp2(abs_p, _abs_jvp_rule)
|
|
|
|
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
|
|
|
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
|
|
|
|
|
|
|
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
|
|
|
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(sqrt_p, partial(_nary_lower_mhlo, mhlo.SqrtOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
|
|
|
ad.defjvp2(rsqrt_p,
|
|
|
|
lambda g, ans, x:
|
2021-05-21 15:07:35 +01:00
|
|
|
mul(g, mul(_const(x, -0.5), div(ans, x))))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(rsqrt_p, partial(_nary_lower_mhlo, mhlo.RsqrtOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-18 14:55:19 -05:00
|
|
|
cbrt_p = standard_unop(_float, 'cbrt')
|
2021-07-22 14:00:52 -07:00
|
|
|
ad.defjvp2(cbrt_p,
|
|
|
|
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(cbrt_p, partial(_nary_lower_mhlo, mhlo.CbrtOp))
|
2021-07-22 14:00:52 -07:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
|
|
|
|
|
|
|
|
def _pow_jvp_lhs(g, ans, x, y):
|
|
|
|
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
|
2021-04-15 15:16:29 -07:00
|
|
|
return mul(g, jac)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _pow_jvp_rhs(g, ans, x, y):
|
2021-04-15 15:16:29 -07:00
|
|
|
return mul(g, mul(log(_replace_zero(x)), ans))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(pow_p, partial(_nary_lower_mhlo, mhlo.PowOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _integer_pow_dtype_rule(x, *, y):
|
|
|
|
dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x)
|
|
|
|
if y < 0 and dtypes.issubdtype(dtype, np.integer):
|
|
|
|
raise TypeError("Integers cannot be raised to negative powers, got "
|
|
|
|
f"integer_pow({x}, {y})")
|
|
|
|
return dtype
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _integer_pow_translation_rule(ctx, avals_in, avals_out, x, *, y):
|
2021-05-12 05:45:19 -07:00
|
|
|
# This should be kept in sync with the jax2tf translation rule.
|
2021-10-18 18:06:48 -07:00
|
|
|
x_aval, = avals_in
|
2020-10-17 14:33:26 -04:00
|
|
|
if y == 0:
|
2021-10-19 08:40:15 -07:00
|
|
|
one = xla.pyval_to_ir_constant(ctx.builder, np.array(1, dtype=x_aval.dtype))
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Broadcast(one, x_aval.shape)]
|
2020-10-17 14:33:26 -04:00
|
|
|
is_reciprocal = y < 0
|
|
|
|
if is_reciprocal:
|
|
|
|
y = -y
|
|
|
|
acc = None
|
|
|
|
while y > 0:
|
|
|
|
if y & 1:
|
|
|
|
acc = x if acc is None else xops.Mul(acc, x)
|
|
|
|
y >>= 1
|
|
|
|
if y > 0:
|
|
|
|
x = xops.Mul(x, x)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Reciprocal(acc) if is_reciprocal else acc]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _integer_pow_jvp(g, x, *, y):
|
2021-01-19 15:42:40 -08:00
|
|
|
return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
integer_pow_p = standard_primitive(
|
|
|
|
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow',
|
2020-12-15 10:15:49 -08:00
|
|
|
translation_rule=_integer_pow_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.defvectorized(integer_pow_p)
|
|
|
|
masking.defvectorized(integer_pow_p)
|
|
|
|
ad.defjvp(integer_pow_p, _integer_pow_jvp)
|
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def _integer_pow(x, *, y):
|
|
|
|
# This should be kept in sync with the jax2tf translation rule.
|
|
|
|
if y == 0:
|
|
|
|
return full_like(x, 1)
|
|
|
|
is_reciprocal = y < 0
|
|
|
|
if is_reciprocal:
|
|
|
|
y = -y
|
|
|
|
acc = None
|
|
|
|
while y > 0:
|
|
|
|
if y & 1:
|
|
|
|
acc = x if acc is None else mul(acc, x)
|
|
|
|
y >>= 1
|
|
|
|
if y > 0:
|
|
|
|
# We don't call square because it calls integer_pow.
|
|
|
|
x = mul(x, x)
|
|
|
|
return div(full_like(acc, 1), acc) if is_reciprocal else acc
|
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
mlir.register_lowering(
|
|
|
|
integer_pow_p,
|
|
|
|
mlir.cache_lowering(mlir.lower_fun(_integer_pow, multiple_results=False)))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
|
|
|
|
|
|
|
not_p = standard_unop(_bool_or_int, 'not')
|
|
|
|
ad.defjvp_zero(not_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(not_p, partial(_nary_lower_mhlo, mhlo.NotOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
|
|
|
ad.defjvp_zero(and_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(and_p, partial(_nary_lower_mhlo, mhlo.AndOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
|
|
|
|
ad.defjvp_zero(or_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(or_p, partial(_nary_lower_mhlo, mhlo.OrOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
|
|
|
ad.defjvp_zero(xor_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(xor_p, partial(_nary_lower_mhlo, mhlo.XorOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
population_count_p = standard_unop(_int, 'population_count')
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(population_count_p,
|
|
|
|
partial(_nary_lower_mhlo, mhlo.PopulationCountOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-03-19 22:35:31 -07:00
|
|
|
clz_p = standard_unop(_int, 'clz')
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(clz_p, partial(_nary_lower_mhlo, mhlo.ClzOp))
|
2021-03-19 22:35:31 -07:00
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
def _add_jvp(primals, tangents):
|
|
|
|
x, y = primals
|
|
|
|
xdot, ydot = tangents
|
|
|
|
primal_out = add(x, y)
|
|
|
|
if type(xdot) is type(ydot) is ad_util.Zero:
|
|
|
|
return primal_out, ad_util.Zero.from_value(primal_out)
|
|
|
|
if type(xdot) is ad_util.Zero:
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, ydot)
|
|
|
|
elif type(ydot) is ad_util.Zero:
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, xdot)
|
|
|
|
else:
|
|
|
|
return primal_out, add(xdot, ydot)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _add_transpose(t, x, y):
|
2021-04-15 15:16:29 -07:00
|
|
|
# Morally the following assertion is true, but because we instantiate zeros in
|
|
|
|
# some places (e.g. in custom_jvp) it may not always hold. For example, see
|
|
|
|
# api_test.py's CustomJVPTest.test_jaxpr_zeros.
|
2020-10-17 14:33:26 -04:00
|
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
2021-04-15 15:16:29 -07:00
|
|
|
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
|
|
|
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
|
|
|
else:
|
|
|
|
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _add_inverse(r, x, y):
|
|
|
|
xr = r - y
|
|
|
|
yr = r - x
|
|
|
|
return xr, yr
|
2021-04-15 15:16:29 -07:00
|
|
|
|
2021-07-26 13:44:57 +01:00
|
|
|
# TODO(slebedev): Why does mypy fail to infer the type here?
|
|
|
|
add_p: Primitive = standard_naryop([_num, _num], 'add')
|
2021-04-15 15:16:29 -07:00
|
|
|
ad.primitive_jvps[add_p] = _add_jvp
|
|
|
|
ad.primitive_transposes[add_p] = _add_transpose
|
2020-10-17 14:33:26 -04:00
|
|
|
iad.definverse(add_p, _add_inverse)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
def _sub_jvp(primals, tangents):
|
|
|
|
x, y = primals
|
|
|
|
xdot, ydot = tangents
|
|
|
|
primal_out = sub(x, y)
|
|
|
|
if type(xdot) is type(ydot) is ad_util.Zero:
|
|
|
|
return primal_out, ad_util.Zero.from_value(primal_out)
|
|
|
|
if type(xdot) is ad_util.Zero:
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot))
|
|
|
|
elif type(ydot) is ad_util.Zero:
|
|
|
|
return primal_out, _maybe_broadcast(primal_out.shape, xdot)
|
|
|
|
else:
|
|
|
|
return primal_out, sub(xdot, ydot)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _sub_transpose(t, x, y):
|
2021-04-15 15:16:29 -07:00
|
|
|
# Morally the following assertion is true, but see the comment in add_p's
|
|
|
|
# transpose rule.
|
2020-10-17 14:33:26 -04:00
|
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
2021-04-15 15:16:29 -07:00
|
|
|
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
|
|
|
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
2020-11-28 09:13:21 -08:00
|
|
|
if type(t) is ad_util.Zero:
|
2021-04-15 15:16:29 -07:00
|
|
|
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
2020-11-28 09:13:21 -08:00
|
|
|
else:
|
2021-04-15 15:16:29 -07:00
|
|
|
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, neg(t))]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
sub_p = standard_naryop([_num, _num], 'sub')
|
2021-04-15 15:16:29 -07:00
|
|
|
ad.primitive_jvps[sub_p] = _sub_jvp
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_transposes[sub_p] = _sub_transpose
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-04-15 15:16:29 -07:00
|
|
|
|
|
|
|
def _mul_transpose(ct, x, y):
|
|
|
|
assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y)
|
|
|
|
if ad.is_undefined_primal(x):
|
|
|
|
if type(ct) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(x.aval), None]
|
|
|
|
else:
|
|
|
|
return [_unbroadcast(x.aval, mul(ct, y)), None]
|
|
|
|
else:
|
|
|
|
if type(ct) is ad_util.Zero:
|
|
|
|
return [None, ad_util.Zero(y.aval)]
|
|
|
|
else:
|
|
|
|
return [None, _unbroadcast(y.aval, mul(x, ct))]
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _mul_inverse(r, x, y):
|
|
|
|
xr = r / y
|
|
|
|
yr = r / x
|
|
|
|
return xr, yr
|
2021-04-15 15:16:29 -07:00
|
|
|
|
|
|
|
mul_p = standard_naryop([_num, _num], 'mul')
|
|
|
|
ad.defjvp(mul_p,
|
|
|
|
lambda xdot, x, y: mul(xdot, y),
|
|
|
|
lambda ydot, x, y: mul(x, ydot))
|
|
|
|
ad.primitive_transposes[mul_p] = _mul_transpose
|
2020-10-17 14:33:26 -04:00
|
|
|
iad.definverse(mul_p, _mul_inverse)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _div_transpose_rule(cotangent, x, y):
|
|
|
|
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
2021-04-15 15:16:29 -07:00
|
|
|
if type(cotangent) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(x.aval), None]
|
|
|
|
else:
|
|
|
|
return [_unbroadcast(x.aval, div(cotangent, y)), None]
|
2020-10-17 14:33:26 -04:00
|
|
|
div_p = standard_naryop([_num, _num], 'div')
|
|
|
|
ad.defjvp(div_p,
|
2021-04-15 15:16:29 -07:00
|
|
|
lambda g, x, y: div(g, y),
|
|
|
|
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_transposes[div_p] = _div_transpose_rule
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(div_p, partial(_nary_lower_mhlo, mhlo.DivOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
rem_p = standard_naryop([_num, _num], 'rem')
|
2021-04-15 15:16:29 -07:00
|
|
|
ad.defjvp(
|
|
|
|
rem_p,
|
|
|
|
lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
|
|
|
|
lambda g, x, y: mul(neg(g), floor(div(x, y))))
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(rem_p, partial(_nary_lower_mhlo, mhlo.RemOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _broadcasting_select(c, which, x, y):
|
|
|
|
"""Wrapper around XLA `Select` that broadcasts its arguments."""
|
|
|
|
which_shape, x_shape, y_shape = (
|
|
|
|
c.get_shape(t).dimensions() for t in (which, x, y))
|
|
|
|
out_shape = broadcast_shapes(which_shape, x_shape, y_shape)
|
|
|
|
bcast_dims = lambda shape: tuple(range(len(out_shape) - len(shape),
|
|
|
|
len(out_shape)))
|
|
|
|
which = xops.BroadcastInDim(which, out_shape, bcast_dims(which_shape))
|
|
|
|
x = xops.BroadcastInDim(x, out_shape, bcast_dims(x_shape))
|
|
|
|
y = xops.BroadcastInDim(y, out_shape, bcast_dims(y_shape))
|
|
|
|
return xops.Select(which, x, y)
|
|
|
|
|
|
|
|
|
2021-06-04 11:02:50 +03:00
|
|
|
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
|
|
|
|
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
|
|
|
|
x = _maybe_broadcast(result_shape, x)
|
|
|
|
y = _maybe_broadcast(result_shape, y)
|
|
|
|
rx = real(x)
|
|
|
|
ry = real(y)
|
|
|
|
pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)),
|
|
|
|
lax_cmp_pick_x(rx, ry))
|
|
|
|
return select(pick_x, x, y)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _minmax_translation_rule(ctx, avals_in, avals_out, x, y, *, op_minmax=None,
|
|
|
|
lax_cmp_pick_x=None):
|
|
|
|
x_aval, y_aval = avals_in
|
|
|
|
if dtypes.issubdtype(x_aval.dtype, np.complexfloating):
|
2021-06-04 11:02:50 +03:00
|
|
|
return xla.lower_fun(partial(_minmax_complex_lowering,
|
|
|
|
lax_cmp_pick_x=lax_cmp_pick_x),
|
2021-10-18 18:06:48 -07:00
|
|
|
multiple_results=False,
|
|
|
|
new_style=True)(ctx, avals_in, avals_out, x, y)
|
2021-06-04 11:02:50 +03:00
|
|
|
else:
|
2021-10-18 18:06:48 -07:00
|
|
|
return [op_minmax(x, y)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
max_p: core.Primitive = standard_naryop(
|
|
|
|
[_any, _any], 'max', translation_rule=partial(
|
2021-06-04 11:02:50 +03:00
|
|
|
_minmax_translation_rule, op_minmax=xops.Max, lax_cmp_pick_x=gt))
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.defjvp2(max_p,
|
2021-04-15 15:16:29 -07:00
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
2021-11-30 06:08:26 -08:00
|
|
|
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
min_p: core.Primitive = standard_naryop(
|
|
|
|
[_any, _any], 'min', translation_rule=partial(
|
2021-06-04 11:02:50 +03:00
|
|
|
_minmax_translation_rule, op_minmax=xops.Min, lax_cmp_pick_x=lt))
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.defjvp2(min_p,
|
2021-04-15 15:16:29 -07:00
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
|
|
|
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
2021-11-30 06:08:26 -08:00
|
|
|
mlir.register_lowering(min_p, partial(_nary_lower_mhlo, mlir.min_mhlo))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
|
|
|
ad.defjvp_zero(shift_left_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(shift_left_p, partial(_nary_lower_mhlo, mhlo.ShiftLeftOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
|
|
|
|
ad.defjvp_zero(shift_right_arithmetic_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(shift_right_arithmetic_p,
|
|
|
|
partial(_nary_lower_mhlo, mhlo.ShiftRightArithmeticOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
|
|
|
|
ad.defjvp_zero(shift_right_logical_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(shift_right_logical_p,
|
|
|
|
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _compare_lower_mhlo(direction: str, ctx, x, y):
|
|
|
|
x_aval, y_aval = ctx.avals_in
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), ctx.avals_in,
|
|
|
|
(x, y))
|
2021-11-23 18:57:45 -08:00
|
|
|
if dtypes.issubdtype(x_aval.dtype, np.inexact):
|
|
|
|
compare_type = "FLOAT"
|
|
|
|
elif dtypes.issubdtype(x_aval.dtype, np.signedinteger):
|
|
|
|
compare_type = "SIGNED"
|
|
|
|
else:
|
|
|
|
compare_type = "UNSIGNED"
|
|
|
|
return mhlo.CompareOp(mlir.aval_to_ir_type(aval_out), x, y,
|
|
|
|
ir.StringAttr.get(direction),
|
|
|
|
ir.StringAttr.get(compare_type)).results
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
|
|
|
|
ad.defjvp_zero(eq_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(eq_p, partial(_compare_lower_mhlo, "EQ"))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne')
|
|
|
|
ad.defjvp_zero(ne_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(ne_p, partial(_compare_lower_mhlo, "NE"))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
ge_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ge')
|
|
|
|
ad.defjvp_zero(ge_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(ge_p, partial(_compare_lower_mhlo, "GE"))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
gt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'gt')
|
|
|
|
ad.defjvp_zero(gt_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(gt_p, partial(_compare_lower_mhlo, "GT"))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
le_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'le')
|
|
|
|
ad.defjvp_zero(le_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(le_p, partial(_compare_lower_mhlo, "LE"))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
lt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'lt')
|
|
|
|
ad.defjvp_zero(lt_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(lt_p, partial(_compare_lower_mhlo, "LT"))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
|
2020-10-17 14:33:26 -04:00
|
|
|
return operand.shape
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
|
2020-10-17 14:33:26 -04:00
|
|
|
return new_dtype
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type):
|
|
|
|
return weak_type
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _convert_element_type_translation_rule(ctx, avals_in, avals_out, operand, *,
|
|
|
|
new_dtype, weak_type):
|
|
|
|
aval_in, = avals_in
|
|
|
|
old_dtype = aval_in.dtype
|
2020-10-17 14:33:26 -04:00
|
|
|
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
|
|
operand = xops.Real(operand)
|
2021-10-19 06:48:34 -07:00
|
|
|
new_etype = xla.dtype_to_primitive_type(new_dtype)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.ConvertElementType(operand, new_element_type=new_etype)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type):
|
2020-12-03 11:49:43 -08:00
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
old_dtype = operand.aval.dtype
|
2021-02-08 13:37:25 -08:00
|
|
|
old_weak_type = dtypes.is_weakly_typed(operand)
|
2020-10-17 14:33:26 -04:00
|
|
|
if type(ct) is ad_util.Zero:
|
|
|
|
return [ad_util.Zero(operand.aval)]
|
|
|
|
elif core.primal_dtype_to_tangent_dtype(old_dtype) is dtypes.float0:
|
2021-01-27 15:13:30 -08:00
|
|
|
return [ad_util.Zero(operand.aval.update(dtype=dtypes.float0, weak_type=False))]
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-03-28 10:32:02 -07:00
|
|
|
return [convert_element_type_p.bind(ct, new_dtype=old_dtype,
|
|
|
|
weak_type=old_weak_type)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
|
2020-10-17 14:33:26 -04:00
|
|
|
if core.primal_dtype_to_tangent_dtype(new_dtype) is dtypes.float0:
|
2021-01-27 15:13:30 -08:00
|
|
|
return ad_util.Zero(tangent.aval.update(dtype=dtypes.float0, weak_type=False))
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-03-28 10:32:02 -07:00
|
|
|
return convert_element_type_p.bind(tangent, new_dtype=new_dtype,
|
|
|
|
weak_type=weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-15 21:21:29 -08:00
|
|
|
def _convert_elt_type_folding_rule(consts, eqn):
|
|
|
|
c, = consts
|
|
|
|
if type(c) in core.literalable_types and not np.shape(c):
|
|
|
|
return [np.array(c, eqn.params['new_dtype'])], None
|
|
|
|
else:
|
|
|
|
return [None], eqn
|
|
|
|
|
|
|
|
def _convert_elt_type_fwd_rule(eqn):
|
|
|
|
v, = eqn.invars
|
|
|
|
if (v.aval.dtype == eqn.params['new_dtype'] and
|
|
|
|
v.aval.weak_type == eqn.params['weak_type']):
|
|
|
|
return [v], None
|
|
|
|
else:
|
|
|
|
return [None], eqn
|
|
|
|
|
|
|
|
convert_element_type_p = Primitive('convert_element_type')
|
2021-03-21 19:38:12 -07:00
|
|
|
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
|
2021-03-21 13:39:57 -07:00
|
|
|
convert_element_type_p.def_abstract_eval(
|
|
|
|
partial(standard_abstract_eval, convert_element_type_p,
|
|
|
|
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
|
|
|
_convert_element_type_weak_type_rule, standard_named_shape_rule))
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(convert_element_type_p,
|
|
|
|
_convert_element_type_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
|
|
|
|
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
|
|
|
|
batching.defvectorized(convert_element_type_p)
|
|
|
|
masking.defvectorized(convert_element_type_p)
|
2021-11-15 21:21:29 -08:00
|
|
|
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
|
|
|
|
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-12-07 07:12:08 -08:00
|
|
|
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
|
|
|
|
aval_in, = ctx.avals_in
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
|
|
|
operand = mhlo.RealOp(operand).result
|
2021-12-07 07:12:08 -08:00
|
|
|
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
|
|
|
return [mlir.convert_mhlo(operand, aval_in, aval_out)]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
|
|
|
return operand.shape
|
|
|
|
|
|
|
|
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
2021-06-10 12:42:40 +02:00
|
|
|
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
if dtypes.issubdtype(old_dtype, np.bool_) or dtypes.issubdtype(old_dtype, np.complexfloating):
|
|
|
|
if old_dtype != new_dtype:
|
|
|
|
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) cannot have different destination type ({new_dtype})")
|
|
|
|
if np.dtype(old_dtype).itemsize != np.dtype(new_dtype).itemsize:
|
|
|
|
raise TypeError(f"`bitcast_convert_type` for operand type ({old_dtype}) must have destination type ({new_dtype}) of same size.")
|
2020-10-17 14:33:26 -04:00
|
|
|
return new_dtype
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _bitcast_convert_type_translation_rule(ctx, avals_in, avals_out, operand, *,
|
|
|
|
new_dtype):
|
2021-10-19 06:48:34 -07:00
|
|
|
new_etype = xla.dtype_to_primitive_type(new_dtype)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.BitcastConvertType(operand, new_element_type=new_etype)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
bitcast_convert_type_p = standard_primitive(
|
|
|
|
_bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule,
|
2021-02-08 13:37:25 -08:00
|
|
|
'bitcast_convert_type', _bitcast_convert_type_translation_rule,
|
|
|
|
weak_type_rule=_strip_weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.defjvp_zero(bitcast_convert_type_p)
|
|
|
|
batching.defvectorized(bitcast_convert_type_p)
|
|
|
|
masking.defvectorized(bitcast_convert_type_p)
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
|
|
|
|
|
|
|
|
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-04-21 23:58:34 +02:00
|
|
|
def _validate_preferred_element_type(input_dtype, preferred_element_type):
|
|
|
|
allowed_types = (np.integer, np.floating, np.complexfloating)
|
|
|
|
if any(dtypes.issubdtype(input_dtype, t) and not dtypes.issubdtype(preferred_element_type, t) for t in allowed_types):
|
|
|
|
raise TypeError("`preferred_element_type` and the original type must both be integral, both be floating point, or both complex.")
|
|
|
|
if dtypes.issubdtype(input_dtype, np.signedinteger) and not dtypes.issubdtype(preferred_element_type, np.signedinteger):
|
|
|
|
raise TypeError("`preferred_element_type` must have the same signedness as the original type.")
|
|
|
|
input_bitwidth = np.dtype(input_dtype).itemsize
|
|
|
|
preferred_bitwidth = np.dtype(preferred_element_type).itemsize
|
|
|
|
if preferred_bitwidth < input_bitwidth:
|
|
|
|
raise TypeError("`preferred_element_type` must not be narrower than the original type.")
|
|
|
|
|
2021-11-23 12:35:23 -08:00
|
|
|
def _precision_config(precision):
|
|
|
|
if precision is not None:
|
|
|
|
config = xla_client.PrecisionConfig()
|
|
|
|
if isinstance(precision, tuple):
|
|
|
|
config.operand_precision.extend(precision)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-11-23 12:35:23 -08:00
|
|
|
config.operand_precision.extend((precision, precision))
|
|
|
|
return config
|
|
|
|
return None
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _masked(padded_value, logical_shape, dimensions, value=0):
|
|
|
|
"""
|
|
|
|
Sets all padding to the given value (default is 0) in the given dimensions.
|
|
|
|
All values outside the logical shape are considered padding.
|
|
|
|
"""
|
|
|
|
if len(dimensions) == 0:
|
|
|
|
return padded_value
|
|
|
|
|
|
|
|
masks = [broadcasted_iota(np.int32, padded_value.shape, d) < logical_shape[d]
|
|
|
|
for d in dimensions]
|
|
|
|
mask_intersection = masks[0]
|
|
|
|
for mask in masks[1:]:
|
|
|
|
mask_intersection &= mask
|
|
|
|
return select(mask_intersection, padded_value, full_like(padded_value, value))
|
|
|
|
|
|
|
|
|
2020-12-10 02:29:40 +00:00
|
|
|
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
|
|
|
preferred_element_type: Optional[DType]):
|
2020-10-17 14:33:26 -04:00
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
|
|
|
for d in (lhs_contracting, lhs_batch)):
|
|
|
|
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
|
|
|
|
"less than the number of axes of the lhs value, got "
|
|
|
|
f"lhs_batch of {lhs_batch} and lhs_contracting of {lhs_contracting} "
|
|
|
|
f"for lhs of rank {lhs.ndim}")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, rhs.ndim))
|
|
|
|
for d in (rhs_contracting, rhs_batch)):
|
|
|
|
msg = ("dot_general requires rhs dimension numbers to be nonnegative and "
|
|
|
|
"less than the number of axes of the rhs value, got "
|
|
|
|
f"rhs_batch of {rhs_batch} and rhs_contracting of {rhs_contracting} "
|
|
|
|
f"for rhs of rank {rhs.ndim}")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if len(lhs_batch) != len(rhs_batch):
|
|
|
|
msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
|
|
|
|
"dimensions, got lhs_batch {} and rhs_batch {}.")
|
|
|
|
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
|
|
|
lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch)
|
|
|
|
rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch)
|
|
|
|
if len(lhs_batch_set) != len(lhs_batch):
|
|
|
|
msg = ("dot_general requires lhs batch dimensions to be distinct, got "
|
|
|
|
f"lhs_batch {lhs_batch}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if len(rhs_batch_set) != len(rhs_batch):
|
|
|
|
msg = ("dot_general requires rhs batch dimensions to be distinct, got "
|
|
|
|
f"rhs_batch {rhs_batch}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if len(lhs_contracting_set) != len(lhs_contracting):
|
|
|
|
msg = ("dot_general requires lhs contracting dimensions to be distinct, "
|
|
|
|
f"got lhs_contracting {lhs_contracting}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if len(rhs_contracting_set) != len(rhs_contracting):
|
|
|
|
msg = ("dot_general requires rhs contracting dimensions to be distinct, "
|
|
|
|
f"got rhs_contracting {rhs_contracting}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if lhs_contracting_set & lhs_batch_set:
|
|
|
|
msg = ("dot_general requires lhs batch dimensions to be disjoint from "
|
|
|
|
"contracting dimensions, got lhs_batch {} and lhs_contracting {}.")
|
|
|
|
raise TypeError(msg.format(lhs_batch, lhs_contracting))
|
|
|
|
if rhs_contracting_set & rhs_batch_set:
|
|
|
|
msg = ("dot_general requires rhs batch dimensions to be disjoint from "
|
|
|
|
"contracting dimensions, got rhs_batch {} and rhs_contracting {}.")
|
|
|
|
raise TypeError(msg.format(rhs_batch, rhs_contracting))
|
|
|
|
lhs_batch_shape = np.take(lhs.shape, lhs_batch)
|
|
|
|
rhs_batch_shape = np.take(rhs.shape, rhs_batch)
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
if not core.symbolic_equal_shape(lhs_batch_shape, rhs_batch_shape):
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
|
|
|
|
"to have the same shape, got {} and {}.")
|
|
|
|
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
|
|
|
|
lhs_contracting_shape = np.take(lhs.shape, lhs_contracting)
|
|
|
|
rhs_contracting_shape = np.take(rhs.shape, rhs_contracting)
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
if not core.symbolic_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = ("dot_general requires contracting dimensions to have the same "
|
|
|
|
"shape, got {} and {}.")
|
|
|
|
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
|
|
|
|
|
2021-04-02 13:23:43 -07:00
|
|
|
return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
|
|
|
|
|
|
|
|
def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
|
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
batch_shape = tuple(np.take(lhs_shape, lhs_batch))
|
2020-10-17 14:33:26 -04:00
|
|
|
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
2021-04-02 13:23:43 -07:00
|
|
|
lhs_tensored_shape = tuple(np.delete(lhs_shape, lhs_contract_or_batch))
|
2020-10-17 14:33:26 -04:00
|
|
|
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
|
2021-04-02 13:23:43 -07:00
|
|
|
rhs_tensored_shape = tuple(np.delete(rhs_shape, rhs_contract_or_batch))
|
2020-10-17 14:33:26 -04:00
|
|
|
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
|
|
|
|
2020-12-10 02:29:40 +00:00
|
|
|
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
|
|
|
preferred_element_type: Optional[DType]):
|
|
|
|
input_dtype = naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs)
|
|
|
|
if preferred_element_type is None:
|
|
|
|
return input_dtype
|
2021-04-21 23:58:34 +02:00
|
|
|
_validate_preferred_element_type(input_dtype, preferred_element_type)
|
2020-12-10 02:29:40 +00:00
|
|
|
return preferred_element_type
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
2020-12-10 02:29:40 +00:00
|
|
|
preferred_element_type: Optional[DType],
|
2020-10-17 14:33:26 -04:00
|
|
|
swap_ans=False):
|
|
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
|
|
x_ndim = g.ndim - y.ndim + len(x_batch) + 2 * len(x_contract)
|
|
|
|
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
|
|
|
y_kept = remaining(range(y.ndim), y_contract, y_batch)
|
|
|
|
if swap_ans:
|
|
|
|
ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
|
|
|
|
else:
|
|
|
|
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
|
|
|
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
2021-01-31 15:34:20 +02:00
|
|
|
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) # type: ignore[arg-type]
|
2020-10-17 14:33:26 -04:00
|
|
|
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
2020-12-10 02:29:40 +00:00
|
|
|
return transpose(dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type),
|
2020-10-17 14:33:26 -04:00
|
|
|
tuple(out_axes))
|
|
|
|
|
2020-12-10 02:29:40 +00:00
|
|
|
def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision,
|
|
|
|
preferred_element_type: Optional[DType]):
|
2020-10-17 14:33:26 -04:00
|
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
|
|
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
|
|
|
return _dot_general_transpose_lhs(
|
|
|
|
g, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
2020-12-10 02:29:40 +00:00
|
|
|
preferred_element_type=preferred_element_type,
|
2020-10-17 14:33:26 -04:00
|
|
|
swap_ans=True)
|
|
|
|
|
|
|
|
|
|
|
|
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
2020-12-10 02:29:40 +00:00
|
|
|
precision,
|
|
|
|
preferred_element_type: Optional[DType]):
|
2020-11-25 15:23:00 -08:00
|
|
|
lhs, rhs = batched_args
|
|
|
|
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
|
|
|
(lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)
|
|
|
|
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
2020-12-10 02:29:40 +00:00
|
|
|
precision=precision,
|
|
|
|
preferred_element_type=preferred_element_type)
|
2020-11-25 15:23:00 -08:00
|
|
|
return batched_out, result_batch_dim
|
|
|
|
|
|
|
|
def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
2020-10-17 14:33:26 -04:00
|
|
|
# there are three kinds of dimensions in a dot_general:
|
|
|
|
# - contraction dimensions appear in lhs and rhs but not the result
|
|
|
|
# - batch dimensions appear in lhs, rhs, and result
|
|
|
|
# - tensor product dimensions appear in the result and one of lhs or rhs
|
2020-11-25 15:23:00 -08:00
|
|
|
lhs_ndim, rhs_ndim = ndims
|
2020-10-17 14:33:26 -04:00
|
|
|
lbd, rbd = batch_dims
|
|
|
|
assert lbd is not None or rbd is not None
|
2020-11-25 15:23:00 -08:00
|
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def bump_dims(dims, b):
|
|
|
|
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
|
|
|
|
|
|
|
if lbd is not None and rbd is not None:
|
|
|
|
# adding a batch dimension
|
|
|
|
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
|
|
|
|
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
|
|
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
|
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
|
|
|
result_batch_dim = 0
|
|
|
|
else:
|
|
|
|
# adding a tensor product dimension
|
|
|
|
if lbd is not None:
|
2020-11-25 15:23:00 -08:00
|
|
|
other = tuple(d for d in range(lhs_ndim)
|
2020-10-17 14:33:26 -04:00
|
|
|
if d not in lhs_batch and d not in lhs_contract)
|
|
|
|
result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd)))
|
|
|
|
lhs_batch = bump_dims(lhs_batch, lbd)
|
|
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
|
|
|
else:
|
2020-11-25 15:23:00 -08:00
|
|
|
other = tuple(d for d in range(rhs_ndim)
|
2020-10-17 14:33:26 -04:00
|
|
|
if d not in rhs_batch and d not in rhs_contract)
|
2020-11-25 15:23:00 -08:00
|
|
|
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
2020-10-17 14:33:26 -04:00
|
|
|
sum(np.less(other, rbd)))
|
|
|
|
rhs_batch = bump_dims(rhs_batch, rbd)
|
|
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
|
|
|
|
|
|
|
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
2020-11-25 15:23:00 -08:00
|
|
|
return new_dimension_numbers, int(result_batch_dim)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _dot_general_translation_rule(ctx, avals_in, avals_out, lhs, rhs, *,
|
|
|
|
dimension_numbers, precision,
|
2020-12-10 02:29:40 +00:00
|
|
|
preferred_element_type: Optional[DType]):
|
|
|
|
if preferred_element_type is not None:
|
2021-10-19 06:48:34 -07:00
|
|
|
preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.DotGeneral(lhs, rhs,
|
|
|
|
xc.make_dot_dimension_numbers(dimension_numbers),
|
|
|
|
precision_config=_precision_config(precision),
|
|
|
|
preferred_element_type=preferred_element_type)]
|
|
|
|
|
|
|
|
def _dot_general_cpu_translation_rule(ctx, avals_in, avals_out, lhs, rhs, *,
|
|
|
|
dimension_numbers, precision,
|
|
|
|
preferred_element_type: Optional[DType]):
|
2021-08-03 00:02:59 -04:00
|
|
|
if preferred_element_type is not None:
|
2021-10-19 06:48:34 -07:00
|
|
|
preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type)
|
2021-08-03 00:02:59 -04:00
|
|
|
|
|
|
|
# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
|
2021-10-18 18:06:48 -07:00
|
|
|
if avals_in[0].dtype == np.float16:
|
2021-10-19 06:48:34 -07:00
|
|
|
lhs = xops.ConvertElementType(
|
|
|
|
lhs, xla.dtype_to_primitive_type(np.dtype(np.float32)))
|
|
|
|
rhs = xops.ConvertElementType(
|
|
|
|
rhs, xla.dtype_to_primitive_type(np.dtype(np.float32)))
|
|
|
|
preferred_element_type = (
|
|
|
|
preferred_element_type or
|
|
|
|
xla.dtype_to_primitive_type(np.dtype(np.float16)))
|
2021-08-03 00:02:59 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.DotGeneral(lhs, rhs,
|
|
|
|
xc.make_dot_dimension_numbers(dimension_numbers),
|
|
|
|
precision_config=_precision_config(precision),
|
|
|
|
preferred_element_type=preferred_element_type)]
|
2021-08-03 00:02:59 -04:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
|
2020-12-10 02:29:40 +00:00
|
|
|
precision,
|
|
|
|
preferred_element_type: Optional[DType]):
|
2020-10-17 14:33:26 -04:00
|
|
|
lhs, rhs = padded_vals
|
|
|
|
# Only need to mask off contraction dims of one side - we mask the lhs here
|
|
|
|
# but this is arbitrary. Could check the sizes of lhs and rhs and mask
|
|
|
|
# whichever is smallest.
|
|
|
|
lhs_shape, _ = logical_shapes
|
|
|
|
(lhs_contract, _), _ = dimension_numbers
|
|
|
|
return dot_general(_masked(lhs, lhs_shape, lhs_contract),
|
2020-12-10 02:29:40 +00:00
|
|
|
rhs, dimension_numbers, precision=precision,
|
|
|
|
preferred_element_type=preferred_element_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
|
|
|
_dot_general_dtype_rule, 'dot_general',
|
|
|
|
_dot_general_translation_rule)
|
|
|
|
ad.defbilinear(dot_general_p,
|
|
|
|
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
|
|
|
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
|
|
|
masking.masking_rules[dot_general_p] = _dot_general_masking_rule
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(dot_general_p, _dot_general_cpu_translation_rule,
|
|
|
|
platform="cpu")
|
2021-08-03 00:02:59 -04:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
|
|
|
|
if precision is None:
|
|
|
|
precision = (Precision.DEFAULT, Precision.DEFAULT)
|
|
|
|
elif not isinstance(precision, tuple):
|
|
|
|
precision = (precision, precision)
|
|
|
|
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in precision])
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
2021-11-23 18:57:45 -08:00
|
|
|
precision, preferred_element_type: Optional[np.dtype]):
|
2021-12-07 07:12:08 -08:00
|
|
|
del preferred_element_type # Implied by the output aval
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
lhs_aval, rhs_aval = ctx.avals_in
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
2021-12-04 10:34:53 -08:00
|
|
|
|
|
|
|
# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
if ctx.module_context.platform == "cpu":
|
2021-12-04 10:34:53 -08:00
|
|
|
if lhs_aval.dtype == np.float16:
|
|
|
|
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
|
|
|
|
lhs = mhlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
|
|
|
|
lhs).result
|
|
|
|
if rhs_aval.dtype == np.float16:
|
|
|
|
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
|
|
|
|
rhs = mhlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32),
|
|
|
|
rhs).result
|
2021-11-23 18:57:45 -08:00
|
|
|
dot_dnums = mhlo.DotDimensionNumbers.get(
|
|
|
|
lhs_batching_dimensions=list(lhs_batch),
|
|
|
|
rhs_batching_dimensions=list(rhs_batch),
|
|
|
|
lhs_contracting_dimensions=list(lhs_contracting),
|
|
|
|
rhs_contracting_dimensions=list(rhs_contracting))
|
2021-12-07 07:12:08 -08:00
|
|
|
return [mhlo.DotGeneralOp(mlir.aval_to_ir_type(aval_out), lhs, rhs,
|
|
|
|
dot_dnums, precision_attr(precision)).result]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(dot_general_p, _dot_general_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
|
|
|
|
_check_shapelike('broadcast_in_dim', 'shape', shape)
|
|
|
|
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
|
|
|
|
broadcast_dimensions)
|
|
|
|
operand_ndim = np.ndim(operand)
|
|
|
|
if operand_ndim != len(broadcast_dimensions):
|
|
|
|
msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
|
|
|
|
'operand ndim; got broadcast_dimensions {} for operand ndim {}.')
|
|
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim))
|
|
|
|
if len(shape) < operand_ndim:
|
|
|
|
msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank '
|
|
|
|
'to the operand shape; got operand ndim {} and target broadcast ndim {}.')
|
|
|
|
raise TypeError(msg.format(operand_ndim, len(shape)))
|
|
|
|
if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
|
|
|
|
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
|
|
|
'dimensions, got {} for operand ndim {} and shape {}.')
|
|
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
|
2021-04-05 16:37:35 +03:00
|
|
|
if not all(core.symbolic_equal_one_of_dim(operand.shape[i],
|
2021-04-01 15:37:01 +03:00
|
|
|
[1, shape[broadcast_dimensions[i]]])
|
|
|
|
for i in range(operand_ndim)):
|
2020-11-03 15:26:22 +01:00
|
|
|
msg = (
|
|
|
|
"broadcast_in_dim operand dimension sizes must either be 1, or be "
|
|
|
|
"equal to their corresponding dimensions in the target broadcast "
|
|
|
|
"shape; got operand of shape {}, target broadcast shape {}, "
|
|
|
|
"broadcast_dimensions {} ")
|
|
|
|
raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
|
2020-10-17 14:33:26 -04:00
|
|
|
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
|
|
|
|
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
|
2020-11-03 15:26:22 +01:00
|
|
|
msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; "
|
|
|
|
"got broadcast_dimensions {}")
|
|
|
|
raise TypeError(msg.format(broadcast_dimensions))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
return shape
|
|
|
|
|
2020-12-29 10:30:22 -08:00
|
|
|
def _broadcast_in_dim_transpose_rule(ct, operand, *, shape, broadcast_dimensions):
|
|
|
|
shape_in = operand.aval.shape
|
2021-08-04 07:30:12 +03:00
|
|
|
unit_dimensions = tuple(i for i, s in enumerate(shape_in) if core.symbolic_equal_dim(s, 1))
|
2020-12-29 10:30:22 -08:00
|
|
|
bdims = tuple(np.delete(broadcast_dimensions, unit_dimensions))
|
|
|
|
axes = tuple(np.delete(range(len(shape)), bdims))
|
|
|
|
return [expand_dims(_reduce_sum(ct, axes), unit_dimensions)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
|
|
|
|
broadcast_dimensions):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
new_operand = batching.moveaxis(operand, bdim, 0)
|
|
|
|
new_shape = (operand.shape[bdim],) + shape
|
|
|
|
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
|
|
|
|
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
|
|
|
|
|
2021-11-15 21:21:29 -08:00
|
|
|
def _broadcast_in_dim_fwd_rule(eqn):
|
|
|
|
v, = eqn.invars
|
|
|
|
if core.symbolic_equal_shape(eqn.params['shape'], v.aval.shape):
|
|
|
|
return [v], None
|
|
|
|
else:
|
|
|
|
return [None], eqn
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
broadcast_in_dim_p = standard_primitive(
|
2020-12-15 10:15:49 -08:00
|
|
|
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
2020-12-29 12:57:54 -08:00
|
|
|
ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
2021-11-15 21:21:29 -08:00
|
|
|
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _broadcast_in_dim_lower(ctx, x, *, shape, broadcast_dimensions):
|
2021-11-23 18:57:45 -08:00
|
|
|
del shape
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.BroadcastInDimOp(
|
|
|
|
mlir.aval_to_ir_type(aval_out), x,
|
|
|
|
mlir.dense_int_elements(broadcast_dimensions)
|
|
|
|
).results
|
|
|
|
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _clamp_shape_rule(min, operand, max):
|
|
|
|
if min.shape and min.shape != operand.shape:
|
2021-11-11 06:36:31 -08:00
|
|
|
raise TypeError("clamp requires min.shape == operand.shape or min.shape == "
|
|
|
|
f"(), got min.shape={min.shape}, "
|
|
|
|
f"operand.shape={operand.shape}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
if max.shape and max.shape != operand.shape:
|
2021-11-11 06:36:31 -08:00
|
|
|
raise TypeError("clamp requires max.shape == operand.shape or max.shape == "
|
|
|
|
f"(), got max.shape={max.shape}, "
|
|
|
|
f"operand.shape={operand.shape}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
return operand.shape
|
|
|
|
|
|
|
|
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
|
|
|
|
'clamp')
|
|
|
|
|
2021-06-20 19:09:13 -07:00
|
|
|
def _clamp_batch_rule(batched_args, batch_dims, **params):
|
|
|
|
min, x, max = batched_args
|
|
|
|
min_bdim, x_bdim, max_bdim = batch_dims
|
|
|
|
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
|
|
|
if i is not None)
|
|
|
|
|
|
|
|
# avoid transposes and some broadcasts in special cases
|
|
|
|
if min_bdim == x_bdim == max_bdim:
|
|
|
|
if np.shape(min) == np.shape(x) == np.shape(max):
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
elif np.ndim(min) == np.ndim(max) == 0:
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
elif np.ndim(min) == np.ndim(max) == 1:
|
|
|
|
min = broadcast_in_dim(min, x.shape, [min_bdim])
|
|
|
|
max = broadcast_in_dim(max, x.shape, [max_bdim])
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
elif np.ndim(min) == 0 and np.ndim(max) == 0 and x_bdim is not None:
|
|
|
|
return clamp_p.bind(min, x, max), x_bdim
|
|
|
|
|
|
|
|
min = batching.bdim_at_front(min, min_bdim, size) if np.shape(min) else min
|
|
|
|
max = batching.bdim_at_front(max, max_bdim, size) if np.shape(max) else max
|
|
|
|
x = batching.bdim_at_front(x, x_bdim, size) if np.shape(x) else x
|
|
|
|
if np.ndim(min) == 0 and np.ndim(x) > 0:
|
|
|
|
min = broadcast(min, x.shape)
|
|
|
|
if np.ndim(max) == 0 and np.ndim(x) > 0:
|
|
|
|
max = broadcast(max, x.shape)
|
|
|
|
if 0 < np.ndim(min) < np.ndim(x):
|
|
|
|
assert np.ndim(min) == 1, np.ndim(min)
|
|
|
|
min = broadcast_in_dim(min, x.shape, [0])
|
|
|
|
if 0 < np.ndim(max) < np.ndim(x):
|
|
|
|
assert np.ndim(max) == 1, np.ndim(max)
|
|
|
|
max = broadcast_in_dim(max, x.shape, [0])
|
|
|
|
if np.ndim(min) > np.ndim(x):
|
|
|
|
assert np.ndim(x) == 0, np.ndim(x)
|
|
|
|
x = broadcast(x, min.shape)
|
|
|
|
return clamp_p.bind(min, x, max), 0
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
|
|
|
|
ad.defjvp(clamp_p,
|
|
|
|
lambda g, min, operand, max:
|
|
|
|
select(bitwise_and(gt(min, operand), lt(min, max)),
|
2021-04-15 15:16:29 -07:00
|
|
|
g, _zeros(operand)),
|
2020-10-17 14:33:26 -04:00
|
|
|
lambda g, min, operand, max:
|
|
|
|
select(bitwise_and(gt(operand, min), lt(operand, max)),
|
|
|
|
g, _zeros(operand)),
|
|
|
|
lambda g, min, operand, max:
|
2021-04-15 15:16:29 -07:00
|
|
|
select(lt(max, operand), g, _zeros(operand)))
|
2021-06-20 19:09:13 -07:00
|
|
|
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(
|
|
|
|
clamp_p, partial(_nary_lower_mhlo, mhlo.ClampOp, explicit_type=True))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _concatenate_shape_rule(*operands, **kwargs):
|
|
|
|
dimension = kwargs.pop('dimension')
|
|
|
|
if not operands:
|
|
|
|
msg = "concatenate expects at least one operand, got 0."
|
|
|
|
raise TypeError(msg)
|
|
|
|
if not all(isinstance(operand, UnshapedArray) for operand in operands):
|
|
|
|
msg = "All objects to concatenate must be arrays, got {}."
|
|
|
|
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
|
|
|
raise TypeError(msg.format(type(op)))
|
|
|
|
if len({operand.ndim for operand in operands}) != 1:
|
2021-08-17 14:37:27 -07:00
|
|
|
msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."
|
|
|
|
raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
|
2020-11-13 14:22:17 -08:00
|
|
|
if not 0 <= dimension < operands[0].ndim:
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = "concatenate dimension out of bounds: dimension {} for shapes {}."
|
2020-11-18 09:35:40 -08:00
|
|
|
raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands])))
|
2020-11-13 14:22:17 -08:00
|
|
|
shapes = [operand.shape[:dimension] + operand.shape[dimension+1:]
|
|
|
|
for operand in operands]
|
2020-11-13 14:55:04 -08:00
|
|
|
if not shapes[:-1] == shapes[1:]:
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
|
2020-11-13 14:55:04 -08:00
|
|
|
"other than the one being concatenated: concatenating along "
|
|
|
|
"dimension {} for shapes {}.")
|
|
|
|
shapes = [operand.shape for operand in operands]
|
2020-10-17 14:33:26 -04:00
|
|
|
raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
|
|
|
|
|
|
|
|
concat_size = sum(o.shape[dimension] for o in operands)
|
|
|
|
ex_shape = operands[0].shape
|
|
|
|
return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
|
|
|
|
|
|
|
|
def _concatenate_dtype_rule(*operands, **kwargs):
|
|
|
|
_check_same_dtypes('concatenate', False, *(o.dtype for o in operands))
|
|
|
|
return operands[0].dtype
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _concatenate_translation_rule(ctx, avals_in, avals_out, *operands,
|
|
|
|
dimension, **kw):
|
|
|
|
return [xops.ConcatInDim(ctx.builder, operands, dimension)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _concatenate_transpose_rule(t, *operands, dimension):
|
|
|
|
operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape
|
|
|
|
for o in operands]
|
|
|
|
if type(t) is ad_util.Zero:
|
2020-11-28 09:13:21 -08:00
|
|
|
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
|
|
|
|
for o in operands]
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
limit_points = np.cumsum([shape[dimension] for shape in operand_shapes])
|
|
|
|
starts = np.zeros((len(operands), t.ndim), dtype=int)
|
|
|
|
starts[1:, dimension] = limit_points[:-1]
|
|
|
|
limits = np.tile(t.shape, (len(operands), 1))
|
|
|
|
limits[:, dimension] = limit_points
|
|
|
|
|
2021-11-23 16:34:33 -08:00
|
|
|
return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o)
|
|
|
|
else None for o, start, limit in zip(operands, starts, limits)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
|
|
|
|
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
|
|
|
|
if bdim is not None)
|
|
|
|
operands = [batching.moveaxis(op, bdim, 0) if bdim is not None
|
|
|
|
else broadcast(op, (size,))
|
|
|
|
for op, bdim in zip(batched_args, batch_dims)]
|
|
|
|
return concatenate(operands, dimension + 1), 0
|
|
|
|
|
|
|
|
# The concatenate_p masking rule requires use of a while-loop construct and so
|
|
|
|
# is defined in lax_control_flow.py
|
|
|
|
|
|
|
|
concatenate_p = standard_primitive(
|
|
|
|
_concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate',
|
|
|
|
_concatenate_translation_rule)
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
|
|
|
|
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _concatenate_lower(ctx, *xs, dimension):
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
|
|
|
|
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _pad_dtype_rule(operand, padding_value, *, padding_config):
|
|
|
|
if operand.dtype != padding_value.dtype:
|
|
|
|
msg = "pad operand and padding_value must be same dtype: got {} and {}."
|
|
|
|
raise TypeError(msg.format(operand.dtype, padding_value.dtype))
|
|
|
|
|
|
|
|
return _input_dtype(operand, padding_value)
|
|
|
|
|
|
|
|
def _pad_shape_rule(operand, padding_value, *, padding_config):
|
|
|
|
del padding_value
|
2021-04-05 11:08:46 +03:00
|
|
|
op_shape = np.shape(operand)
|
2020-10-17 14:33:26 -04:00
|
|
|
if not len(padding_config) == np.ndim(operand):
|
|
|
|
raise ValueError("length of padding_config must equal the number of axes "
|
|
|
|
f"of operand, got padding_config {padding_config} "
|
2021-04-05 11:08:46 +03:00
|
|
|
f"for operand shape {op_shape}")
|
2020-10-17 14:33:26 -04:00
|
|
|
if not all(i >= 0 for _, _, i in padding_config):
|
|
|
|
raise ValueError("interior padding in padding_config must be nonnegative, "
|
|
|
|
f"got padding_config {padding_config}")
|
2021-04-05 16:37:35 +03:00
|
|
|
result = tuple(core.sum_dim(l, h, core.dilate_dim(d, i + 1))
|
2021-04-05 11:08:46 +03:00
|
|
|
for (l, h, i), d in zip(padding_config, op_shape))
|
2021-04-05 16:37:35 +03:00
|
|
|
if not all(core.greater_equal_dim(d, 0) for d in result):
|
2021-04-05 11:08:46 +03:00
|
|
|
msg = (f"Dimension size after padding is not at least 0, "
|
|
|
|
f"got result shape {result}, for padding_config {padding_config}"
|
|
|
|
f" and operand shape {op_shape}")
|
|
|
|
raise ValueError(msg)
|
|
|
|
return result
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _pad_transpose(t, operand, padding_value, *, padding_config):
|
|
|
|
if type(t) is ad_util.Zero:
|
2020-11-28 09:13:21 -08:00
|
|
|
t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
|
|
|
|
t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None
|
|
|
|
else:
|
|
|
|
lo, hi, interior = zip(*padding_config)
|
|
|
|
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
|
2020-11-27 18:01:22 -08:00
|
|
|
|
2020-11-28 09:13:21 -08:00
|
|
|
def t_op():
|
|
|
|
unpad_config = safe_zip(np.negative(lo), np.negative(hi),
|
|
|
|
np.zeros_like(interior))
|
|
|
|
unpadded = pad(t, np.array(0., t.dtype), unpad_config)
|
2021-11-23 16:34:33 -08:00
|
|
|
return slicing.slice(unpadded, np.zeros_like(lo), unpadded.shape,
|
|
|
|
np.add(interior, 1))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2020-11-28 09:13:21 -08:00
|
|
|
t_operand = t_op() if ad.is_undefined_primal(operand) else None
|
|
|
|
t_padv = sub(total(t), total(t_operand)) if ad.is_undefined_primal(padding_value) else None
|
2020-10-17 14:33:26 -04:00
|
|
|
return [t_operand, t_padv]
|
|
|
|
|
|
|
|
def _pad_batch_rule(batched_args, batch_dims, *, padding_config):
|
|
|
|
operand, padding_value = batched_args
|
|
|
|
operand_bdim, padding_value_bdim = batch_dims
|
2021-02-17 13:27:26 -05:00
|
|
|
if operand_bdim is None:
|
|
|
|
operand_bdim = 0
|
|
|
|
operand = broadcast(operand, (padding_value.shape[padding_value_bdim],))
|
|
|
|
|
|
|
|
padding_config = list(padding_config)
|
|
|
|
padding_config.insert(operand_bdim, (0, 0, 0))
|
2020-10-17 14:33:26 -04:00
|
|
|
if padding_value_bdim is None:
|
|
|
|
return pad(operand, padding_value, padding_config), operand_bdim
|
2021-02-17 13:27:26 -05:00
|
|
|
|
|
|
|
assert padding_value_bdim == 0, padding_value_bdim
|
|
|
|
|
|
|
|
x = pad(operand, _zero(operand), padding_config)
|
|
|
|
mask = pad(full_like(operand, True, np.bool_), False, padding_config)
|
|
|
|
broadcasted_padding = broadcast_in_dim(padding_value, x.shape,
|
|
|
|
(operand_bdim,))
|
|
|
|
return select(mask, x, broadcasted_padding), operand_bdim
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _pad_translation_rule(ctx, avals_in, avals_out, operand, padding_value, *,
|
|
|
|
padding_config):
|
|
|
|
return [xops.Pad(operand, padding_value,
|
|
|
|
xc.make_padding_config(padding_config))]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _pad_masking_rule(padded_vals, logical_shapes, padding_config):
|
|
|
|
operand, padding_value = padded_vals
|
|
|
|
shape, _ = logical_shapes
|
|
|
|
|
|
|
|
out = pad(operand, padding_value, padding_config)
|
|
|
|
out_shape = [lo + shape[i] * (interior + 1)
|
|
|
|
for i, (lo, hi, interior) in enumerate(padding_config)]
|
|
|
|
padded_dims = [i for i, config in enumerate(padding_config)
|
|
|
|
if config != (0, 0, 0)]
|
|
|
|
return _masked(out, out_shape, padded_dims, padding_value)
|
|
|
|
|
|
|
|
pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
|
|
|
|
translation_rule=_pad_translation_rule)
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(pad_p, _pad_transpose)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
|
|
|
masking.masking_rules[pad_p] = _pad_masking_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _pad_lower(ctx, x, padding_value, *, padding_config):
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
low, high, interior = util.unzip3(padding_config)
|
|
|
|
return mhlo.PadOp(mlir.aval_to_ir_type(aval_out), x, padding_value,
|
|
|
|
mlir.dense_int_elements(low),
|
|
|
|
mlir.dense_int_elements(high),
|
|
|
|
mlir.dense_int_elements(interior)).results
|
|
|
|
mlir.register_lowering(pad_p, _pad_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
# The squeeze primitive exists for the benefit of masking and other
|
|
|
|
# transformations that need to keep track of axis identity.
|
|
|
|
# For example, consider reshaping a 2D array with shape (1, N) into a 1D array
|
|
|
|
# with shape (N,). This results in the following JAXpr:
|
|
|
|
# reshape[ dimension=None new_sizes=(N,) ]
|
|
|
|
# For N > 1, we can match up the output array axis with the second axis of the
|
|
|
|
# input. But for N = 1, it is not clear how axes match up: all we know from the
|
|
|
|
# JAXpr is that we are reshaping from (1, 1) to (1,).
|
|
|
|
# In constrast, squeeze[ dimensions=(0,) ] is unambiguous.
|
|
|
|
|
|
|
|
|
|
|
|
def _squeeze_dtype_rule(operand, *, dimensions):
|
|
|
|
return operand.dtype
|
|
|
|
|
|
|
|
def _squeeze_shape_rule(operand, *, dimensions):
|
|
|
|
return _compute_squeeze_shape(np.shape(operand), dimensions)
|
|
|
|
|
|
|
|
def _compute_squeeze_shape(shape, dimensions):
|
|
|
|
dims_set = set(dimensions)
|
|
|
|
if len(dims_set) != len(dimensions):
|
|
|
|
raise ValueError(f"dimensions are not unique: {dimensions}")
|
|
|
|
if not all(0 <= d < len(shape) for d in dims_set):
|
|
|
|
raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
|
[jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
2021-05-20 14:07:52 +03:00
|
|
|
if any(not core.symbolic_equal_dim(shape[d], 1) for d in dimensions):
|
2020-10-17 14:33:26 -04:00
|
|
|
raise ValueError(
|
|
|
|
"cannot select an axis to squeeze out which has size not equal to "
|
|
|
|
f"one, got shape={shape} and dimensions={dimensions}")
|
|
|
|
return tuple(s for i, s in enumerate(shape) if i not in dims_set)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _squeeze_translation_rule(ctx, avals_in, avals_out, arg, *, dimensions):
|
|
|
|
return [xops.Reshape(arg, avals_out[0].shape)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _squeeze_transpose_rule(t, operand, *, dimensions):
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
return [expand_dims(t, dimensions)]
|
|
|
|
|
|
|
|
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
operand = batching.moveaxis(operand, bdim, 0)
|
|
|
|
dimensions = tuple(np.add(1, dimensions))
|
|
|
|
return squeeze(operand, dimensions=dimensions), 0
|
|
|
|
|
|
|
|
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
|
|
|
|
'squeeze', _squeeze_translation_rule)
|
|
|
|
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
|
|
|
|
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _squeeze_lower(ctx, operand, *, dimensions):
|
2021-11-23 18:57:45 -08:00
|
|
|
del dimensions # Implied by the output aval.
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
|
|
|
|
|
|
|
|
mlir.register_lowering(squeeze_p, _squeeze_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
2021-11-23 16:34:33 -08:00
|
|
|
def _shape_as_value(shape):
|
|
|
|
"""Converts a shape that may contain Poly values into a JAX value."""
|
|
|
|
if len(shape) == 0:
|
|
|
|
return full((0,), np.array(0, np.int64))
|
|
|
|
dims = [
|
|
|
|
expand_dims(convert_element_type(core.dimension_as_value(d), np.int64),
|
|
|
|
(0,))
|
|
|
|
for d in shape
|
|
|
|
]
|
|
|
|
return concatenate(dims, dimension=0)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _is_singleton_reshape(old, new):
|
|
|
|
# A singleton reshape is one where only singleton dimensions are added. We
|
|
|
|
# want to detect them because they can be expressed as (lazy) broadcasts.
|
|
|
|
old, new = iter(old), iter(new)
|
|
|
|
d1, d2 = next(old, None), next(new, None)
|
|
|
|
bcast_dims = []
|
|
|
|
i = 0
|
|
|
|
while True:
|
|
|
|
if d1 is d2 is None:
|
|
|
|
return bcast_dims
|
|
|
|
elif d1 == d2:
|
|
|
|
bcast_dims.append(i)
|
|
|
|
i += 1
|
|
|
|
d1, d2 = next(old, None), next(new, None)
|
|
|
|
elif d2 == 1:
|
|
|
|
i += 1
|
|
|
|
d2 = next(new, None)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
def _reshape_shape_rule(operand, *, new_sizes, dimensions):
|
2021-04-05 16:37:35 +03:00
|
|
|
if not all(core.greater_equal_dim(d, 0) for d in new_sizes):
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = 'reshape new_sizes must all be positive, got {}.'
|
|
|
|
raise TypeError(msg.format(new_sizes))
|
2021-04-05 16:37:35 +03:00
|
|
|
if not core.same_shape_sizes(np.shape(operand), new_sizes):
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
|
|
|
|
raise TypeError(msg.format(new_sizes, np.shape(operand)))
|
|
|
|
if dimensions is not None:
|
|
|
|
if set(dimensions) != set(range(np.ndim(operand))):
|
|
|
|
msg = ('reshape dimensions must be a permutation of operand dimensions, '
|
|
|
|
'got dimensions {} for shape {}.')
|
|
|
|
raise TypeError(msg.format(dimensions, np.shape(operand)))
|
|
|
|
return tuple(new_sizes)
|
|
|
|
|
|
|
|
def _reshape_dtype_rule(operand, *, new_sizes, dimensions):
|
|
|
|
return operand.dtype
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _reshape_translation_rule(ctx, avals_in, avals_out, operand, *, new_sizes,
|
|
|
|
dimensions):
|
2020-10-17 14:33:26 -04:00
|
|
|
if dimensions is None:
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Reshape(operand, new_sizes)]
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Reshape(operand, dimensions, new_sizes)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions):
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
if dimensions is None:
|
|
|
|
return [reshape(t, operand.aval.shape)]
|
|
|
|
else:
|
|
|
|
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)),
|
|
|
|
np.argsort(dimensions))]
|
|
|
|
|
|
|
|
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
operand = batching.moveaxis(operand, bdim, 0)
|
|
|
|
if dimensions is not None:
|
|
|
|
dimensions = (0,) + tuple(np.add(1, dimensions))
|
|
|
|
return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
|
|
|
|
|
|
|
|
def _reshape_masking_rule(padded_args, logical_shapes, polymorphic_shapes,
|
|
|
|
new_sizes, dimensions):
|
|
|
|
operand, = padded_args
|
|
|
|
old_shape, = polymorphic_shapes
|
|
|
|
def is_poly(size): return type(size) is masking.Poly and not size.is_constant
|
|
|
|
def merge_const_sizes(shape):
|
|
|
|
"""Merges all nonpolymorphic sizes into the previous polymorphic size."""
|
|
|
|
poly_dims = [i for i, size in enumerate(shape) if is_poly(size)]
|
|
|
|
return [prod(shape[start:stop])
|
|
|
|
for start, stop in zip([0] + poly_dims, poly_dims + [len(shape)])]
|
|
|
|
if merge_const_sizes(old_shape) != merge_const_sizes(new_sizes):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Reshape on padded dimensions causing fragmentation is not supported.")
|
|
|
|
|
|
|
|
return reshape(operand,
|
|
|
|
new_sizes=masking.padded_shape_as_value(new_sizes),
|
|
|
|
dimensions=dimensions)
|
|
|
|
|
|
|
|
reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
|
|
|
|
'reshape', _reshape_translation_rule)
|
|
|
|
ad.deflinear2(reshape_p, _reshape_transpose_rule)
|
|
|
|
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
|
|
|
|
masking.masking_rules[reshape_p] = _reshape_masking_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _reshape_lower(ctx, x, *, new_sizes, dimensions):
|
|
|
|
aval_in, = ctx.avals_in
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
if dimensions is not None:
|
|
|
|
aval = core.ShapedArray(np.take(aval_in.shape, dimensions), aval_in.dtype)
|
2022-01-04 21:06:26 -08:00
|
|
|
if jax._src.lib._xla_extension_version < 49:
|
|
|
|
x = mhlo.TransposeOp(
|
|
|
|
mlir.aval_to_ir_type(aval), x,
|
|
|
|
mlir.dense_int_elements(dimensions)).result
|
|
|
|
else:
|
|
|
|
x = mhlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), x).results
|
|
|
|
mlir.register_lowering(reshape_p, _reshape_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _rev_shape_rule(operand, *, dimensions):
|
|
|
|
_check_shapelike('rev', 'dimensions', dimensions)
|
|
|
|
if len(set(dimensions)) != len(dimensions):
|
|
|
|
msg = 'rev dimensions must be unique, got {}.'
|
|
|
|
raise TypeError(msg.format(dimensions))
|
|
|
|
if dimensions and not _max(dimensions) < operand.ndim:
|
|
|
|
msg = ('rev dimensions must all be less than operand ndim, got dimensions '
|
|
|
|
'{} for operand ndim {}.')
|
|
|
|
raise TypeError(msg.format(dimensions, operand.ndim))
|
|
|
|
return operand.shape
|
|
|
|
|
|
|
|
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
|
|
|
|
return rev(operand, new_dimensions), bdim
|
|
|
|
|
2020-12-15 10:15:49 -08:00
|
|
|
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _rev_lower(ctx, x, *, dimensions):
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
|
|
|
|
mlir.register_lowering(rev_p, _rev_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _transpose_shape_rule(operand, *, permutation):
|
|
|
|
if not isinstance(permutation, (tuple, list, np.ndarray)):
|
|
|
|
msg = "transpose permutation must be a tuple/list/ndarray, got {}."
|
|
|
|
raise TypeError(msg.format(type(permutation)))
|
|
|
|
if tuple(sorted(permutation)) != tuple(range(operand.ndim)):
|
|
|
|
msg = ("transpose permutation isn't a permutation of operand dimensions, "
|
|
|
|
"got permutation {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(permutation, operand.shape))
|
|
|
|
return tuple(np.take(operand.shape, permutation))
|
|
|
|
|
|
|
|
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
|
|
|
|
return transpose(operand, perm), 0
|
|
|
|
|
|
|
|
def _transpose_masking_rule(padded_vals, logical_shapes, permutation):
|
|
|
|
return transpose(*padded_vals, permutation=permutation)
|
|
|
|
|
|
|
|
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
|
|
|
'transpose')
|
2020-12-30 17:42:04 -08:00
|
|
|
ad.deflinear2(transpose_p,
|
2021-01-31 15:34:20 +02:00
|
|
|
lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) # type: ignore[arg-type]
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
|
|
|
masking.masking_rules[transpose_p] = _transpose_masking_rule
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _transpose_lower(ctx, x, *, permutation):
|
|
|
|
aval_out, = ctx.avals_out
|
2022-01-04 21:06:26 -08:00
|
|
|
if jax._src.lib._xla_extension_version < 49:
|
|
|
|
return mhlo.TransposeOp(
|
|
|
|
mlir.aval_to_ir_type(aval_out), x,
|
|
|
|
mlir.dense_int_elements(permutation)).results
|
|
|
|
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(transpose_p, _transpose_lower)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _select_shape_rule(pred, on_true, on_false):
|
|
|
|
if on_true.shape != on_false.shape:
|
|
|
|
msg = "select on_true and on_false must have the same shape, got {} and {}."
|
|
|
|
raise TypeError(msg.format(on_true.shape, on_false.shape))
|
|
|
|
if pred.shape and pred.shape != on_true.shape:
|
|
|
|
msg = ("select pred must be scalar or have the same shape as on_true and "
|
|
|
|
"on_false, got pred shape {} for on_true and on_false of shape {}.")
|
|
|
|
raise TypeError(msg.format(pred.shape, on_true.shape))
|
|
|
|
return on_true.shape
|
|
|
|
|
|
|
|
def _select_dtype_rule(pred, on_true, on_false):
|
|
|
|
_check_same_dtypes("select", False, on_true.dtype, on_false.dtype)
|
|
|
|
if not dtypes.issubdtype(pred.dtype, np.bool_):
|
|
|
|
msg = "select pred must be boolean type, got {}."
|
|
|
|
raise TypeError(msg.format(pred.dtype))
|
|
|
|
return on_true.dtype
|
|
|
|
|
|
|
|
def _select_transpose_rule(t, pred, on_true, on_false):
|
|
|
|
assert not ad.is_undefined_primal(pred)
|
|
|
|
if type(t) is ad_util.Zero:
|
2020-11-28 09:13:21 -08:00
|
|
|
return [None,
|
|
|
|
ad_util.Zero(on_true.aval) if ad.is_undefined_primal(on_true) else None,
|
|
|
|
ad_util.Zero(on_false.aval) if ad.is_undefined_primal(on_false) else None]
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
zeros = full_like(t, 0)
|
|
|
|
return [None,
|
|
|
|
select(pred, t, zeros) if ad.is_undefined_primal(on_true) else None,
|
|
|
|
select(pred, zeros, t) if ad.is_undefined_primal(on_false) else None]
|
|
|
|
|
|
|
|
def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
|
|
|
pred, on_true, on_false, = batched_args
|
|
|
|
pred_bdim, ot_bdim, of_bdim = batch_dims
|
|
|
|
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
|
|
|
if i is not None)
|
|
|
|
|
|
|
|
# avoid transposes and some broadcasts in special cases
|
|
|
|
if pred_bdim == ot_bdim == of_bdim:
|
|
|
|
if np.shape(pred) == np.shape(on_true):
|
|
|
|
return select(pred, on_true, on_false), pred_bdim
|
|
|
|
else:
|
|
|
|
# vmapped function had a scalar pred with nonscalar args
|
|
|
|
assert np.ndim(pred) == 1
|
|
|
|
pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim])
|
|
|
|
return select(pred, on_true, on_false), pred_bdim
|
|
|
|
elif np.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None:
|
|
|
|
if ot_bdim == of_bdim:
|
|
|
|
return select(pred, on_true, on_false), ot_bdim
|
|
|
|
elif np.shape(on_true) == np.shape(on_false):
|
|
|
|
on_false = batching.moveaxis(on_false, of_bdim, ot_bdim)
|
|
|
|
return select(pred, on_true, on_false), ot_bdim
|
|
|
|
|
|
|
|
pred = batching.bdim_at_front(pred, pred_bdim, size) if np.shape(pred) else pred
|
2021-04-05 15:00:15 +03:00
|
|
|
if not () == np.shape(on_true) == np.shape(on_false):
|
2020-10-17 14:33:26 -04:00
|
|
|
on_true = batching.bdim_at_front(on_true, ot_bdim, size)
|
|
|
|
on_false = batching.bdim_at_front(on_false, of_bdim, size)
|
|
|
|
assert np.shape(on_true) == np.shape(on_false)
|
|
|
|
if 0 < np.ndim(pred) < np.ndim(on_true):
|
|
|
|
# vmapped function had a scalar pred with nonscalar args
|
|
|
|
assert np.ndim(pred) == 1
|
|
|
|
pred = broadcast_in_dim(pred, on_true.shape, [0])
|
|
|
|
if np.ndim(pred) > np.ndim(on_true):
|
|
|
|
assert np.ndim(on_true) == 0
|
|
|
|
on_true = broadcast(on_true, pred.shape)
|
|
|
|
on_false = broadcast(on_false, pred.shape)
|
|
|
|
return select(pred, on_true, on_false), 0
|
|
|
|
|
|
|
|
def _select_masking_rule(padded_vals, logical_shapes):
|
|
|
|
pred_shape, true_shape, false_shape = [
|
|
|
|
masking.padded_shape_as_value(val.shape) for val in padded_vals]
|
|
|
|
assert np.array_equal(pred_shape, true_shape)
|
|
|
|
assert np.array_equal(pred_shape, false_shape)
|
|
|
|
return select(*padded_vals)
|
|
|
|
|
2020-10-26 15:32:31 -07:00
|
|
|
def _select_jvp(primals, tangents):
|
|
|
|
pred, on_true, on_false = primals
|
|
|
|
_, on_true_dot, on_false_dot = tangents
|
|
|
|
out = select(pred, on_true, on_false)
|
|
|
|
if type(on_true_dot) is ad_util.Zero:
|
2021-04-15 14:21:53 +01:00
|
|
|
if type(on_false_dot) is ad_util.Zero:
|
|
|
|
out_dot = ad_util.Zero(on_true_dot.aval)
|
|
|
|
else:
|
|
|
|
out_dot = select(pred, _zeros(on_false_dot), on_false_dot)
|
2020-10-26 15:32:31 -07:00
|
|
|
elif type(on_false_dot) is ad_util.Zero:
|
|
|
|
out_dot = select(pred, on_true_dot, _zeros(on_true_dot))
|
|
|
|
else:
|
|
|
|
out_dot = select(pred, on_true_dot, on_false_dot)
|
|
|
|
return out, out_dot
|
|
|
|
|
2021-02-08 13:37:25 -08:00
|
|
|
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select',
|
|
|
|
weak_type_rule=_argnum_weak_type(1, 2))
|
2020-10-26 15:32:31 -07:00
|
|
|
ad.primitive_jvps[select_p] = _select_jvp
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_transposes[select_p] = _select_transpose_rule
|
|
|
|
batching.primitive_batchers[select_p] = _select_batch_rule
|
|
|
|
masking.masking_rules[select_p] = _select_masking_rule
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.register_lowering(select_p, partial(_nary_lower_mhlo, mhlo.SelectOp))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
2021-02-12 10:30:46 -08:00
|
|
|
def _reduce_shape_rule(*avals, computation, jaxpr, consts, dimensions):
|
|
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
|
|
if any(arg.shape != () for arg in init_val_avals):
|
|
|
|
init_val_shapes = [a.shape for a in init_val_avals]
|
|
|
|
raise ValueError(f'reduce found non-scalar initial value: {init_val_shapes}')
|
|
|
|
return [tuple(np.delete(op.shape, dimensions)) for op in operand_avals]
|
|
|
|
|
|
|
|
def _reduce_dtype_rule(*avals, computation, jaxpr, consts, dimensions):
|
|
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
|
|
operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals]
|
|
|
|
init_val_dtypes = [dtypes.canonicalize_dtype(init.dtype) for init in init_val_avals]
|
|
|
|
if operand_dtypes != init_val_dtypes:
|
|
|
|
raise TypeError(
|
|
|
|
"reduce operand dtypes should match corresponding initial value dtypes, "
|
|
|
|
f"got operands={operand_avals} and initial_values={init_val_avals}")
|
2020-12-14 09:20:26 -08:00
|
|
|
return operand_dtypes
|
2020-11-10 15:57:19 -08:00
|
|
|
|
2021-02-12 10:30:46 -08:00
|
|
|
def _reduce_weak_type_rule(*avals, computation, jaxpr, consts, dimensions):
|
|
|
|
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
|
|
|
|
return [op.weak_type and init_val.weak_type
|
|
|
|
for op, init_val in safe_zip(operand_avals, init_val_avals)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _reduce_translation_rule(ctx, avals_in, avals_out, *values, computation,
|
|
|
|
jaxpr, consts, dimensions):
|
|
|
|
c = ctx.builder
|
2020-11-10 15:57:19 -08:00
|
|
|
operands, init_values = split_list(values, [len(values) // 2])
|
|
|
|
if len(operands) == 1:
|
|
|
|
init_value = init_values[0]
|
2021-11-15 18:26:05 -08:00
|
|
|
xla_computation = _reduction_computation(ctx, jaxpr, consts, init_value)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Reduce(c, operands, init_values, xla_computation, dimensions)]
|
2021-11-15 18:26:05 -08:00
|
|
|
xla_computation = _reduction_computation(ctx, jaxpr, consts, init_values,
|
2021-10-18 18:06:48 -07:00
|
|
|
singleton=False)
|
|
|
|
return xla.xla_destructure(
|
|
|
|
c, xops.Reduce(c, operands, init_values, xla_computation, dimensions))
|
2020-11-10 15:57:19 -08:00
|
|
|
|
|
|
|
def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
|
|
|
|
consts, dimensions):
|
2021-02-10 11:04:19 -08:00
|
|
|
# TODO(mattjj,frostig): use batch_jaxpr, delete computation (assumes poly??)
|
2020-11-10 15:57:19 -08:00
|
|
|
num_operands = len(batched_args) // 2
|
|
|
|
operands, init_values = split_list(batched_args, [num_operands])
|
|
|
|
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
|
2021-02-10 11:04:19 -08:00
|
|
|
if all(init_value_bdim is batching.not_mapped
|
|
|
|
for init_value_bdim in init_value_bdims):
|
2021-11-03 09:36:31 -04:00
|
|
|
size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
|
|
|
|
if ax is not None)
|
|
|
|
operands = [batching.bdim_at_front(arg, bdim, size)
|
|
|
|
for arg, bdim in zip(operands, operand_bdims)]
|
|
|
|
new_dimensions = [d + 1 for d in dimensions]
|
|
|
|
new_operand_bdims = [0] * num_operands
|
2020-11-10 15:57:19 -08:00
|
|
|
return reduce_p.bind(*(operands + init_values),
|
2021-11-03 09:36:31 -04:00
|
|
|
computation=computation,
|
|
|
|
dimensions=tuple(new_dimensions),
|
2020-11-10 15:57:19 -08:00
|
|
|
consts=consts,
|
|
|
|
jaxpr=jaxpr), new_operand_bdims
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
raise NotImplementedError # loop and stack
|
|
|
|
|
2021-11-15 18:26:05 -08:00
|
|
|
def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True):
|
|
|
|
c = ctx.builder
|
|
|
|
platform = ctx.platform
|
2020-11-10 15:57:19 -08:00
|
|
|
if singleton:
|
|
|
|
init_values = [init_values]
|
|
|
|
shapes = safe_map(c.get_shape, init_values + init_values)
|
2020-11-19 11:36:35 +00:00
|
|
|
axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions
|
2021-10-18 13:19:45 -04:00
|
|
|
subc = xc.XlaBuilder("reduction_computation")
|
2020-10-17 14:33:26 -04:00
|
|
|
assert len(consts) == 0, "Reduction computations cannot have constants"
|
2021-11-30 14:24:02 -08:00
|
|
|
args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
|
2021-11-15 18:26:05 -08:00
|
|
|
ctx = xla.TranslationContext(subc, platform, axis_env, '')
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
|
2020-11-10 15:57:19 -08:00
|
|
|
if singleton:
|
|
|
|
return subc.build(out_nodes[0])
|
|
|
|
out_nodes = xops.Tuple(subc, out_nodes)
|
|
|
|
return subc.build(out_nodes)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-03-23 10:31:02 -04:00
|
|
|
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
|
2021-08-25 12:50:28 -07:00
|
|
|
input_shape = np.array(primals[0].shape, dtype=np.int_)
|
2021-03-23 10:31:02 -04:00
|
|
|
|
|
|
|
n = np.prod(input_shape[list(axes)])
|
|
|
|
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
|
|
|
|
|
|
|
# Move the reduced axes to the front, and flatten them to 1D.
|
|
|
|
permutation = axes + tuple(non_axes)
|
|
|
|
new_shape = (n,) + tuple(input_shape[non_axes])
|
|
|
|
primals = tuple(reshape(x, new_shape, permutation) for x in primals)
|
|
|
|
tangents = tuple(reshape(t, new_shape, permutation) for t in tangents)
|
|
|
|
|
|
|
|
for d in range(len(non_axes) + 1):
|
|
|
|
reducer = api.vmap(reducer)
|
|
|
|
def _reduce_tree(*xs, axis=0):
|
|
|
|
"""Reduce by repeatedly splitting the array and multiplying."""
|
|
|
|
while xs[0].shape[axis] > 1:
|
|
|
|
n = xs[0].shape[axis]
|
|
|
|
n1 = (n + 1) // 2
|
|
|
|
n2 = n - n1
|
2021-11-23 16:34:33 -08:00
|
|
|
xs1 = [slicing.slice_in_dim(x, 0, n1) for x in xs]
|
|
|
|
xs2 = [slicing.slice_in_dim(x, n1, None) for x in xs]
|
2021-03-23 10:31:02 -04:00
|
|
|
if n2 != n1:
|
|
|
|
paddings = [(0, 0, 0)] * len(xs[0].shape)
|
|
|
|
paddings[axis] = (0, 1, 0)
|
|
|
|
xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)]
|
|
|
|
xs = reducer(*(xs1 + xs2))
|
|
|
|
if xs[0].shape[axis] == 0:
|
|
|
|
return [full(input_shape[non_axes], i) for i in init_values]
|
|
|
|
return tuple(squeeze(x, (axis,)) for x in xs)
|
|
|
|
|
|
|
|
return api.jvp(_reduce_tree, primals, tangents)
|
|
|
|
|
|
|
|
def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr,
|
|
|
|
consts, dimensions):
|
|
|
|
primal_xs, init_values = split_list(primals, [len(primals) // 2])
|
|
|
|
tangent_xs, tangent_init = split_list(tangents, [len(tangents) // 2])
|
|
|
|
# This test may be too strict, if a value is actually zero but we cannot prove
|
|
|
|
# it is symbolically zero.
|
|
|
|
if any(type(t) is not ad_util.Zero for t in tangent_init):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Gradient of general lax.reduce with non-zero tangents for "
|
|
|
|
"initial values to reduction not implemented")
|
|
|
|
reducer = core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts))
|
|
|
|
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _masking_defreducer(prim, identity):
|
|
|
|
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
|
|
|
|
|
|
|
|
def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
|
2020-11-10 15:57:19 -08:00
|
|
|
axes, input_shape=None, **reduce_kwargs):
|
2020-10-17 14:33:26 -04:00
|
|
|
(padded_val,), (logical_shape,) = padded_vals, logical_shapes
|
|
|
|
padded_shape = masking.padded_shape_as_value(padded_val.shape)
|
|
|
|
masks = [broadcasted_iota(np.int32, padded_shape, i) < d
|
|
|
|
for i, d in enumerate(logical_shape) if i in axes]
|
|
|
|
mask = _reduce(operator.and_, masks)
|
|
|
|
masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype))
|
2020-11-10 15:57:19 -08:00
|
|
|
prim_bind = partial(prim.bind, **reduce_kwargs)
|
|
|
|
bind = prim_bind if input_shape is None else partial(prim_bind, input_shape=padded_shape)
|
2020-10-17 14:33:26 -04:00
|
|
|
return bind(masked_val, axes=axes)
|
|
|
|
|
2021-02-10 11:04:19 -08:00
|
|
|
def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions):
|
|
|
|
# TODO(mattjj,frostig): see the TODOs noting limitations/assumptions in
|
|
|
|
# _reduce_batching_rule. We're making the same assumptions here for now.
|
|
|
|
num_operands = len(avals) // 2
|
|
|
|
operand_avals, init_avals = split_list(avals, [num_operands])
|
|
|
|
if any(a.named_shape for a in init_avals):
|
|
|
|
raise NotImplementedError
|
|
|
|
named_shapes = [a.named_shape for a in operand_avals]
|
2021-10-18 18:06:48 -07:00
|
|
|
join = core.join_named_shapes(*(a.named_shape for a in operand_avals))
|
|
|
|
return [join] * len(named_shapes)
|
2021-02-10 11:04:19 -08:00
|
|
|
|
|
|
|
|
2021-02-12 10:30:46 -08:00
|
|
|
reduce_p = core.Primitive('reduce')
|
|
|
|
reduce_p.multiple_results = True
|
|
|
|
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
|
|
|
|
reduce_p.def_abstract_eval(
|
|
|
|
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
|
2021-02-10 11:04:19 -08:00
|
|
|
_reduce_dtype_rule, _reduce_weak_type_rule,
|
|
|
|
_reduce_named_shape_rule))
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(reduce_p, _reduce_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
2021-03-23 10:31:02 -04:00
|
|
|
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
|
2020-10-17 14:33:26 -04:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
|
|
|
|
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
2021-11-23 18:57:45 -08:00
|
|
|
operands, init_values = util.split_list(values, [len(values) // 2])
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
init_value_avals = ctx.avals_in[len(values) // 2:]
|
|
|
|
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
2021-11-23 18:57:45 -08:00
|
|
|
operands, init_values, mlir.dense_int_elements(dimensions))
|
|
|
|
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
|
|
|
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
|
|
|
with ir.InsertionPoint(reducer):
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
reducer_ctx = ctx.module_context.replace(name_stack='')
|
|
|
|
out_nodes = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, consts,
|
2021-11-23 18:57:45 -08:00
|
|
|
*([a] for a in reducer.arguments))
|
|
|
|
mhlo.ReturnOp(util.flatten(out_nodes))
|
|
|
|
return op.results
|
|
|
|
|
|
|
|
mlir.register_lowering(reduce_p, _reduce_lower)
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
|
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
|
|
|
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
|
|
|
|
"of number.".format(name, np.dtype(operand.dtype).name))
|
|
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
|
|
|
|
def _reduce_sum_shape_rule(operand, *, axes):
|
|
|
|
return _reduce_op_shape_rule(operand, axes=axes)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _reduce_sum_translation_rule(ctx, avals_in, avals_out, operand, *, axes):
|
|
|
|
operand_aval, = avals_in
|
|
|
|
scalar = ShapedArray((), operand_aval.dtype)
|
|
|
|
return [xops.Reduce(
|
|
|
|
ctx.builder, [operand],
|
2021-10-19 08:40:15 -07:00
|
|
|
[xla.pyval_to_ir_constant(ctx.builder, np.array(0, operand_aval.dtype))],
|
2021-11-30 05:34:00 -08:00
|
|
|
xla.primitive_subcomputation(ctx.platform, ctx.axis_env, add_p, scalar,
|
|
|
|
scalar), axes)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
input_shape = operand.aval.shape
|
|
|
|
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
|
|
|
|
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
|
|
|
|
assert result.shape == input_shape
|
|
|
|
return [result]
|
|
|
|
|
|
|
|
reduce_sum_p = standard_primitive(
|
|
|
|
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
|
|
|
'reduce_sum', _reduce_sum_translation_rule)
|
|
|
|
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
|
|
|
batching.defreducer(reduce_sum_p)
|
|
|
|
_masking_defreducer(reduce_sum_p,
|
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(0, dtype), shape))
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
|
|
|
del input_shape # Unused.
|
|
|
|
if len(axes) != len(set(axes)):
|
|
|
|
raise ValueError(f"duplicate value in 'axes' of reduction: {axes}")
|
2021-01-13 14:16:54 -08:00
|
|
|
if not all(0 <= a < operand.ndim for a in axes):
|
|
|
|
raise ValueError(f"reduction axes {axes} contains out-of-bounds indices for {operand}.")
|
2021-12-10 23:22:11 -08:00
|
|
|
axes = frozenset(axes)
|
|
|
|
return tuple(d for i, d in enumerate(operand.shape) if i not in axes)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _reduce_prod_translation_rule(ctx, avals_in, avals_out, operand, *, axes):
|
|
|
|
operand_aval, = avals_in
|
|
|
|
scalar = ShapedArray((), operand_aval.dtype)
|
|
|
|
return [xops.Reduce(
|
|
|
|
ctx.builder, [operand],
|
2021-10-19 08:40:15 -07:00
|
|
|
[xla.pyval_to_ir_constant(ctx.builder, np.array(1, operand_aval.dtype))],
|
2021-11-30 05:34:00 -08:00
|
|
|
xla.primitive_subcomputation(ctx.platform, ctx.axis_env, mul_p, scalar,
|
|
|
|
scalar), axes)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
2021-03-23 10:31:02 -04:00
|
|
|
reducer = lambda x, y: [mul(x, y)]
|
|
|
|
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
|
|
|
|
primals, tangents, axes)
|
|
|
|
return primals_out[0], tangents_out[0]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
reduce_prod_p = standard_primitive(
|
|
|
|
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
|
|
|
'reduce_prod', _reduce_prod_translation_rule)
|
|
|
|
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
|
|
|
batching.defreducer(reduce_prod_p)
|
|
|
|
_masking_defreducer(reduce_prod_p,
|
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(1, dtype), shape))
|
|
|
|
|
|
|
|
|
|
|
|
def _reduce_chooser_shape_rule(operand, *, axes):
|
|
|
|
return tuple(np.delete(operand.shape, axes))
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _reduce_chooser_translation_rule(prim, identity, ctx, avals_in, avals_out,
|
|
|
|
operand, *, axes):
|
|
|
|
operand_aval, = avals_in
|
|
|
|
scalar = ShapedArray((), operand_aval.dtype)
|
2021-11-15 18:26:05 -08:00
|
|
|
return [xops.Reduce(
|
|
|
|
ctx.builder, [operand],
|
|
|
|
[xla.pyval_to_ir_constant(ctx.builder, identity(operand_aval.dtype))],
|
2021-11-30 05:34:00 -08:00
|
|
|
xla.primitive_subcomputation(ctx.platform, ctx.axis_env, prim, scalar,
|
|
|
|
scalar), axes)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
|
|
|
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
|
|
|
|
# locations in a single pass (rather than comparing equality) and use a
|
|
|
|
# gather, and/or even push along the chosen elements of g (b/112040122)
|
|
|
|
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
|
|
|
location_indicators = convert_element_type(
|
|
|
|
_eq_meet(operand, reshape(ans, shape)), g.dtype)
|
|
|
|
counts = _reduce_sum(location_indicators, axes)
|
|
|
|
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
|
|
|
|
|
|
|
|
_reduce_max_translation_rule = partial(_reduce_chooser_translation_rule, max_p,
|
|
|
|
_get_max_identity)
|
|
|
|
reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
|
|
|
'reduce_max', _reduce_max_translation_rule)
|
|
|
|
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
|
|
|
batching.defreducer(reduce_max_p)
|
|
|
|
_masking_defreducer(reduce_max_p,
|
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(-np.inf, dtype), shape))
|
|
|
|
|
|
|
|
|
|
|
|
_reduce_min_translation_rule = partial(
|
|
|
|
_reduce_chooser_translation_rule, min_p, _get_min_identity)
|
|
|
|
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
|
|
|
'reduce_min', _reduce_min_translation_rule)
|
|
|
|
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
|
|
|
batching.defreducer(reduce_min_p)
|
|
|
|
_masking_defreducer(reduce_min_p,
|
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(np.inf, dtype), shape))
|
|
|
|
|
|
|
|
|
|
|
|
def _argminmax_shape_rule(operand, *, axes, index_dtype):
|
|
|
|
axis, = axes
|
2021-07-12 01:11:17 -07:00
|
|
|
if not (0 <= axis < len(operand.shape)):
|
|
|
|
raise ValueError(f"Invalid axis {axis} for operand shape {operand.shape}")
|
|
|
|
if not core.greater_equal_dim(operand.shape[axis], 1):
|
|
|
|
raise ValueError("argmin and argmax require non-empty reduced dimension. "
|
|
|
|
f"operand.shape={operand.shape} axis={axis}")
|
2020-10-17 14:33:26 -04:00
|
|
|
return tuple(np.delete(operand.shape, axis))
|
|
|
|
|
|
|
|
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
2020-11-17 09:45:48 +01:00
|
|
|
if not dtypes.issubdtype(index_dtype, np.integer):
|
|
|
|
raise TypeError("index_dtype must be an integer type, but got {}"
|
|
|
|
.format(np.dtype(index_dtype).name))
|
2020-10-17 14:33:26 -04:00
|
|
|
return index_dtype
|
|
|
|
|
2021-07-12 01:11:17 -07:00
|
|
|
def _compute_argminmax(value_comparator, get_identity,
|
|
|
|
operand, *, index_dtype, axes):
|
|
|
|
# value_comparator is either lax.lt (for argmin) or lax.gt
|
|
|
|
# get_identity(operand.dtype) is inf for argmin or -inf for argmax
|
2020-10-17 14:33:26 -04:00
|
|
|
axis, = axes
|
2021-07-12 01:11:17 -07:00
|
|
|
indices = broadcasted_iota(index_dtype, np.shape(operand), axis)
|
|
|
|
def reducer_fn(op_val_index, acc_val_index):
|
|
|
|
op_val, op_index = op_val_index
|
|
|
|
acc_val, acc_index = acc_val_index
|
|
|
|
# Pick op_val if Lt (for argmin) or if NaN
|
|
|
|
pick_op_val = bitwise_or(value_comparator(op_val, acc_val),
|
|
|
|
ne(op_val, op_val))
|
|
|
|
# If x and y are not NaN and x = y, then pick the first
|
|
|
|
pick_op_index = bitwise_or(pick_op_val,
|
|
|
|
bitwise_and(eq(op_val, acc_val),
|
|
|
|
lt(op_index, acc_index)))
|
|
|
|
return (select(pick_op_val, op_val, acc_val),
|
|
|
|
select(pick_op_index, op_index, acc_index))
|
|
|
|
res = reduce([operand, indices],
|
|
|
|
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
|
|
|
reducer_fn,
|
|
|
|
axes)
|
|
|
|
return res[1]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-07-12 01:11:17 -07:00
|
|
|
_argmin_translation_rule = xla.lower_fun(
|
|
|
|
partial(_compute_argminmax, lt, _get_min_identity),
|
2021-10-18 18:06:48 -07:00
|
|
|
multiple_results=False, new_style=True)
|
2021-07-12 01:11:17 -07:00
|
|
|
|
|
|
|
_argmax_translation_rule = xla.lower_fun(
|
|
|
|
partial(_compute_argminmax, gt, _get_max_identity),
|
2021-10-18 18:06:48 -07:00
|
|
|
multiple_results=False, new_style=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
2021-02-08 13:37:25 -08:00
|
|
|
'argmin', _argmin_translation_rule,
|
|
|
|
weak_type_rule=_strip_weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.defreducer(argmin_p)
|
|
|
|
ad.defjvp_zero(argmin_p)
|
|
|
|
|
|
|
|
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
2021-02-08 13:37:25 -08:00
|
|
|
'argmax', _argmax_translation_rule,
|
|
|
|
weak_type_rule=_strip_weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.defreducer(argmax_p)
|
|
|
|
ad.defjvp_zero(argmax_p)
|
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
mlir.register_lowering(argmin_p, mlir.cache_lowering(mlir.lower_fun(
|
2021-11-23 18:57:45 -08:00
|
|
|
partial(_compute_argminmax, lt, _get_min_identity),
|
2021-12-16 08:34:10 -08:00
|
|
|
multiple_results=False)))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
2021-12-16 08:34:10 -08:00
|
|
|
mlir.register_lowering(argmax_p, mlir.cache_lowering(mlir.lower_fun(
|
2021-11-23 18:57:45 -08:00
|
|
|
partial(_compute_argminmax, gt, _get_max_identity),
|
2021-12-16 08:34:10 -08:00
|
|
|
multiple_results=False)))
|
2021-11-23 18:57:45 -08:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _reduce_logical_shape_rule(operand, *, axes):
|
|
|
|
if operand.dtype != np.bool_:
|
|
|
|
msg = "logical reduction requires operand dtype bool, got {}."
|
|
|
|
raise TypeError(msg.format(operand.dtype))
|
|
|
|
return tuple(np.delete(operand.shape, axes))
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _reduce_logical_translation_rule(prim, identity, ctx, avals_in, avals_out,
|
|
|
|
operand, *, axes):
|
2020-10-17 14:33:26 -04:00
|
|
|
scalar = ShapedArray((), np.bool_)
|
2021-11-15 18:26:05 -08:00
|
|
|
return [xops.Reduce(
|
|
|
|
ctx.builder, [operand],
|
|
|
|
[xla.pyval_to_ir_constant(ctx.builder, identity(np.bool_))],
|
2021-11-30 05:34:00 -08:00
|
|
|
xla.primitive_subcomputation(ctx.platform, ctx.axis_env, prim, scalar,
|
|
|
|
scalar), axes)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
_reduce_or_translation_rule = partial(_reduce_logical_translation_rule,
|
|
|
|
or_p, _get_max_identity)
|
|
|
|
reduce_or_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_),
|
2021-02-08 13:37:25 -08:00
|
|
|
'reduce_or', _reduce_or_translation_rule,
|
|
|
|
weak_type_rule=_strip_weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.defreducer(reduce_or_p)
|
|
|
|
|
|
|
|
|
|
|
|
_reduce_and_translation_rule = partial(_reduce_logical_translation_rule,
|
|
|
|
and_p, _get_min_identity)
|
|
|
|
reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_),
|
2021-02-08 13:37:25 -08:00
|
|
|
'reduce_and', _reduce_and_translation_rule,
|
|
|
|
weak_type_rule=_strip_weak_type)
|
2020-10-17 14:33:26 -04:00
|
|
|
batching.defreducer(reduce_and_p)
|
|
|
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
dtype = aval_out.dtype
|
|
|
|
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
|
|
|
mlir.ir_constants(unit_factory(aval_out.dtype)),
|
|
|
|
mlir.dense_int_elements(axes))
|
|
|
|
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
|
|
|
|
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
|
|
|
with ir.InsertionPoint(reducer_region):
|
|
|
|
add = reducer(*reducer_region.arguments)
|
|
|
|
mhlo.ReturnOp(add.results)
|
|
|
|
return op.results
|
|
|
|
|
|
|
|
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, mhlo.AddOp,
|
|
|
|
lambda dtype: np.array(0, dtype)))
|
|
|
|
mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, mhlo.MulOp,
|
|
|
|
lambda dtype: np.array(1, dtype)))
|
|
|
|
mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, mhlo.OrOp,
|
|
|
|
lambda dtype: np.array(False, dtype)))
|
|
|
|
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, mhlo.AndOp,
|
|
|
|
lambda dtype: np.array(True, dtype)))
|
2021-11-30 06:08:26 -08:00
|
|
|
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_mhlo,
|
2021-11-23 18:57:45 -08:00
|
|
|
_get_min_identity))
|
2021-11-30 06:08:26 -08:00
|
|
|
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_mhlo,
|
2021-11-23 18:57:45 -08:00
|
|
|
_get_max_identity))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2021-04-05 09:54:14 -07:00
|
|
|
def _reduce_precision_shape_rule(operand, *, exponent_bits, mantissa_bits):
|
|
|
|
exponent_bits = operator.index(exponent_bits)
|
|
|
|
mantissa_bits = operator.index(mantissa_bits)
|
|
|
|
if exponent_bits < 1:
|
|
|
|
raise ValueError(f"reduce_precision: exponent_bits must be positive; got {exponent_bits}")
|
|
|
|
if mantissa_bits < 0:
|
|
|
|
raise ValueError(f"reduce_precision: mantissa_bits must be non-negative; got {mantissa_bits}")
|
|
|
|
return operand.shape
|
|
|
|
|
|
|
|
|
|
|
|
reduce_precision_p = standard_primitive(
|
|
|
|
_reduce_precision_shape_rule,
|
|
|
|
partial(unop_dtype_rule, _identity, _float, 'reduce_precision'),
|
|
|
|
name='reduce_precision')
|
2021-04-05 14:16:50 -07:00
|
|
|
batching.defvectorized(reduce_precision_p)
|
|
|
|
masking.defvectorized(reduce_precision_p)
|
2021-04-05 09:54:14 -07:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
|
|
|
|
mlir.i32_attr(exponent_bits),
|
|
|
|
mlir.i32_attr(mantissa_bits)).results
|
|
|
|
|
|
|
|
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
2021-04-05 09:54:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
_UINT_DTYPES = {
|
2021-10-19 06:48:34 -07:00
|
|
|
16: np.dtype(np.uint16),
|
|
|
|
32: np.dtype(np.uint32),
|
|
|
|
64: np.dtype(np.uint64),
|
2020-10-17 14:33:26 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
_INT_DTYPES = {
|
2021-10-19 06:48:34 -07:00
|
|
|
16: np.dtype(np.int16),
|
|
|
|
32: np.dtype(np.int32),
|
|
|
|
64: np.dtype(np.int64),
|
2020-10-17 14:33:26 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _sort_abstract_eval(*args, **kwargs):
|
|
|
|
args = tuple(raise_to_shaped(arg) for arg in args)
|
|
|
|
if any(arg.shape != args[0].shape for arg in args[1:]):
|
|
|
|
shapes = " ".join(str(a.shape) for a in args)
|
|
|
|
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def _float_to_int_for_sort(x):
|
|
|
|
# Switch from a floating point value to a integer value in such a way that
|
|
|
|
# when using the integer value to compare, we get the same result for normal
|
|
|
|
# values, and -nan is treated as the smallest value, and nan is treated as
|
|
|
|
# the largest value.
|
|
|
|
# If f is a float, and
|
|
|
|
# x = bit_cast<int32>(f);
|
|
|
|
# y = x < 0 ? int32_max - x : x;
|
|
|
|
# then y is ordered as an int32 such that finite values have the obvious
|
2022-01-13 13:03:41 -08:00
|
|
|
# order. In this scheme, -0 would be before 0, and -NaN and NaN appear at
|
|
|
|
# the beginning and end of the ordering. This causes issues for stable
|
|
|
|
# sorts, so we avoid this by standardizing the representation of zeros
|
|
|
|
# and NaNs in the output.
|
2020-10-17 14:33:26 -04:00
|
|
|
# Note that in order to avoid -x to overflow, we calculate
|
|
|
|
# int32_max - x as unsigned, and then convert back to signed.
|
|
|
|
if x.dtype == dtypes.bfloat16:
|
|
|
|
x = convert_element_type(x, np.float32)
|
|
|
|
nbits = np.finfo(x).bits
|
|
|
|
signed_dtype = _INT_DTYPES[nbits]
|
|
|
|
unsigned_dtype = _UINT_DTYPES[nbits]
|
|
|
|
|
|
|
|
signed = bitcast_convert_type(x, signed_dtype)
|
|
|
|
unsigned = bitcast_convert_type(x, unsigned_dtype)
|
2022-01-13 13:03:41 -08:00
|
|
|
|
|
|
|
# We cannot standardize zeros in x because XLA elides this is some cases.
|
|
|
|
# We cannot standardize NaNs in x because it triggers jax.debug_nans
|
|
|
|
# So instead we do these replacements in the signed integer representation.
|
|
|
|
|
|
|
|
# Standardize zeros:
|
|
|
|
signed = select(eq(x, _zero(x)), _zeros(signed), signed)
|
|
|
|
# Standardize nans:
|
|
|
|
signed_nan = x.dtype.type(np.nan).view(signed_dtype)
|
|
|
|
signed = select(_isnan(x), full_like(signed, signed_nan), signed)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
flipped = bitcast_convert_type(
|
2021-10-19 06:48:34 -07:00
|
|
|
sub(unsigned_dtype.type(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
return select(lt(signed, _zero(signed)), flipped, signed)
|
|
|
|
|
|
|
|
# Default comparator that sorts the operands lexicographically on the
|
|
|
|
# first `num_keys` arguments.
|
|
|
|
# For floating point types, a total order is created where
|
2022-01-13 13:03:41 -08:00
|
|
|
# -infinity < ... < 0 < ... < infinity < NaN.
|
|
|
|
# 0.0 and -0.0 are treated as equivalent, as are all NaN representations.
|
2020-10-17 14:33:26 -04:00
|
|
|
# For complex types, the (real, imag) pairs are sorted lexicographically
|
|
|
|
# (following NumPy's semantics).
|
|
|
|
# This code adds complex-number support and lexicographic ordering to the algorithm from:
|
|
|
|
# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33
|
|
|
|
def _sort_lt_comparator(*operands, num_keys=1):
|
2022-01-13 13:45:59 -08:00
|
|
|
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
|
|
|
|
p = None
|
|
|
|
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
|
|
|
p = (bitwise_or(lt(xk, yk), bitwise_and(eq(xk, yk), p)) if p is not None
|
|
|
|
else lt(xk, yk))
|
|
|
|
return p
|
|
|
|
|
|
|
|
# Similar to sort_lt_comparator, but implements less than or equal. Used by
|
|
|
|
# the searchsorted() implementation.
|
|
|
|
def _sort_le_comparator(*operands, num_keys=1):
|
|
|
|
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
|
|
|
|
p = None
|
|
|
|
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
|
|
|
p = (bitwise_or(lt(xk, yk), bitwise_and(eq(xk, yk), p)) if p is not None
|
|
|
|
else le(xk, yk))
|
|
|
|
return p
|
|
|
|
|
|
|
|
def _operands_to_keys(*operands, num_keys=1):
|
2020-10-17 14:33:26 -04:00
|
|
|
assert len(operands) >= 2 and len(operands) % 2 == 0, operands
|
|
|
|
assert len(operands) // 2 >= num_keys, (operands, num_keys)
|
|
|
|
x_keys, y_keys = [], []
|
|
|
|
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
|
|
|
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
2021-11-16 17:36:28 -05:00
|
|
|
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
2020-10-17 14:33:26 -04:00
|
|
|
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
|
|
|
|
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
|
2021-11-16 17:36:28 -05:00
|
|
|
elif dtypes.issubdtype(x.dtype, np.floating):
|
2020-10-17 14:33:26 -04:00
|
|
|
x_keys.append(_float_to_int_for_sort(x))
|
|
|
|
y_keys.append(_float_to_int_for_sort(y))
|
|
|
|
else:
|
|
|
|
x_keys.append(x)
|
|
|
|
y_keys.append(y)
|
2022-01-13 13:45:59 -08:00
|
|
|
return x_keys, y_keys
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _sort_translation_rule(ctx, avals_in, avals_out, *operands, dimension,
|
|
|
|
is_stable, num_keys):
|
|
|
|
c = ctx.builder
|
2020-10-17 14:33:26 -04:00
|
|
|
types = [c.get_shape(x).xla_element_type() for x in operands]
|
2021-10-18 13:19:45 -04:00
|
|
|
subc = xc.XlaBuilder("sort_lt_comparator")
|
2021-11-30 14:24:02 -08:00
|
|
|
params = [xla.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
|
2020-10-17 14:33:26 -04:00
|
|
|
for i, typ in enumerate(types) for j in range(2)]
|
|
|
|
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
|
2021-11-15 18:26:05 -08:00
|
|
|
backend=ctx.platform,
|
2020-10-17 14:33:26 -04:00
|
|
|
multiple_results=False)(subc, *params)
|
|
|
|
comparator = subc.build(result)
|
|
|
|
out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable,
|
|
|
|
comparator=comparator)
|
2021-10-18 18:06:48 -07:00
|
|
|
return xla.xla_destructure(c, out) if len(operands) != 1 else [out]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
|
|
|
|
shape = primals[0].shape
|
|
|
|
iotas = []
|
|
|
|
for dim, size in enumerate(shape):
|
|
|
|
dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64
|
|
|
|
iotas.append(broadcasted_iota(dtype, shape, dim))
|
|
|
|
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
|
|
|
|
is_stable=is_stable, num_keys=num_keys)
|
|
|
|
idx = tuple(primals[-1] if i == dimension else iotas[i]
|
|
|
|
for i in range(len(shape)))
|
|
|
|
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
|
|
|
|
return tuple(primals[:-1]), tangents_out
|
|
|
|
|
|
|
|
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
|
|
|
|
prototype_arg, new_bdim = next(
|
|
|
|
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
|
|
|
new_args = []
|
|
|
|
for arg, bdim in zip(batched_args, batch_dims):
|
|
|
|
if bdim is None:
|
|
|
|
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
|
|
|
|
new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
|
|
|
|
else:
|
|
|
|
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
|
|
|
new_dimension = dimension + (new_bdim <= dimension)
|
|
|
|
bdims = (new_bdim,) * len(new_args)
|
|
|
|
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),
|
|
|
|
bdims)
|
|
|
|
|
|
|
|
|
|
|
|
sort_p = Primitive('sort')
|
|
|
|
sort_p.multiple_results = True
|
|
|
|
sort_p.def_impl(partial(xla.apply_primitive, sort_p))
|
|
|
|
sort_p.def_abstract_eval(_sort_abstract_eval)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(sort_p, _sort_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_jvps[sort_p] = _sort_jvp
|
|
|
|
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
|
|
|
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
|
|
|
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
|
|
|
sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.flatten_lowering_ir_args(operands),
|
|
|
|
mlir.i64_attr(dimension), ir.BoolAttr.get(is_stable))
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
|
2021-11-23 18:57:45 -08:00
|
|
|
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
|
|
|
|
comparator = sort.comparator.blocks.append(
|
|
|
|
*util.flatten(zip(scalar_types, scalar_types)))
|
|
|
|
with ir.InsertionPoint(comparator):
|
|
|
|
lower_comparator = mlir.lower_fun(partial(_sort_lt_comparator),
|
|
|
|
multiple_results=False)
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
sub_ctx = mlir.LoweringRuleContext(
|
|
|
|
module_context = ctx.module_context,
|
2021-12-16 08:34:10 -08:00
|
|
|
primitive=None,
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
|
|
|
|
avals_out=[core.ShapedArray((), np.bool_)])
|
|
|
|
|
|
|
|
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
|
2021-11-23 18:57:45 -08:00
|
|
|
num_keys=num_keys)
|
|
|
|
mhlo.ReturnOp(util.flatten(out))
|
|
|
|
return sort.results
|
|
|
|
|
|
|
|
mlir.register_lowering(sort_p, _sort_lower)
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _top_k_abstract_eval(operand, *, k):
|
|
|
|
if k < 0:
|
|
|
|
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
|
|
|
|
if len(operand.shape) == 0:
|
|
|
|
raise TypeError("top_k operand must have >= 1 dimension, got {}"
|
|
|
|
.format(operand.shape))
|
|
|
|
shape = list(operand.shape)
|
|
|
|
if shape[-1] < k:
|
|
|
|
msg = "k argument to top_k must be no larger than minor dimension; {} vs {}"
|
|
|
|
raise ValueError(msg.format(k, shape))
|
|
|
|
shape[-1] = k
|
2021-03-28 10:32:02 -07:00
|
|
|
return (operand.update(shape=shape, dtype=operand.dtype,
|
|
|
|
weak_type=operand.weak_type),
|
2021-01-27 15:13:30 -08:00
|
|
|
operand.update(shape=shape, dtype=np.dtype(np.int32)))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _top_k_jvp(primals, tangents, *, k):
|
|
|
|
operand, = primals
|
|
|
|
tangent, = tangents
|
|
|
|
primals_out = top_k(operand, k)
|
|
|
|
if type(tangent) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(primals_out[0])
|
|
|
|
else:
|
|
|
|
_, k_idxs = primals_out
|
|
|
|
idx_shape = k_idxs.shape
|
|
|
|
rank = len(idx_shape)
|
|
|
|
gather_index_shape = idx_shape + (1,)
|
|
|
|
gather_indices = []
|
|
|
|
for i in range(rank-1):
|
|
|
|
_iota = iota(k_idxs.dtype, idx_shape[i])
|
|
|
|
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
|
|
|
|
gather_indices.append(_iota)
|
|
|
|
gather_indices.append(reshape(k_idxs, gather_index_shape))
|
|
|
|
gather_indices = concatenate(gather_indices, dimension=rank)
|
|
|
|
slice_sizes = (1,) * rank
|
2021-11-23 16:34:33 -08:00
|
|
|
dnums = slicing.GatherDimensionNumbers(
|
2020-10-17 14:33:26 -04:00
|
|
|
offset_dims=(),
|
|
|
|
collapsed_slice_dims=tuple(range(rank)),
|
|
|
|
start_index_map=tuple(range(rank)))
|
2021-11-23 16:34:33 -08:00
|
|
|
tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes)
|
2020-10-17 14:33:26 -04:00
|
|
|
return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1]))
|
|
|
|
|
|
|
|
def _top_k_batch_rule(batched_args, batch_dims, *, k):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
if bdim == operand.ndim-1:
|
|
|
|
perm = np.arange(operand.ndim)
|
|
|
|
perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
|
|
|
|
top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
|
|
|
|
return (transpose(top_k_v, perm),
|
|
|
|
transpose(top_k_i, perm)), (bdim, bdim)
|
|
|
|
else:
|
|
|
|
return top_k(operand, k=k), (bdim, bdim)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
|
|
|
|
return xla.xla_destructure(ctx.builder, xops.TopK(x, k))
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
top_k_p = Primitive('top_k')
|
|
|
|
top_k_p.multiple_results = True
|
|
|
|
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
|
|
|
|
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(top_k_p, _top_k_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
|
|
|
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
|
|
|
|
|
|
|
|
def _stop_gradient_jvp_rule(primals, tangents):
|
|
|
|
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
|
|
|
x, = primals
|
|
|
|
return stop_gradient(x), ad_util.Zero.from_value(x)
|
|
|
|
|
|
|
|
def _stop_gradient_batch_rule(batched_args, batch_dims):
|
|
|
|
x, = batched_args
|
|
|
|
dim, = batch_dims
|
|
|
|
return stop_gradient(x), dim
|
|
|
|
|
|
|
|
ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
|
|
|
|
batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
|
|
|
|
|
|
|
|
|
2020-12-17 19:49:30 -08:00
|
|
|
def create_token(_=None):
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Creates an XLA token value with no preconditions for sequencing effects.
|
|
|
|
|
|
|
|
Experimental.
|
|
|
|
|
2020-12-17 19:49:30 -08:00
|
|
|
The argument is ignored. It exists for backward compatibility.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
2021-03-29 13:58:04 -07:00
|
|
|
return create_token_p.bind()
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
create_token_p = Primitive("create_token")
|
|
|
|
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
|
2020-12-17 19:49:30 -08:00
|
|
|
create_token_p.def_abstract_eval(lambda *_: abstract_token)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(create_token_p,
|
|
|
|
lambda ctx, *_: [xops.CreateToken(ctx.builder)])
|
2020-10-17 14:33:26 -04:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _create_token_lowering(ctx, *operands):
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
|
|
|
|
|
|
|
|
mlir.register_lowering(create_token_p, _create_token_lowering)
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def after_all(*operands):
|
|
|
|
"""Merges one or more XLA token values. Experimental.
|
|
|
|
|
|
|
|
Wraps the XLA AfterAll operator."""
|
|
|
|
return after_all_p.bind(*operands)
|
|
|
|
|
|
|
|
def _after_all_abstract_eval(*operands):
|
|
|
|
if any(x is not abstract_token for x in operands):
|
|
|
|
raise TypeError("Arguments to after_all must be tokens")
|
|
|
|
return abstract_token
|
|
|
|
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _after_all_translation_rule(ctx, avals_in, avals_out, *operands):
|
|
|
|
return [xops.AfterAll(ctx.builder, operands)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
after_all_p = Primitive("after_all")
|
|
|
|
after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
|
|
|
|
after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(after_all_p, _after_all_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _after_all_lowering(ctx, *operands):
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
|
|
|
|
|
|
|
|
mlir.register_lowering(after_all_p, _after_all_lowering)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def infeed(token, shape=None, partitions=None):
|
|
|
|
"""Consumes an infeed value of `shape` from the host. Experimental.
|
|
|
|
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
2020-12-24 09:04:02 +11:00
|
|
|
`partitions` may be specified inside a `sharded_jit` function.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
|
|
|
flat_shapes, treedef = pytree.flatten(shape)
|
|
|
|
for shape in flat_shapes:
|
|
|
|
if not isinstance(shape, ShapedArray):
|
|
|
|
raise TypeError("shape argument to infeed must be a pytree of "
|
|
|
|
"ShapedArray values, got {}".format(shape))
|
|
|
|
if partitions is not None:
|
|
|
|
# Always replicate token.
|
|
|
|
# We specifically use type() to raise an error for PartitionSpecs.
|
|
|
|
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
|
|
|
raise ValueError(f"'partitions' argument to infeed should be a tuple, "
|
|
|
|
f"got {partitions}")
|
|
|
|
partitions = partitions + (None,)
|
|
|
|
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes),
|
|
|
|
partitions=partitions)
|
|
|
|
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
|
|
|
|
|
|
|
|
def _infeed_abstract_eval(token, *, shapes, partitions):
|
|
|
|
if token is not abstract_token:
|
|
|
|
raise TypeError("First argument to infeed must be a token")
|
|
|
|
return shapes + (abstract_token,)
|
|
|
|
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _infeed_translation_rule(ctx, avals_in, avals_out, token, *, shapes,
|
|
|
|
partitions):
|
|
|
|
c = ctx.builder
|
2020-10-17 14:33:26 -04:00
|
|
|
shape = tuple(shape.with_major_to_minor_layout_if_absent()
|
|
|
|
for x in shapes for shape in xla.aval_to_xla_shapes(x))
|
|
|
|
build_infeed = partial(xops.InfeedWithToken, token,
|
|
|
|
xla_client.Shape.tuple_shape(shape))
|
|
|
|
if partitions:
|
2021-11-30 14:24:02 -08:00
|
|
|
xs_and_token = xla.with_sharding(c, partitions, build_infeed)
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
|
|
|
# Note that infeed will default to replication if inside a sharded
|
|
|
|
# computation and no sharding is specified.
|
|
|
|
xs_and_token = build_infeed()
|
|
|
|
xs = xops.GetTupleElement(xs_and_token, 0)
|
|
|
|
token = xops.GetTupleElement(xs_and_token, 1)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.GetTupleElement(xs, i) for i in range(len(shapes))] + [token]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
infeed_p = Primitive("infeed")
|
|
|
|
infeed_p.multiple_results = True
|
|
|
|
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
|
|
|
|
infeed_p.def_abstract_eval(_infeed_abstract_eval)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(infeed_p, _infeed_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _infeed_lowering(ctx, token, *, shapes, partitions):
|
|
|
|
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out[:-1])
|
2021-11-23 18:57:45 -08:00
|
|
|
flat_output_types = util.flatten(output_types)
|
|
|
|
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
|
|
|
|
# TODO(phawkins): verify `shapes` have a major-to-minor layout.
|
|
|
|
layouts = ir.ArrayAttr.get([
|
|
|
|
ir.ArrayAttr.get(
|
|
|
|
[ir.ArrayAttr.get(
|
|
|
|
[mlir.i64_attr(i) for i in range(len(aval.shape) - 1, -1, -1)])
|
|
|
|
for aval in shapes]),
|
|
|
|
ir.UnitAttr.get(),
|
|
|
|
])
|
|
|
|
output_and_token_tuple_type = ir.TupleType.get_tuple(
|
|
|
|
[output_tuple_type, mhlo.TokenType.get()])
|
2021-12-02 16:24:02 -08:00
|
|
|
infeed = mhlo.InfeedOp(
|
2021-11-23 18:57:45 -08:00
|
|
|
output_and_token_tuple_type, token, ir.StringAttr.get(""),
|
2021-12-02 16:24:02 -08:00
|
|
|
layouts)
|
|
|
|
if partitions is not None:
|
|
|
|
mlir.set_sharding(infeed, xla.sharding_to_proto(partitions))
|
|
|
|
outs_tuple = mhlo.GetTupleElementOp(output_tuple_type, infeed.result,
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.i32_attr(0)).result
|
2021-12-02 16:24:02 -08:00
|
|
|
token = mhlo.GetTupleElementOp(mhlo.TokenType.get(), infeed.result,
|
2021-11-23 18:57:45 -08:00
|
|
|
mlir.i32_attr(1)).result
|
|
|
|
outs = [mhlo.GetTupleElementOp(typ, outs_tuple, mlir.i32_attr(i)).result
|
|
|
|
for i, typ in enumerate(flat_output_types)]
|
2021-11-24 11:15:04 -08:00
|
|
|
return util.unflatten(outs, safe_map(len, output_types)) + [[token,]]
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(infeed_p, _infeed_lowering)
|
|
|
|
|
|
|
|
|
2021-07-01 11:59:13 -07:00
|
|
|
def outfeed(token, xs, partitions = None):
|
2020-10-17 14:33:26 -04:00
|
|
|
"""Outfeeds value `xs` to the host. Experimental.
|
|
|
|
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
2021-07-01 11:59:13 -07:00
|
|
|
`partitions` may be specified inside a `sharded_jit` or `pjit` function.
|
2020-10-17 14:33:26 -04:00
|
|
|
"""
|
2021-07-01 11:59:13 -07:00
|
|
|
if partitions is not None:
|
|
|
|
# We specifically use type() to raise an error for PartitionSpecs.
|
|
|
|
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
|
|
|
raise ValueError(f"'partitions' argument to outfeed should be a tuple, "
|
|
|
|
f"got {partitions}")
|
2020-10-17 14:33:26 -04:00
|
|
|
flat_xs, _ = pytree.flatten(xs)
|
2021-07-01 11:59:13 -07:00
|
|
|
return outfeed_p.bind(token, *flat_xs, partitions=partitions)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-07-01 11:59:13 -07:00
|
|
|
def _outfeed_abstract_eval(token, *xs, partitions):
|
2020-10-17 14:33:26 -04:00
|
|
|
if token is not abstract_token:
|
|
|
|
raise TypeError("First argument to outfeed must be a token")
|
|
|
|
return abstract_token
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _outfeed_translation_rule(ctx, avals_in, avals_out, token, *xs, partitions):
|
|
|
|
c = ctx.builder
|
2020-10-17 14:33:26 -04:00
|
|
|
t = xops.Tuple(c, xs)
|
2021-07-01 11:59:13 -07:00
|
|
|
if partitions is not None:
|
2021-11-30 14:24:02 -08:00
|
|
|
return [xla.with_sharding(c, partitions, xops.OutfeedWithToken,
|
|
|
|
t, token, c.get_shape(t))]
|
2021-07-01 11:59:13 -07:00
|
|
|
else:
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.OutfeedWithToken(t, token, c.get_shape(t))]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
outfeed_p = Primitive("outfeed")
|
|
|
|
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
|
|
|
|
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(outfeed_p, _outfeed_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-11-23 18:57:45 -08:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _outfeed_lowering(ctx, token, *xs, partitions):
|
|
|
|
token_aval = ctx.avals_in[0]
|
|
|
|
xs_avals = ctx.avals_in[1:]
|
2021-11-23 18:57:45 -08:00
|
|
|
input_types = map(mlir.aval_to_ir_types, xs_avals)
|
|
|
|
flat_input_types = util.flatten(input_types)
|
|
|
|
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
|
|
|
|
tup = mhlo.TupleOp(input_tuple_type, mlir.flatten_lowering_ir_args(xs)).result
|
2021-12-02 16:24:02 -08:00
|
|
|
outfeed = mhlo.OutfeedOp(mlir.aval_to_ir_type(token_aval), tup, token,
|
|
|
|
ir.StringAttr.get(""))
|
|
|
|
if partitions is not None:
|
|
|
|
mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions))
|
|
|
|
return outfeed.results
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
|
|
mlir.register_lowering(outfeed_p, _outfeed_lowering)
|
|
|
|
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def rng_uniform(a, b, shape):
|
|
|
|
"""Stateful PRNG generator. Experimental and its use is discouraged.
|
|
|
|
|
|
|
|
Returns uniformly distributed random numbers in the range [a, b)
|
|
|
|
|
|
|
|
You should use jax.random for most purposes; this function exists only for
|
|
|
|
niche use cases with special performance requirements.
|
|
|
|
|
|
|
|
This API may be removed at any time.
|
|
|
|
"""
|
|
|
|
return rng_uniform_p.bind(a, b, shape=tuple(shape))
|
|
|
|
|
|
|
|
def _rng_uniform_abstract_eval(a, b, *, shape):
|
|
|
|
if a.dtype != b.dtype:
|
|
|
|
raise ValueError(
|
|
|
|
"Arguments to rng_uniform must have identical dtypes, got {} "
|
|
|
|
"and {}.".format(a.dtype, b.dtype))
|
|
|
|
if a.shape != () or b.shape != ():
|
|
|
|
raise ValueError(
|
|
|
|
"Arguments to rng_uniform must be scalars; got shapes {} and {}."
|
|
|
|
.format(a.shape, b.shape))
|
2021-03-28 10:32:02 -07:00
|
|
|
return a.update(shape=shape, dtype=a.dtype,
|
|
|
|
weak_type=(a.weak_type and b.weak_type))
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _rng_uniform_translation_rule(ctx, avals_in, avals_out, a, b, *, shape):
|
|
|
|
c = ctx.builder
|
2020-10-17 14:33:26 -04:00
|
|
|
xla_shape = xc.Shape.array_shape(c.get_shape(a).xla_element_type(), shape)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.RngUniform(a, b, xla_shape)]
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
rng_uniform_p = Primitive("rng_uniform")
|
|
|
|
rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p))
|
|
|
|
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(rng_uniform_p, _rng_uniform_translation_rule)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _rng_uniform_lowering(ctx, a, b, *, shape):
|
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
canonicalize_types=False)
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.RngUniformOp(a, b, shape).results
|
|
|
|
|
|
|
|
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
|
|
|
|
2020-10-23 07:34:32 -07:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
|
2021-10-01 18:15:00 -07:00
|
|
|
del dtype, algorithm
|
2021-03-16 12:13:41 -04:00
|
|
|
return (key.shape, tuple(shape))
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
|
2021-10-01 18:15:00 -07:00
|
|
|
del shape, algorithm
|
2021-03-16 12:13:41 -04:00
|
|
|
return (key.dtype, dtype)
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
|
2021-10-01 18:15:00 -07:00
|
|
|
del shape, dtype, algorithm
|
2021-03-16 12:13:41 -04:00
|
|
|
return (key.weak_type, False)
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _rng_bit_generator_translation_rule(
|
|
|
|
ctx, avals_in, avals_out, key, *, shape, dtype, algorithm):
|
|
|
|
c = ctx.builder
|
2021-10-01 18:15:00 -07:00
|
|
|
key_shape, key_dtype = c.get_shape(key).dimensions(), c.get_shape(key).numpy_dtype()
|
|
|
|
# While the RngBitGenerator HLO accepts a u64[2] key on all backends, we
|
|
|
|
# typically represent the key argument to this primitive as a u32[4] so as to
|
|
|
|
# sidestep issues with the jax_enable_x64=False configuration. As a result, we
|
|
|
|
# need to convert u32[4] -> u64[2] here in the translation rule. However, we
|
|
|
|
# also polymorphically allow a u64[2] for backward compatibility.
|
2021-11-13 05:04:31 -08:00
|
|
|
assert ((key_shape == (4,) and key_dtype == np.dtype('uint32')) or
|
|
|
|
(key_shape == (2,) and key_dtype == np.dtype('uint64'))), (key_shape, key_dtype)
|
2021-03-16 12:13:41 -04:00
|
|
|
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
|
2021-11-13 05:04:31 -08:00
|
|
|
if key_dtype == np.dtype('uint32'):
|
2021-11-18 14:55:19 -05:00
|
|
|
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
|
|
|
|
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
|
2021-10-05 13:46:57 -07:00
|
|
|
out_key, out_vals = xla.xla_destructure(
|
|
|
|
c, xops.RngBitGenerator(algorithm, key, xla_shape))
|
2021-11-13 05:04:31 -08:00
|
|
|
if key_dtype == np.dtype('uint32'):
|
2021-11-18 14:55:19 -05:00
|
|
|
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
|
|
|
|
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
|
2021-10-18 18:06:48 -07:00
|
|
|
return [out_key, out_vals]
|
2021-10-05 13:46:57 -07:00
|
|
|
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
|
|
|
return [key.named_shape, key.named_shape]
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
rng_bit_generator_p = Primitive("rng_bit_generator")
|
|
|
|
rng_bit_generator_p.multiple_results = True
|
|
|
|
rng_bit_generator_p.def_impl(
|
|
|
|
partial(xla.apply_primitive, rng_bit_generator_p))
|
|
|
|
rng_bit_generator_p.def_abstract_eval(
|
|
|
|
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
|
|
|
|
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
|
|
|
_rng_bit_generator_weak_type_rule,
|
|
|
|
_rng_bit_generator_named_shape_rule))
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(rng_bit_generator_p,
|
|
|
|
_rng_bit_generator_translation_rule)
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
RandomAlgorithm = xops.RandomAlgorithm
|
2021-07-26 13:44:57 +01:00
|
|
|
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[assignment]
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
|
|
|
2021-10-05 13:46:57 -07:00
|
|
|
def rng_bit_generator(key, shape, dtype=np.uint32,
|
2021-03-16 12:13:41 -04:00
|
|
|
algorithm=RandomAlgorithm.RNG_DEFAULT):
|
|
|
|
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
Returns uniformly distributed random bits with the specified shape and dtype
|
2021-04-08 10:42:25 -07:00
|
|
|
(what is required to be an integer type) using the platform specific
|
2021-03-16 12:13:41 -04:00
|
|
|
default algorithm or the one specified.
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-07-24 15:25:13 +07:00
|
|
|
It provides direct access to the RngBitGenerator primitive exposed by XLA
|
2021-03-16 12:13:41 -04:00
|
|
|
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
|
|
|
|
level API access.
|
2021-02-16 12:31:01 +00:00
|
|
|
|
2021-03-16 12:13:41 -04:00
|
|
|
Most users should use `jax.random` instead for a stable and more user
|
|
|
|
friendly API.
|
|
|
|
"""
|
|
|
|
shape = jax.core.canonicalize_shape(shape)
|
|
|
|
return tuple(
|
|
|
|
rng_bit_generator_p.bind(
|
|
|
|
key, shape=shape, dtype=dtype, algorithm=algorithm))
|
2021-02-16 12:31:01 +00:00
|
|
|
|
|
|
|
|
2020-10-23 07:34:32 -07:00
|
|
|
def _iota_abstract_eval(*, dtype, shape, dimension):
|
|
|
|
_check_shapelike("iota", "shape", shape)
|
|
|
|
if not any(dtypes.issubdtype(dtype, t) for t in _num):
|
|
|
|
msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
|
|
|
typename = str(np.dtype(dtype).name)
|
|
|
|
accepted_typenames = (t.__name__ for t in _num)
|
|
|
|
raise TypeError(msg.format(typename, ', '.join(accepted_typenames)))
|
|
|
|
if not 0 <= dimension < len(shape):
|
|
|
|
raise ValueError("iota dimension must be between 0 and len(shape), got "
|
|
|
|
f"dimension={dimension} for shape {shape}")
|
|
|
|
return ShapedArray(shape, dtype)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
def _iota_translation_rule(ctx, avals_in, avals_out, *, dtype, shape,
|
|
|
|
dimension):
|
2021-10-19 06:48:34 -07:00
|
|
|
etype = xla.dtype_to_primitive_type(dtype)
|
2020-10-23 07:34:32 -07:00
|
|
|
xla_shape = xc.Shape.array_shape(etype, shape)
|
2021-10-18 18:06:48 -07:00
|
|
|
return [xops.Iota(ctx.builder, xla_shape, dimension)]
|
2020-10-23 07:34:32 -07:00
|
|
|
|
|
|
|
iota_p = Primitive('iota')
|
|
|
|
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
|
|
|
|
iota_p.def_abstract_eval(_iota_abstract_eval)
|
2021-10-18 18:06:48 -07:00
|
|
|
xla.register_translation(iota_p, _iota_translation_rule)
|
2020-10-23 07:34:32 -07:00
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
def _iota_lower(ctx, *, dtype, shape, dimension):
|
2021-11-23 18:57:45 -08:00
|
|
|
del dtype, shape
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
aval_out, = ctx.avals_out
|
2021-11-23 18:57:45 -08:00
|
|
|
return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out),
|
|
|
|
mlir.i64_attr(dimension)).results
|
|
|
|
mlir.register_lowering(iota_p, _iota_lower)
|
|
|
|
|
2020-10-23 07:34:32 -07:00
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
### util
|
|
|
|
|
|
|
|
_ndim = np.ndim
|
|
|
|
|
|
|
|
|
|
|
|
def _dilate_shape(shape, dilation):
|
|
|
|
"""Utility function for computing the shape resulting from a dilation."""
|
|
|
|
if not np.all(np.greater(dilation, 0)):
|
|
|
|
msg = "All dilations must be positive, got {}."
|
|
|
|
raise TypeError(msg.format(dilation))
|
|
|
|
dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation)
|
2021-04-05 16:37:35 +03:00
|
|
|
return core.dilate_shape(shape, dilation)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _ceil_divide(x1, x2):
|
|
|
|
return -np.floor_divide(np.negative(x1), x2)
|
|
|
|
|
|
|
|
def padtype_to_pads(in_shape, window_shape, window_strides, padding):
|
|
|
|
"""Convert padding string to list of pairs of pad values."""
|
|
|
|
PaddingType = xla_client.PaddingType
|
|
|
|
|
|
|
|
if isinstance(padding, str):
|
|
|
|
mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME}
|
|
|
|
try:
|
|
|
|
padding = mapping[padding.upper()]
|
|
|
|
except KeyError as err:
|
|
|
|
msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}."
|
|
|
|
raise RuntimeError(msg.format(padding)) from err
|
|
|
|
|
|
|
|
if padding == PaddingType.SAME:
|
|
|
|
out_shape = _ceil_divide(in_shape, window_strides)
|
|
|
|
pad_sizes = np.maximum(0, (out_shape - 1) * window_strides +
|
|
|
|
window_shape - in_shape)
|
|
|
|
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
|
|
|
|
elif padding == PaddingType.VALID:
|
|
|
|
return [(0, 0)] * len(in_shape)
|
|
|
|
else:
|
|
|
|
msg = "Unknown padding type: {}."
|
|
|
|
raise TypeError(msg.format(padding))
|
|
|
|
|
|
|
|
|
2021-12-07 11:45:07 -08:00
|
|
|
# Map of lax function to equivalent jax.numpy function for use in error string below.
|
|
|
|
_JNP_FUNCTION_EQUIVALENTS = {
|
|
|
|
'abs': 'fabs',
|
|
|
|
'acos': 'arccos',
|
|
|
|
'acosh': 'arccosh',
|
|
|
|
'add': 'add',
|
|
|
|
'asin': 'arcsin',
|
|
|
|
'asinh': 'arcsinh',
|
|
|
|
'atan': 'arctan',
|
|
|
|
'atan2': 'arctan2',
|
|
|
|
'atanh': 'arctanh',
|
|
|
|
'bitwise_and': 'bitwise_and',
|
|
|
|
'bitwise_not': 'bitwise_not',
|
|
|
|
'bitwise_or': 'bitwise_or',
|
|
|
|
'bitwise_xor': 'bitwise_xor',
|
|
|
|
'cbrt': 'cbrt',
|
|
|
|
'ceil': 'ceil',
|
|
|
|
'concatenate': 'concatenate',
|
|
|
|
'cos': 'cos',
|
|
|
|
'cosh': 'cosh',
|
|
|
|
'div': 'divide',
|
|
|
|
'eq': 'equal',
|
|
|
|
'exp': 'exp',
|
|
|
|
'expm1': 'expm1',
|
|
|
|
'floor': 'floor',
|
|
|
|
'greater': 'greater',
|
|
|
|
'greater_equal': 'greater_equal',
|
|
|
|
'less': 'less',
|
|
|
|
'less_equal': 'less_equal',
|
|
|
|
'log': 'log',
|
|
|
|
'logical_and': 'logical_and',
|
|
|
|
'logical_not': 'logical_not',
|
|
|
|
'logical_or': 'logical_or',
|
|
|
|
'logical_xor': 'logical_xor',
|
|
|
|
'log1p': 'log1p',
|
|
|
|
'max': 'maximum',
|
|
|
|
'min': 'minimum',
|
|
|
|
'mul': 'multiply',
|
|
|
|
'ne': 'not_equal',
|
|
|
|
'neg': 'negative',
|
|
|
|
'nextafter': 'nextafter',
|
|
|
|
'pow': 'float_power',
|
|
|
|
'rount': 'rount',
|
|
|
|
'select': 'where',
|
|
|
|
'shift_left': 'left_shift',
|
|
|
|
'shift_right_logical': 'right_shift',
|
|
|
|
'shift_right_arithmetic': 'right_shift',
|
|
|
|
'sign': 'sign',
|
|
|
|
'sin': 'sin',
|
|
|
|
'sinh': 'sinh',
|
|
|
|
'sqrt': 'sqrt',
|
|
|
|
'sub': 'subtract',
|
|
|
|
'tan': 'tan',
|
|
|
|
'tanh': 'tanh'
|
|
|
|
}
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
|
|
|
|
"""Check that dtypes agree, possibly ignoring float precision."""
|
|
|
|
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
|
|
|
# allows mixed floating point precision, but the HLO verifier often rejects it
|
|
|
|
types = list(map(np.dtype, ttypes)) # canonicalize
|
|
|
|
if ignore_fp_precision:
|
|
|
|
types = [
|
|
|
|
np.floating if dtypes.issubdtype(dtype, np.floating)
|
|
|
|
else np.complexfloating if dtypes.issubdtype(dtype, np.complexfloating)
|
|
|
|
else dtype for dtype in types]
|
|
|
|
if len({dtypes.canonicalize_dtype(t) for t in types}) != 1:
|
|
|
|
if ignore_fp_precision:
|
2021-12-07 11:45:07 -08:00
|
|
|
msg = ("lax.{} requires arguments to have same dtypes up to floating point "
|
2020-10-17 14:33:26 -04:00
|
|
|
"precision, got {}.")
|
|
|
|
else:
|
2021-12-07 11:45:07 -08:00
|
|
|
msg = "lax.{} requires arguments to have the same dtypes, got {}."
|
|
|
|
if name in _JNP_FUNCTION_EQUIVALENTS:
|
|
|
|
equiv = _JNP_FUNCTION_EQUIVALENTS[name]
|
|
|
|
msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
|
2020-10-17 14:33:26 -04:00
|
|
|
raise TypeError(msg.format(name, ", ".join(map(str, types))))
|
|
|
|
|
|
|
|
|
|
|
|
def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
|
|
|
|
"""Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints)."""
|
|
|
|
if not isinstance(obj, (tuple, list, np.ndarray)):
|
|
|
|
msg = "{} {} must be of type tuple/list/ndarray, got {}."
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, type(obj)))
|
|
|
|
# bool(obj) for an ndarray raises an error, so we check len
|
|
|
|
if not len(obj): # pylint: disable=g-explicit-length-test
|
|
|
|
return
|
|
|
|
obj_arr = np.array(obj)
|
|
|
|
if obj_arr.ndim != 1:
|
|
|
|
msg = "{} {} must be rank 1, got {}."
|
|
|
|
raise TypeError(msg.format(obj_arr.ndim))
|
|
|
|
try:
|
|
|
|
canonicalize_shape(obj_arr)
|
|
|
|
except TypeError as err:
|
|
|
|
msg = "{} {} must have every element be an integer type, got {}."
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) from err
|
|
|
|
lower_bound, bound_error = (
|
|
|
|
(1, "strictly positive") if non_zero_shape else (0, "nonnegative"))
|
2021-04-05 16:37:35 +03:00
|
|
|
if not all(core.greater_equal_dim(d, lower_bound) for d in obj_arr):
|
2020-10-17 14:33:26 -04:00
|
|
|
msg = "{} {} must have every element be {}, got {}."
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
|
|
|
|
|
|
|
|
|
|
|
|
def _const(example, val):
|
2021-03-17 13:07:53 -07:00
|
|
|
dtype = _dtype(example)
|
2020-10-17 14:33:26 -04:00
|
|
|
if dtypes.is_python_scalar(example):
|
2021-03-17 13:07:53 -07:00
|
|
|
val = dtypes.scalar_type_of(example)(val)
|
|
|
|
return val if dtype == _dtype(val) else np.array(val, dtype)
|
|
|
|
return np.array(val, dtype)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
_zeros: Callable = partial(full_like, fill_value=0)
|
|
|
|
_zero: Callable = partial(full_like, shape=(), fill_value=0)
|
|
|
|
_ones: Callable = partial(full_like, fill_value=1)
|
|
|
|
_one: Callable = partial(full_like, shape=(), fill_value=1)
|
|
|
|
_twos: Callable = partial(full_like, fill_value=2)
|
|
|
|
_two: Callable = partial(full_like, shape=(), fill_value=2)
|
|
|
|
|
2021-11-22 09:29:43 -08:00
|
|
|
dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
|
|
|
_dtype: Callable = partial(dtypes.dtype, canonicalize=True)
|
2020-10-17 14:33:26 -04:00
|
|
|
|
2022-01-13 13:03:41 -08:00
|
|
|
def _isnan(x) -> bool:
|
|
|
|
return ne(x, x)
|
|
|
|
|
2020-10-17 14:33:26 -04:00
|
|
|
def _iscomplex(x) -> bool:
|
|
|
|
return dtypes.issubdtype(_dtype(x), np.complexfloating)
|
|
|
|
|
|
|
|
|
|
|
|
def ranges_like(*xs):
|
|
|
|
start = 0
|
|
|
|
for x in xs:
|
|
|
|
x_len = len(x)
|
|
|
|
yield range(start, start + x_len)
|
|
|
|
start += x_len
|
|
|
|
|
|
|
|
|
|
|
|
def remaining(original, *removed_lists):
|
|
|
|
removed = set(itertools.chain(*removed_lists))
|
|
|
|
return [i for i in original if i not in removed]
|
|
|
|
|
|
|
|
|
2021-09-13 15:39:02 -07:00
|
|
|
def canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[PrecisionType, PrecisionType]]:
|
2021-05-12 02:29:51 -07:00
|
|
|
"""Turns an API precision specification, into a pair of enumeration values.
|
|
|
|
|
|
|
|
The API can take the precision as a string, or int, and either as a single
|
|
|
|
value to apply to both operands, or as a sequence of two values.
|
|
|
|
"""
|
2020-10-17 14:33:26 -04:00
|
|
|
if precision is None:
|
2021-03-23 20:58:52 -07:00
|
|
|
if config.jax_default_matmul_precision is None:
|
|
|
|
return None
|
|
|
|
try:
|
2021-11-08 14:15:31 -08:00
|
|
|
precision = Precision(config.jax_default_matmul_precision)
|
2021-05-12 02:29:51 -07:00
|
|
|
return (precision, precision)
|
2021-11-08 14:15:31 -08:00
|
|
|
except TypeError:
|
2021-03-23 20:58:52 -07:00
|
|
|
raise ValueError(
|
|
|
|
"jax_default_matmul_precision flag must be set to None or a value in "
|
2021-11-08 14:15:31 -08:00
|
|
|
f"{list(Precision._strings)}, but got {config.jax_default_matmul_precision}"
|
2021-03-23 20:58:52 -07:00
|
|
|
) from None
|
2021-11-08 14:15:31 -08:00
|
|
|
elif isinstance(precision, str) and precision in Precision._strings:
|
|
|
|
precision = Precision(precision)
|
2021-05-12 02:29:51 -07:00
|
|
|
return (precision, precision)
|
2021-11-08 14:15:31 -08:00
|
|
|
elif isinstance(precision, xla_client.PrecisionConfig.Precision):
|
2021-05-12 02:29:51 -07:00
|
|
|
return (precision, precision)
|
2021-03-23 20:58:52 -07:00
|
|
|
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
2021-11-08 14:15:31 -08:00
|
|
|
all(isinstance(p, xla_client.PrecisionConfig.Precision) for p in precision)):
|
2021-05-12 02:29:51 -07:00
|
|
|
return precision # type: ignore[return-value]
|
2021-03-23 20:58:52 -07:00
|
|
|
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
|
|
|
all(isinstance(s, str) for s in precision)):
|
|
|
|
s1, s2 = precision
|
2021-09-13 15:39:02 -07:00
|
|
|
return (canonicalize_precision(s1)[0], canonicalize_precision(s2)[0]) # type: ignore
|
2020-10-17 14:33:26 -04:00
|
|
|
else:
|
2021-03-23 20:58:52 -07:00
|
|
|
raise ValueError(
|
2021-11-08 14:15:31 -08:00
|
|
|
f"Precision argument must be None, a string in {list(Precision._strings)}, "
|
2021-03-23 20:58:52 -07:00
|
|
|
"a lax.Precision value or a tuple of two lax.Precision values or "
|
|
|
|
f"strings; got {precision}.")
|
2020-10-17 14:33:26 -04:00
|
|
|
|
|
|
|
def _balanced_eq(x, z, y):
|
|
|
|
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
|
|
|
|
select(_eq_meet(y, z), _twos(z), _ones(z)))
|
|
|
|
|
|
|
|
|
|
|
|
def _eq_meet(a, b):
|
|
|
|
a_dtype, b_dtype = _dtype(a), _dtype(b)
|
|
|
|
if a_dtype != b_dtype:
|
|
|
|
higher_dtype = dtypes.promote_types(a_dtype, b_dtype)
|
|
|
|
if higher_dtype == a_dtype:
|
|
|
|
a = convert_element_type(a, b_dtype)
|
|
|
|
else:
|
|
|
|
b = convert_element_type(b, a_dtype)
|
|
|
|
return eq(a, b)
|
|
|
|
|
|
|
|
|
|
|
|
def _abstractify(x):
|
|
|
|
return raise_to_shaped(core.get_aval(x))
|
|
|
|
|
|
|
|
|
|
|
|
def _check_user_dtype_supported(dtype, fun_name=None):
|
2021-07-24 15:25:13 +07:00
|
|
|
# Avoid using `dtype in [...]` because of numpy dtype equality overloading.
|
2020-11-25 14:59:30 -08:00
|
|
|
if isinstance(dtype, type) and dtype in {bool, int, float, complex}:
|
|
|
|
return
|
2020-10-17 14:33:26 -04:00
|
|
|
np_dtype = np.dtype(dtype)
|
|
|
|
if np_dtype.kind not in "biufc" and np_dtype.type != dtypes.bfloat16:
|
|
|
|
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
|
|
|
|
msg += f" in {fun_name}" if fun_name else ""
|
|
|
|
raise TypeError(msg)
|
|
|
|
if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
|
|
|
|
msg = ("Explicitly requested dtype {} {} is not available, "
|
|
|
|
"and will be truncated to dtype {}. To enable more dtypes, set the "
|
|
|
|
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
|
|
|
|
"environment variable. "
|
|
|
|
"See https://github.com/google/jax#current-gotchas for more.")
|
|
|
|
fun_name = f"requested in {fun_name}" if fun_name else ""
|
|
|
|
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
|
2020-12-08 13:03:30 -08:00
|
|
|
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=2)
|