2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-01-08 13:17:55 -05:00
|
|
|
import builtins
|
2019-10-22 19:53:59 -04:00
|
|
|
import functools
|
2018-11-17 18:03:33 -08:00
|
|
|
import itertools
|
|
|
|
import operator
|
2020-06-04 13:50:44 -07:00
|
|
|
from typing import (Any, Callable, List, NamedTuple, Optional, Sequence, Union, Tuple)
|
2019-01-11 14:49:42 -05:00
|
|
|
import warnings
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-12 16:28:40 -07:00
|
|
|
from .. import core
|
|
|
|
from .. import ad_util
|
|
|
|
from .. import api
|
|
|
|
from .. import linear_util as lu
|
2019-11-15 10:02:51 -05:00
|
|
|
from .. import dtypes
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
from .. import lazy
|
2020-07-30 12:59:36 -07:00
|
|
|
from ..config import flags, config
|
2020-06-03 22:40:48 +02:00
|
|
|
from ..core import Primitive, _canonicalize_dimension
|
2020-06-04 13:50:44 -07:00
|
|
|
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray, array_types,
|
2020-02-05 10:10:33 -08:00
|
|
|
raise_to_shaped, abstract_token, canonicalize_shape)
|
2019-04-12 16:28:40 -07:00
|
|
|
from ..interpreters import partial_eval as pe
|
|
|
|
from ..interpreters import xla
|
2019-07-06 10:00:08 -07:00
|
|
|
from ..interpreters import pxla
|
2019-04-12 16:28:40 -07:00
|
|
|
from ..interpreters import ad
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
from ..interpreters import invertible_ad as iad
|
2019-04-12 16:28:40 -07:00
|
|
|
from ..interpreters import batching
|
2019-09-03 17:09:27 -07:00
|
|
|
from ..interpreters import masking
|
2020-06-04 13:50:44 -07:00
|
|
|
from ..util import cache, safe_zip, partial, prod, safe_map
|
|
|
|
from ..tree_util import tree_map
|
2019-10-09 15:05:54 -04:00
|
|
|
from ..lib import pytree
|
2019-04-12 16:28:40 -07:00
|
|
|
from ..lib import xla_bridge
|
2019-07-29 15:06:05 -04:00
|
|
|
from ..lib import xla_client
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
xb = xla_bridge
|
|
|
|
xc = xla_client
|
|
|
|
xops = xla_client.ops
|
|
|
|
|
2019-01-28 15:10:58 -05:00
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
_max = builtins.max
|
|
|
|
_min = builtins.max
|
2020-01-08 13:17:55 -05:00
|
|
|
_reduce = functools.reduce
|
2018-12-13 07:24:14 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
Array = Any
|
|
|
|
DType = Any
|
|
|
|
Shape = Sequence[int]
|
2019-02-01 11:07:45 -05:00
|
|
|
|
2020-05-28 00:15:01 +02:00
|
|
|
def _try_broadcast_shapes(shapes):
|
|
|
|
# Replace 1 with 0 to avoid inconclusive comparisons for polymorphic dims:
|
2020-07-14 13:05:31 -07:00
|
|
|
out_shape = np.max(np.where(shapes == 1, 0, shapes), axis=0)
|
|
|
|
out_shape = np.where(np.all(shapes == 1, axis=0), 1, out_shape)
|
|
|
|
if not np.all((shapes == out_shape) | (shapes == 1)):
|
2020-05-28 00:15:01 +02:00
|
|
|
return None
|
|
|
|
return canonicalize_shape(out_shape)
|
|
|
|
|
2019-08-09 13:12:44 -04:00
|
|
|
@cache()
|
2019-02-01 11:07:45 -05:00
|
|
|
def broadcast_shapes(*shapes):
|
|
|
|
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
|
|
|
|
if len(shapes) == 1:
|
|
|
|
return shapes[0]
|
|
|
|
ndim = _max(len(shape) for shape in shapes)
|
2020-07-14 13:05:31 -07:00
|
|
|
shapes = np.array([(1,) * (ndim - len(shape)) + shape for shape in shapes])
|
2020-05-28 00:15:01 +02:00
|
|
|
result_shape = _try_broadcast_shapes(shapes)
|
|
|
|
if result_shape is None:
|
2019-02-01 11:07:45 -05:00
|
|
|
raise ValueError("Incompatible shapes for broadcasting: {}"
|
|
|
|
.format(tuple(map(tuple, shapes))))
|
2020-05-28 00:15:01 +02:00
|
|
|
return result_shape
|
2019-02-01 11:07:45 -05:00
|
|
|
|
2019-05-29 10:39:51 -07:00
|
|
|
def _identity(x): return x
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
### traceables
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def neg(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise negation: :math:`-x`."""
|
|
|
|
return neg_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def sign(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise sign.
|
|
|
|
|
2020-01-09 11:16:52 -05:00
|
|
|
For floating-point inputs, returns
|
2019-02-19 11:30:31 -05:00
|
|
|
:math:`\mathrm{sign}(x) = \begin{cases}
|
|
|
|
-1 & x < 0\\
|
|
|
|
-0 & x = -0\\
|
|
|
|
\mathit{NaN} & x = \mathit{NaN}\\
|
|
|
|
+0 & x = +0\\
|
|
|
|
1 & x > 0
|
2020-01-09 11:16:52 -05:00
|
|
|
\end{cases}`
|
|
|
|
|
|
|
|
For signed integer inputs, returns
|
|
|
|
:math:`\mathrm{sign}(x) = \begin{cases}
|
|
|
|
-1 & x < 0\\
|
|
|
|
0 & x = 0\\
|
|
|
|
1 & x > 0
|
|
|
|
\end{cases}`
|
|
|
|
|
|
|
|
For complex inputs, returns the complex phase, i.e.
|
|
|
|
:math:`\mathrm{sign}(x) = \frac{x}{|x|}`.
|
2019-02-19 11:30:31 -05:00
|
|
|
"""
|
|
|
|
return sign_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def nextafter(x1: Array, x2: Array) -> Array:
|
2019-12-11 16:41:24 -05:00
|
|
|
r"""Returns the next representable value after `x1` in the direction of `x2`."""
|
|
|
|
return nextafter_p.bind(_brcast(x1, x2), _brcast(x2, x1))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def floor(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
|
|
|
|
return floor_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def ceil(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
|
|
|
|
return ceil_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def round(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise round.
|
|
|
|
|
|
|
|
Rounds values to the nearest integer. Halfway values (e.g., `0.5`) are rounded
|
|
|
|
away from zero."""
|
|
|
|
return round_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def is_finite(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise :math:`\mathrm{isfinite}`.
|
|
|
|
|
|
|
|
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
|
|
|
|
:math:`\mathit{NaN}`.
|
|
|
|
"""
|
|
|
|
return is_finite_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def exp(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise exponential: :math:`e^x`."""
|
|
|
|
return exp_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def expm1(x: Array) -> Array:
|
2020-05-15 20:51:53 -07:00
|
|
|
r"""Elementwise :math:`e^{x} - 1`."""
|
2019-02-19 11:30:31 -05:00
|
|
|
return expm1_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def log(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
|
|
|
|
return log_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def log1p(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
|
|
|
|
return log1p_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def tanh(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
|
|
|
|
return tanh_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def sin(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
|
|
|
|
return sin_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def cos(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
|
|
|
|
return cos_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def atan2(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise arc tangent of two variables:
|
|
|
|
:math:`\mathrm{atan}({x \over y})`."""
|
|
|
|
return atan2_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def betainc(a: Array, b: Array, x: Array) -> Array:
|
2020-01-15 13:13:11 -08:00
|
|
|
r"""Elementwise regularized incomplete beta integral."""
|
|
|
|
return regularized_incomplete_beta_p.bind(a, b, x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def lgamma(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`."""
|
|
|
|
return lgamma_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def digamma(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise digamma: :math:`\psi(x)`."""
|
|
|
|
return digamma_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def igamma(a: Array, x: Array) -> Array:
|
2020-01-29 08:25:21 -08:00
|
|
|
r"""Elementwise regularized incomplete gamma function."""
|
2020-05-06 16:15:17 +02:00
|
|
|
return igamma_p.bind(a, x)
|
2020-01-29 08:25:21 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def igammac(a: Array, x: Array) -> Array:
|
2020-01-29 08:25:21 -08:00
|
|
|
r"""Elementwise complementary regularized incomplete gamma function."""
|
2020-05-06 16:15:17 +02:00
|
|
|
return igammac_p.bind(a, x)
|
2020-01-29 08:25:21 -08:00
|
|
|
|
2020-05-05 17:10:31 -07:00
|
|
|
def igamma_grad_a(a: Array, x: Array) -> Array:
|
|
|
|
r"""Elementwise derivative of the regularized incomplete gamma function."""
|
2020-05-06 16:15:17 +02:00
|
|
|
return igamma_grad_a_p.bind(a, x)
|
2020-05-05 17:10:31 -07:00
|
|
|
|
2020-06-19 06:34:18 -07:00
|
|
|
def random_gamma_grad(a: Array, x: Array) -> Array:
|
|
|
|
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
|
|
|
|
return random_gamma_grad_p.bind(a, x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def bessel_i0e(x: Array) -> Array:
|
2019-10-21 10:30:55 -04:00
|
|
|
r"""Exponentially scaled modified Bessel function of order 0:
|
2019-12-11 16:41:24 -05:00
|
|
|
:math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
|
2019-10-21 10:30:55 -04:00
|
|
|
"""
|
|
|
|
return bessel_i0e_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def bessel_i1e(x: Array) -> Array:
|
2019-10-21 10:30:55 -04:00
|
|
|
r"""Exponentially scaled modified Bessel function of order 1:
|
2019-12-11 16:41:24 -05:00
|
|
|
:math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)`
|
2019-10-21 10:30:55 -04:00
|
|
|
"""
|
|
|
|
return bessel_i1e_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def erf(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise error function: :math:`\mathrm{erf}(x)`."""
|
|
|
|
return erf_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def erfc(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise complementary error function:
|
|
|
|
:math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`."""
|
|
|
|
return erfc_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def erf_inv(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`."""
|
|
|
|
return erf_inv_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def real(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
|
|
|
|
|
|
|
|
Returns the real part of a complex number.
|
|
|
|
"""
|
|
|
|
return real_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def imag(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
|
|
|
|
|
|
|
|
Returns the imaginary part of a complex number.
|
|
|
|
"""
|
|
|
|
return imag_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def complex(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise make complex number: :math:`x + jy`.
|
|
|
|
|
|
|
|
Builds a complex number from real and imaginary parts.
|
|
|
|
"""
|
|
|
|
return complex_p.bind(_brcast(x, y), _brcast(y, x))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def conj(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
|
|
|
|
return conj_p.bind(x, input_dtype=_dtype(x))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def abs(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise absolute value: :math:`|x|`."""
|
|
|
|
return abs_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def pow(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise power: :math:`x^y`."""
|
|
|
|
return pow_p.bind(x, y)
|
|
|
|
|
2020-05-18 17:54:20 -04:00
|
|
|
def integer_pow(x: Array, y: int) -> Array:
|
|
|
|
r"""Elementwise power: :math:`x^y`, where :math:`y` is a fixed integer."""
|
|
|
|
if y == 0:
|
|
|
|
return _ones(x)
|
|
|
|
elif y == 1:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return integer_pow_p.bind(x, y=y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def sqrt(x: Array) -> Array:
|
2019-09-04 15:06:46 -07:00
|
|
|
r"""Elementwise square root: :math:`\sqrt{x}`."""
|
|
|
|
return sqrt_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def rsqrt(x: Array) -> Array:
|
2019-09-04 15:06:46 -07:00
|
|
|
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}."""
|
|
|
|
return rsqrt_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def bitwise_not(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise NOT: :math:`\neg x`."""
|
|
|
|
return not_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def bitwise_and(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise AND: :math:`x \wedge y`."""
|
|
|
|
return and_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def bitwise_or(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise OR: :math:`x \vee y`."""
|
|
|
|
return or_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def bitwise_xor(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
|
|
|
|
return xor_p.bind(x, y)
|
|
|
|
|
2020-04-28 06:32:52 +01:00
|
|
|
def population_count(x: Array) -> Array:
|
|
|
|
r"""Elementwise popcount, count the number of set bits in each element."""
|
|
|
|
return population_count_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def add(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise addition: :math:`x + y`."""
|
|
|
|
return add_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def sub(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise subtraction: :math:`x - y`."""
|
|
|
|
return sub_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def mul(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise multiplication: :math:`x \times y`."""
|
|
|
|
return mul_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def div(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise division: :math:`x \over y`."""
|
|
|
|
return div_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def rem(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise remainder: :math:`x \bmod y`."""
|
|
|
|
return rem_p.bind(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def max(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
|
2019-02-01 11:07:45 -05:00
|
|
|
|
|
|
|
For complex numbers, uses a lexicographic comparison on the
|
|
|
|
`(real, imaginary)` pairs."""
|
|
|
|
return max_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def min(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
2019-02-01 11:07:45 -05:00
|
|
|
|
|
|
|
For complex numbers, uses a lexicographic comparison on the
|
|
|
|
`(real, imaginary)` pairs."""
|
|
|
|
return min_p.bind(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def shift_left(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise left shift: :math:`x \ll y`."""
|
|
|
|
return shift_left_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def shift_right_arithmetic(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise arithmetic right shift: :math:`x \gg y`."""
|
|
|
|
return shift_right_arithmetic_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def shift_right_logical(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise logical right shift: :math:`x \gg y`."""
|
|
|
|
return shift_right_logical_p.bind(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def eq(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise equals: :math:`x = y`."""
|
|
|
|
return eq_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def ne(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise not-equals: :math:`x \neq y`."""
|
|
|
|
return ne_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def ge(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
|
|
|
|
return ge_p.bind(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def gt(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise greater-than: :math:`x > y`."""
|
|
|
|
return gt_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def le(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
|
|
|
|
return le_p.bind(x, y)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def lt(x: Array, y: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise less-than: :math:`x < y`."""
|
|
|
|
return lt_p.bind(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def convert_element_type(operand: Array, new_dtype: DType) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
"""Elementwise cast.
|
|
|
|
|
|
|
|
Wraps XLA's `ConvertElementType
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
|
|
|
|
operator, which performs an elementwise conversion from one type to another.
|
|
|
|
Similar to a C++ `static_cast`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array or scalar value to be cast
|
|
|
|
new_dtype: the new type. Should be a NumPy type.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
|
|
|
|
"""
|
2019-11-15 10:02:51 -05:00
|
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
2019-12-09 21:18:39 -05:00
|
|
|
# Avoids dropping precision by casting Python scalars to the default Jax
|
|
|
|
# type. If we passed a Python scalar directly to the bind call below, it is
|
|
|
|
# cast to the default type as part of the calling convention.
|
|
|
|
if type(operand) in dtypes.python_scalar_dtypes:
|
2020-07-14 13:05:31 -07:00
|
|
|
operand = np.asarray(operand, new_dtype)
|
2019-11-15 10:02:51 -05:00
|
|
|
old_dtype = dtypes.canonicalize_dtype(_dtype(operand))
|
2019-11-20 22:43:46 -05:00
|
|
|
if old_dtype == new_dtype:
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand
|
2020-07-14 13:05:31 -07:00
|
|
|
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
2019-11-20 22:43:46 -05:00
|
|
|
msg = "Casting complex values to real discards the imaginary part"
|
2020-07-14 13:05:31 -07:00
|
|
|
warnings.warn(msg, np.ComplexWarning, stacklevel=2)
|
2019-11-20 22:43:46 -05:00
|
|
|
return convert_element_type_p.bind(
|
|
|
|
operand, new_dtype=new_dtype, old_dtype=old_dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
"""Elementwise bitcast.
|
|
|
|
|
|
|
|
Wraps XLA's `BitcastConvertType
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
|
|
|
|
operator, which performs a bit cast from one type to another. The bitwidth
|
|
|
|
of the source and destination types must match.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array or scalar value to be cast
|
|
|
|
new_dtype: the new type. Should be a NumPy type.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array with the same shape as `operand`, bitcast elementwise to
|
|
|
|
`new_dtype`.
|
|
|
|
"""
|
2019-11-15 10:02:51 -05:00
|
|
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
2018-12-14 08:07:12 -08:00
|
|
|
old_dtype = _dtype(operand)
|
|
|
|
if old_dtype != new_dtype:
|
|
|
|
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
|
|
|
else:
|
|
|
|
return operand
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def clamp(min: Array, x: Array, max: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise clamp.
|
|
|
|
|
|
|
|
Returns :math:`\mathrm{clamp}(x) = \begin{cases}
|
|
|
|
\mathit{min} & \text{if } x < \mathit{min},\\
|
|
|
|
\mathit{max} & \text{if } x > \mathit{max},\\
|
|
|
|
x & \text{otherwise}
|
|
|
|
\end{cases}`.
|
|
|
|
"""
|
|
|
|
return clamp_p.bind(min, x, max)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def concatenate(operands: Sequence[Array], dimension: int) -> Array:
|
2019-02-19 21:28:01 -05:00
|
|
|
"""Concatenates a sequence of arrays along `dimension`.
|
|
|
|
|
|
|
|
Wraps XLA's `Concatenate
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#concatenate>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operands: a sequence of arrays to concatenate. The arrays must have equal
|
|
|
|
shapes, except in the `dimension` axis.
|
|
|
|
dimension: the dimension along which to concatenate the arrays.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the concatenation.
|
|
|
|
"""
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return concatenate_p.bind(*operands, dimension=dimension)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-23 21:45:41 -04:00
|
|
|
Precision = xla_client.PrecisionConfig.Precision
|
2019-10-09 17:02:11 -07:00
|
|
|
Precision.__str__ = lambda precision: precision.name
|
2020-04-08 14:13:15 -04:00
|
|
|
PrecisionType = Any
|
2019-06-28 09:00:32 -04:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
class ConvDimensionNumbers(NamedTuple):
|
|
|
|
"""Describes batch, spatial, and feature dimensions of a convolution.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs_spec: a tuple of nonnegative integer dimension numbers containing
|
|
|
|
`(batch dimension, feature dimension, spatial dimensions...)`.
|
|
|
|
rhs_spec: a tuple of nonnegative integer dimension numbers containing
|
|
|
|
`(out feature dimension, in feature dimension, spatial dimensions...)`.
|
|
|
|
out_spec: a tuple of nonnegative integer dimension numbers containing
|
|
|
|
`(batch dimension, feature dimension, spatial dimensions...)`.
|
|
|
|
"""
|
|
|
|
lhs_spec: Sequence[int]
|
|
|
|
rhs_spec: Sequence[int]
|
|
|
|
out_spec: Sequence[int]
|
|
|
|
|
|
|
|
ConvGeneralDilatedDimensionNumbers = Union[
|
|
|
|
None, ConvDimensionNumbers, Tuple[str, str, str]]
|
|
|
|
|
|
|
|
def conv_general_dilated(
|
|
|
|
lhs: Array, rhs: Array, window_strides: Sequence[int],
|
|
|
|
padding: Union[str, Sequence[Tuple[int, int]]],
|
|
|
|
lhs_dilation: Optional[Sequence[int]] = None,
|
|
|
|
rhs_dilation: Optional[Sequence[int]] = None,
|
|
|
|
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
|
2020-04-09 16:21:30 -04:00
|
|
|
feature_group_count: int = 1, batch_group_count: int = 1,
|
|
|
|
precision: Optional[PrecisionType] = None) -> Array:
|
2019-02-19 21:28:01 -05:00
|
|
|
"""General n-dimensional convolution operator, with optional dilation.
|
|
|
|
|
|
|
|
Wraps XLA's `Conv
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs: a rank `n+2` dimensional input array.
|
|
|
|
rhs: a rank `n+2` dimensional array of kernel weights.
|
|
|
|
window_strides: a sequence of `n` integers, representing the inter-window
|
|
|
|
strides.
|
|
|
|
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
|
|
|
|
`n` `(low, high)` integer pairs that give the padding to apply before and
|
|
|
|
after each spatial dimension.
|
|
|
|
lhs_dilation: `None`, or a sequence of `n` integers, giving the
|
|
|
|
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
|
|
|
|
is also known as transposed convolution.
|
|
|
|
rhs_dilation: `None`, or a sequence of `n` integers, giving the
|
|
|
|
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
|
|
|
|
is also known as atrous convolution.
|
|
|
|
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
|
|
|
|
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
|
|
|
|
of length `n+2`.
|
2019-06-15 13:38:55 -07:00
|
|
|
feature_group_count: integer, default 1. See XLA HLO docs.
|
2020-04-09 16:21:30 -04:00
|
|
|
batch_group_count: integer, default 1. See XLA HLO docs.
|
2020-06-14 21:42:45 -07:00
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
|
|
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
2019-02-19 21:28:01 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the convolution result.
|
|
|
|
|
|
|
|
In the string case of `dimension_numbers`, each character identifies by
|
|
|
|
position:
|
|
|
|
|
|
|
|
- the batch dimensions in `lhs`, `rhs`, and the output with the character
|
|
|
|
'N',
|
|
|
|
- the feature dimensions in `lhs` and the output with the character 'C',
|
|
|
|
- the input and output feature dimensions in rhs with the characters 'I'
|
|
|
|
and 'O' respectively, and
|
|
|
|
- spatial dimension correspondences between lhs, rhs, and the output using
|
|
|
|
any distinct characters.
|
|
|
|
|
2019-02-19 22:05:44 -05:00
|
|
|
For example, to indicate dimension numbers consistent with the `conv` function
|
|
|
|
with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
|
|
|
|
another example, to indicate dimension numbers consistent with the TensorFlow
|
|
|
|
Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
|
|
|
|
latter form of convolution dimension specification, window strides are
|
|
|
|
associated with spatial dimension character labels according to the order in
|
|
|
|
which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
|
|
|
|
is matched with the dimension corresponding to the first character
|
2019-02-19 21:28:01 -05:00
|
|
|
appearing in rhs_spec that is not `'I'` or `'O'`.
|
|
|
|
|
2019-02-19 22:05:44 -05:00
|
|
|
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
|
|
|
|
(for a 2D convolution).
|
2019-02-19 21:28:01 -05:00
|
|
|
"""
|
2020-04-08 14:13:15 -04:00
|
|
|
dnums: ConvDimensionNumbers
|
2020-05-04 19:02:13 +01:00
|
|
|
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
2018-12-10 17:18:56 -08:00
|
|
|
if lhs_dilation is None:
|
|
|
|
lhs_dilation = (1,) * (lhs.ndim - 2)
|
2019-12-09 16:06:59 +00:00
|
|
|
elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
|
|
|
|
raise ValueError(
|
|
|
|
"String padding is not implemented for transposed convolution "
|
|
|
|
"using this op. Please either exactly specify the required padding or "
|
|
|
|
"use conv_transpose.")
|
2018-12-10 17:18:56 -08:00
|
|
|
if rhs_dilation is None:
|
|
|
|
rhs_dilation = (1,) * (rhs.ndim - 2)
|
2019-12-09 16:06:59 +00:00
|
|
|
if isinstance(padding, str):
|
2020-04-08 14:13:15 -04:00
|
|
|
lhs_perm, rhs_perm, _ = dnums
|
2020-07-14 13:05:31 -07:00
|
|
|
rhs_shape = np.take(rhs.shape, rhs_perm)[2:]
|
2019-12-09 16:06:59 +00:00
|
|
|
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
|
|
|
|
padding = padtype_to_pads(
|
2020-07-14 13:05:31 -07:00
|
|
|
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
|
2019-12-09 16:06:59 +00:00
|
|
|
window_strides, padding)
|
2018-11-17 18:03:33 -08:00
|
|
|
return conv_general_dilated_p.bind(
|
|
|
|
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
|
2020-04-08 14:13:15 -04:00
|
|
|
dimension_numbers=dnums,
|
2019-06-15 13:38:55 -07:00
|
|
|
feature_group_count=feature_group_count,
|
2020-04-09 16:21:30 -04:00
|
|
|
batch_group_count=batch_group_count,
|
2019-06-28 09:00:32 -04:00
|
|
|
lhs_shape=lhs.shape, rhs_shape=rhs.shape,
|
|
|
|
precision=_canonicalize_precision(precision))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def dot(lhs: Array, rhs: Array, precision: Optional[PrecisionType] = None) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
|
|
|
|
|
|
|
Wraps XLA's `Dot
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
For more general contraction, see the `dot_general` operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs: an array of rank 1 or 2.
|
|
|
|
rhs: an array of rank 1 or 2.
|
2020-06-14 21:42:45 -07:00
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
|
|
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
2019-02-28 22:48:31 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the product.
|
|
|
|
"""
|
2019-10-08 14:23:30 -07:00
|
|
|
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and lhs.shape[-1] == rhs.shape[0]:
|
|
|
|
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
|
|
|
|
precision=precision)
|
|
|
|
else:
|
|
|
|
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
|
|
|
|
lhs.shape, rhs.shape))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
|
|
|
|
DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]],
|
|
|
|
Tuple[Sequence[int], Sequence[int]]]
|
|
|
|
|
|
|
|
def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
|
|
|
|
precision: Optional[PrecisionType] = None) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""More general contraction operator.
|
|
|
|
|
|
|
|
Wraps XLA's `DotGeneral
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs: an array
|
|
|
|
rhs: an array
|
|
|
|
dimension_numbers: a tuple of tuples of the form
|
|
|
|
`((lhs_contracting_dims, rhs_contracting_dims),
|
|
|
|
(lhs_batch_dims, rhs_batch_dims))`
|
2020-06-14 21:42:45 -07:00
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
|
|
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
2019-02-28 22:48:31 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the result.
|
|
|
|
"""
|
2020-04-08 14:13:15 -04:00
|
|
|
contract_dims_seq, batch_dims_seq = dimension_numbers
|
|
|
|
contract_dims = tuple(map(lambda x: tuple(x), contract_dims_seq))
|
|
|
|
batch_dims = tuple(map(lambda x: tuple(x), batch_dims_seq))
|
2019-06-06 17:21:21 -04:00
|
|
|
return dot_general_p.bind(lhs, rhs,
|
2019-06-28 09:00:32 -04:00
|
|
|
dimension_numbers=(contract_dims, batch_dims),
|
|
|
|
precision=_canonicalize_precision(precision))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def broadcast(operand: Array, sizes: Sequence[int]) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Broadcasts an array, adding new major dimensions.
|
|
|
|
|
|
|
|
Wraps XLA's `Broadcast
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#broadcast>`_
|
|
|
|
operator.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array
|
|
|
|
sizes: a sequence of integers, giving the sizes of new major dimensions
|
|
|
|
to add.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the result.
|
|
|
|
"""
|
2020-07-14 13:05:31 -07:00
|
|
|
dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand)))
|
|
|
|
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def broadcast_in_dim(operand: Array, shape: Shape,
|
|
|
|
broadcast_dimensions: Sequence[int]) -> Array:
|
2020-03-16 09:54:58 +01:00
|
|
|
"""Wraps XLA's `BroadcastInDim
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#broadcastindim>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2020-04-07 09:38:10 -04:00
|
|
|
shape = _broadcast_in_dim_shape_rule(
|
|
|
|
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions):
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand
|
2019-07-14 10:57:41 -04:00
|
|
|
return broadcast_in_dim_p.bind(
|
|
|
|
operand, shape=tuple(shape),
|
|
|
|
broadcast_dimensions=tuple(broadcast_dimensions))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-01 18:00:38 +01:00
|
|
|
def broadcast_to_rank(x: Array, rank: int) -> Array:
|
|
|
|
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
|
|
|
|
return broadcast(x, (1,) * (rank - x.ndim))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def reshape(operand: Array, new_sizes: Shape,
|
|
|
|
dimensions: Optional[Sequence[int]] = None) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Reshape
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
|
|
|
operator.
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
|
|
|
|
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
|
|
|
|
``lax.expand_dims``. These preserve information about axis identity that may
|
|
|
|
be useful for advanced transformation rules.
|
2019-02-28 22:48:31 -05:00
|
|
|
"""
|
2020-02-05 10:10:33 -08:00
|
|
|
new_sizes = canonicalize_shape(new_sizes) # TODO
|
2019-08-30 16:06:43 -07:00
|
|
|
new_sizes = tuple(new_sizes)
|
2020-07-14 13:05:31 -07:00
|
|
|
same_shape = np.shape(operand) == new_sizes
|
|
|
|
same_dims = dimensions is None or tuple(dimensions) == tuple(range(np.ndim(operand)))
|
|
|
|
if np.shape(operand) and same_shape and same_dims:
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand
|
|
|
|
else:
|
2020-04-08 14:13:15 -04:00
|
|
|
return reshape_p.bind(
|
|
|
|
operand, new_sizes=new_sizes,
|
|
|
|
dimensions=None if dimensions is None or same_dims else tuple(dimensions))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def pad(operand: Array, padding_value: Array,
|
|
|
|
padding_config: Sequence[Tuple[int, int, int]]) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Pad
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def rev(operand: Array, dimensions: Sequence[int]) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Rev
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
return rev_p.bind(operand, dimensions=tuple(dimensions))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def select(pred: Array, on_true: Array, on_false: Array) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Select
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#select>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
return select_p.bind(pred, on_true, on_false)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def slice(operand: Array, start_indices: Sequence[int],
|
|
|
|
limit_indices: Sequence[int],
|
|
|
|
strides: Optional[Sequence[int]] = None) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Slice
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2020-07-14 13:05:31 -07:00
|
|
|
if (np.all(np.equal(start_indices, 0))
|
|
|
|
and np.all(np.equal(limit_indices, operand.shape))
|
2019-03-21 07:27:08 -07:00
|
|
|
and strides is None):
|
|
|
|
return operand
|
|
|
|
else:
|
|
|
|
return slice_p.bind(operand, start_indices=tuple(start_indices),
|
|
|
|
limit_indices=tuple(limit_indices),
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
strides=None if strides is None else tuple(strides))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def dynamic_slice(operand: Array, start_indices: Sequence[Array],
|
|
|
|
slice_sizes: Shape) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `DynamicSlice
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
|
|
|
|
operator.
|
2019-08-15 11:26:30 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to slice.
|
2020-07-20 06:08:54 -07:00
|
|
|
start_indices: a list of scalar indices, one per dimension. These values
|
|
|
|
may be dynamic.
|
2019-08-15 11:26:30 -04:00
|
|
|
slice_sizes: the size of the slice. Must be a sequence of non-negative
|
2020-07-20 06:08:54 -07:00
|
|
|
integers with length equal to `ndim(operand)`. Inside a JIT compiled
|
|
|
|
function, only static values are supported (all JAX arrays inside JIT
|
|
|
|
must have statically known size).
|
2019-08-15 11:26:30 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the slice.
|
2019-02-28 22:48:31 -05:00
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
start_indices = _dynamic_slice_indices(operand, start_indices)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return dynamic_slice_p.bind(operand, *start_indices,
|
|
|
|
slice_sizes=tuple(slice_sizes))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def dynamic_update_slice(operand: Array, update: Array,
|
|
|
|
start_indices: Array) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `DynamicUpdateSlice
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
|
|
|
|
operator.
|
2019-08-15 11:26:30 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to slice.
|
|
|
|
update: an array containing the new values to write onto `operand`.
|
|
|
|
start_indices: a list of scalar indices, one per dimension.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the slice.
|
2019-02-28 22:48:31 -05:00
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
start_indices = _dynamic_slice_indices(operand, start_indices)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return dynamic_update_slice_p.bind(operand, update, *start_indices)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
|
|
|
|
class GatherDimensionNumbers(NamedTuple):
|
|
|
|
"""
|
|
|
|
Describes the dimension number arguments to an `XLA's Gather operator
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#gather>`_. See the XLA
|
|
|
|
documentation for more details of what the dimension numbers mean.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
offset_dims: the set of dimensions in the `gather` output that offset into
|
|
|
|
an array sliced from `operand`. Must be a tuple of integers in ascending
|
|
|
|
order, each representing a dimension number of the output.
|
|
|
|
collapsed_slice_dims: the set of dimensions `i` in `operand` that have
|
|
|
|
`slice_sizes[i] == 1` and that should not have a corresponding dimension
|
|
|
|
in the output of the gather. Must be a tuple of integers in ascending
|
|
|
|
order.
|
|
|
|
start_index_map: for each dimension in `start_indices`, gives the
|
|
|
|
corresponding dimension in `operand` that is to be sliced. Must be a
|
|
|
|
tuple of integers with size equal to `start_indices.shape[-1]`.
|
|
|
|
|
|
|
|
Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is
|
|
|
|
implicit; there is always an index vector dimension and it must always be the
|
|
|
|
last dimension. To gather scalar indices, add a trailing dimension of size 1.
|
|
|
|
"""
|
|
|
|
offset_dims: Sequence[int]
|
|
|
|
collapsed_slice_dims: Sequence[int]
|
|
|
|
start_index_map: Sequence[int]
|
|
|
|
|
|
|
|
|
|
|
|
def gather(operand: Array, start_indices: Array,
|
|
|
|
dimension_numbers: GatherDimensionNumbers,
|
|
|
|
slice_sizes: Shape) -> Array:
|
2019-02-22 08:39:18 -05:00
|
|
|
"""Gather operator.
|
|
|
|
|
|
|
|
Wraps `XLA's Gather operator
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#gather>`_.
|
|
|
|
|
|
|
|
The semantics of gather are complicated, and its API might change in the
|
|
|
|
future. For most use cases, you should prefer `Numpy-style indexing
|
|
|
|
<https://docs.scipy.org/doc/numpy-1.16.0/reference/arrays.indexing.html>`_
|
|
|
|
(e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array from which slices should be taken
|
|
|
|
start_indices: the indices at which slices should be taken
|
|
|
|
dimension_numbers: a `lax.GatherDimensionNumbers` object that describes
|
|
|
|
how dimensions of `operand`, `start_indices` and the output relate.
|
|
|
|
slice_sizes: the size of each slice. Must be a sequence of non-negative
|
2019-08-15 11:26:30 -04:00
|
|
|
integers with length equal to `ndim(operand)`.
|
2019-02-22 08:39:18 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the gather output.
|
|
|
|
"""
|
2019-01-08 21:34:48 -05:00
|
|
|
return gather_p.bind(
|
|
|
|
operand, start_indices, dimension_numbers=dimension_numbers,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
slice_sizes=canonicalize_shape(slice_sizes))
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
|
|
|
|
class ScatterDimensionNumbers(NamedTuple):
|
|
|
|
"""
|
|
|
|
Describes the dimension number arguments to an `XLA's Scatter operator
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_. See the XLA
|
|
|
|
documentation for more details of what the dimension numbers mean.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
update_window_dims: the set of dimensions in the `updates` that are window
|
|
|
|
dimensions. Must be a tuple of integers in ascending
|
|
|
|
order, each representing a dimension number.
|
|
|
|
inserted_window_dims: the set of size 1 window dimensions that must be inserted
|
|
|
|
into the shape of `updates`. Must be a tuple of integers in ascending
|
|
|
|
order, each representing a dimension number of the output. These are the
|
|
|
|
mirror image of `collapsed_slice_dims` in the case of `gather`.
|
|
|
|
scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives
|
|
|
|
the corresponding dimension in `operand`. Must be a sequence of integers
|
|
|
|
with size equal to indices.shape[-1].
|
|
|
|
|
|
|
|
Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is
|
|
|
|
implicit; there is always an index vector dimension and it must always be the
|
|
|
|
last dimension. To scatter scalar indices, add a trailing dimension of size 1.
|
|
|
|
"""
|
|
|
|
update_window_dims: Sequence[int]
|
|
|
|
inserted_window_dims: Sequence[int]
|
|
|
|
scatter_dims_to_operand_dims: Sequence[int]
|
|
|
|
|
|
|
|
def scatter_add(operand: Array, scatter_indices: Array, updates: Array,
|
2020-07-21 23:16:27 -07:00
|
|
|
dimension_numbers: ScatterDimensionNumbers, *,
|
|
|
|
indices_are_sorted: bool = False,
|
|
|
|
unique_indices: bool = False) -> Array:
|
2019-03-01 15:41:49 -05:00
|
|
|
"""Scatter-add operator.
|
2019-02-22 08:39:18 -05:00
|
|
|
|
|
|
|
Wraps `XLA's Scatter operator
|
2019-03-01 15:41:49 -05:00
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
|
|
|
|
addition is used to combine updates and values from `operand`.
|
2019-02-22 08:39:18 -05:00
|
|
|
|
|
|
|
The semantics of scatter are complicated and its API is subject to change.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to which the scatter should be applied
|
|
|
|
scatter_indices: an array that gives the indices in `operand` to which each
|
|
|
|
update in `updates` should be applied.
|
|
|
|
updates: the updates that should be scattered onto `operand`.
|
|
|
|
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
|
|
|
|
how dimensions of `operand`, `start_indices`, `updates` and the output
|
|
|
|
relate.
|
2020-07-21 23:16:27 -07:00
|
|
|
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
|
|
|
|
true, may improve performance on some backends.
|
|
|
|
unique_indices: whether `scatter_indices` is known to be free of duplicates.
|
|
|
|
If true, may improve performance on some backends.
|
2019-02-22 08:39:18 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the sum of `operand` and the scattered updates.
|
|
|
|
"""
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr, consts = _reduction_jaxpr(add, _abstractify(_const(operand, 0)))
|
2019-03-01 15:41:49 -05:00
|
|
|
return scatter_add_p.bind(
|
|
|
|
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2019-03-01 15:41:49 -05:00
|
|
|
|
2020-04-13 16:16:34 -04:00
|
|
|
def scatter_mul(operand: Array, scatter_indices: Array, updates: Array,
|
2020-07-21 23:16:27 -07:00
|
|
|
dimension_numbers: ScatterDimensionNumbers, *,
|
|
|
|
indices_are_sorted: bool = False,
|
|
|
|
unique_indices: bool = False) -> Array:
|
2020-04-13 16:16:34 -04:00
|
|
|
"""Scatter-multiply operator.
|
|
|
|
|
|
|
|
Wraps `XLA's Scatter operator
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
|
|
|
|
multiplication is used to combine updates and values from `operand`.
|
|
|
|
|
|
|
|
The semantics of scatter are complicated and its API is subject to change.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to which the scatter should be applied
|
|
|
|
scatter_indices: an array that gives the indices in `operand` to which each
|
|
|
|
update in `updates` should be applied.
|
|
|
|
updates: the updates that should be scattered onto `operand`.
|
|
|
|
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
|
|
|
|
how dimensions of `operand`, `start_indices`, `updates` and the output
|
|
|
|
relate.
|
2020-07-21 23:16:27 -07:00
|
|
|
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
|
|
|
|
true, may improve performance on some backends.
|
|
|
|
unique_indices: whether `scatter_indices` is known to be free of duplicates.
|
|
|
|
If true, may improve performance on some backends.
|
2020-04-13 16:16:34 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the sum of `operand` and the scattered updates.
|
|
|
|
"""
|
|
|
|
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(_const(operand, 1)))
|
|
|
|
return scatter_mul_p.bind(
|
|
|
|
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2020-04-13 16:16:34 -04:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def scatter_min(operand: Array, scatter_indices: Array, updates: Array,
|
2020-07-21 23:16:27 -07:00
|
|
|
dimension_numbers: ScatterDimensionNumbers, *,
|
|
|
|
indices_are_sorted: bool = False,
|
|
|
|
unique_indices: bool = False) -> Array:
|
2019-06-21 19:31:41 -07:00
|
|
|
"""Scatter-min operator.
|
|
|
|
|
|
|
|
Wraps `XLA's Scatter operator
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
|
|
|
|
the `min` function is used to combine updates and values from `operand`.
|
|
|
|
|
|
|
|
The semantics of scatter are complicated and its API is subject to change.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to which the scatter should be applied
|
|
|
|
scatter_indices: an array that gives the indices in `operand` to which each
|
|
|
|
update in `updates` should be applied.
|
|
|
|
updates: the updates that should be scattered onto `operand`.
|
|
|
|
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
|
|
|
|
how dimensions of `operand`, `start_indices`, `updates` and the output
|
|
|
|
relate.
|
2020-07-21 23:16:27 -07:00
|
|
|
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
|
|
|
|
true, may improve performance on some backends.
|
|
|
|
unique_indices: whether `scatter_indices` is known to be free of duplicates.
|
|
|
|
If true, may improve performance on some backends.
|
2019-06-21 19:31:41 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the sum of `operand` and the scattered updates.
|
|
|
|
"""
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr, consts = _reduction_jaxpr(min, _abstractify(_const(operand, 0)))
|
2019-06-21 19:31:41 -07:00
|
|
|
return scatter_min_p.bind(
|
|
|
|
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2019-06-21 19:31:41 -07:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def scatter_max(operand: Array, scatter_indices: Array, updates: Array,
|
2020-07-21 23:16:27 -07:00
|
|
|
dimension_numbers: ScatterDimensionNumbers, *,
|
|
|
|
indices_are_sorted: bool = False,
|
|
|
|
unique_indices: bool = False) -> Array:
|
2019-06-21 19:31:41 -07:00
|
|
|
"""Scatter-max operator.
|
|
|
|
|
|
|
|
Wraps `XLA's Scatter operator
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
|
|
|
|
the `max` function is used to combine updates and values from `operand`.
|
|
|
|
|
|
|
|
The semantics of scatter are complicated and its API is subject to change.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to which the scatter should be applied
|
|
|
|
scatter_indices: an array that gives the indices in `operand` to which each
|
|
|
|
update in `updates` should be applied.
|
|
|
|
updates: the updates that should be scattered onto `operand`.
|
|
|
|
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
|
|
|
|
how dimensions of `operand`, `start_indices`, `updates` and the output
|
|
|
|
relate.
|
2020-07-21 23:16:27 -07:00
|
|
|
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
|
|
|
|
true, may improve performance on some backends.
|
|
|
|
unique_indices: whether `scatter_indices` is known to be free of duplicates.
|
|
|
|
If true, may improve performance on some backends.
|
2019-06-21 19:31:41 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the sum of `operand` and the scattered updates.
|
|
|
|
"""
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr, consts = _reduction_jaxpr(max, _abstractify(_const(operand, 0)))
|
2019-06-21 19:31:41 -07:00
|
|
|
return scatter_max_p.bind(
|
|
|
|
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2019-06-21 19:31:41 -07:00
|
|
|
|
2019-09-24 19:20:12 +02:00
|
|
|
# Define this outside of scatter to ensure cache hits.
|
|
|
|
_scatter_reduction_computation = lambda x, y: y
|
|
|
|
|
2020-07-21 23:16:27 -07:00
|
|
|
def scatter(operand: Array, scatter_indices: Array, updates: Array,
|
|
|
|
dimension_numbers: ScatterDimensionNumbers, *,
|
|
|
|
indices_are_sorted: bool = False,
|
|
|
|
unique_indices: bool = False) -> Array:
|
2019-03-01 15:41:49 -05:00
|
|
|
"""Scatter-update operator.
|
|
|
|
|
|
|
|
Wraps `XLA's Scatter operator
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where updates
|
|
|
|
replace values from `operand`.
|
|
|
|
|
|
|
|
If multiple updates are performed to the same index of operand, they may be
|
|
|
|
applied in any order.
|
|
|
|
|
|
|
|
The semantics of scatter are complicated and its API is subject to change.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
operand: an array to which the scatter should be applied
|
|
|
|
scatter_indices: an array that gives the indices in `operand` to which each
|
|
|
|
update in `updates` should be applied.
|
|
|
|
updates: the updates that should be scattered onto `operand`.
|
|
|
|
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
|
|
|
|
how dimensions of `operand`, `start_indices`, `updates` and the output
|
|
|
|
relate.
|
2020-07-21 23:16:27 -07:00
|
|
|
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
|
|
|
|
true, may improve performance on some backends.
|
|
|
|
unique_indices: whether `scatter_indices` is known to be free of duplicates.
|
|
|
|
If true, may improve performance on some backends.
|
2019-03-01 15:41:49 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the sum of `operand` and the scattered updates.
|
|
|
|
"""
|
2019-09-24 19:20:12 +02:00
|
|
|
jaxpr, consts = _reduction_jaxpr(_scatter_reduction_computation,
|
|
|
|
_abstractify(_const(operand, 0)))
|
2019-01-08 21:34:48 -05:00
|
|
|
return scatter_p.bind(
|
|
|
|
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array:
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
indices = concatenate([expand_dims(i, (1,)) for i in idxs], 1)
|
2020-07-14 13:05:31 -07:00
|
|
|
indices = indices % np.array([src.shape[ax] for ax in axes])
|
2019-02-02 09:22:37 -08:00
|
|
|
slice_sizes = list(src.shape)
|
|
|
|
for ax in axes:
|
|
|
|
slice_sizes[ax] = 1
|
|
|
|
offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1))
|
|
|
|
dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=offset_dims,
|
|
|
|
collapsed_slice_dims=axes,
|
2019-03-01 10:34:46 -05:00
|
|
|
start_index_map=axes)
|
2020-04-08 14:13:15 -04:00
|
|
|
return gather(src, indices, dimension_numbers=dnums,
|
|
|
|
slice_sizes=tuple(slice_sizes))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def transpose(operand: Array, permutation: Sequence[int]) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Transpose
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2018-12-12 17:53:37 -08:00
|
|
|
permutation = tuple(permutation)
|
|
|
|
if permutation == tuple(range(len(permutation))):
|
|
|
|
return operand
|
|
|
|
else:
|
|
|
|
return transpose_p.bind(operand, permutation=permutation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-01 11:01:22 -04:00
|
|
|
def argmin(operand: Array, axis: int,
|
|
|
|
index_dtype: DType) -> Tuple[Array, Array]:
|
|
|
|
"""Computes the index of the minimum element along ``axis``."""
|
|
|
|
return argmin_p.bind(operand, axes=(axis,),
|
|
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
|
|
|
|
def argmax(operand: Array, axis: int,
|
|
|
|
index_dtype: DType) -> Tuple[Array, Array]:
|
|
|
|
"""Computes the index of the maximum element along ``axis``."""
|
|
|
|
return argmax_p.bind(operand, axes=(axis,),
|
|
|
|
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
|
|
|
|
2020-06-09 13:09:50 -07:00
|
|
|
def reduce(operand: Array, init_value: Array, computation: Callable,
|
|
|
|
dimensions: Sequence[int]) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Reduce
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2020-06-09 13:09:50 -07:00
|
|
|
monoid_reducer = _get_monoid_reducer(computation, init_value)
|
|
|
|
if monoid_reducer:
|
|
|
|
return monoid_reducer(operand, dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-06-09 13:09:50 -07:00
|
|
|
jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value))
|
|
|
|
return reduce_p.bind(operand, init_value, computation=computation,
|
|
|
|
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
@cache()
|
2020-06-09 13:09:50 -07:00
|
|
|
def _reduction_jaxpr(computation, aval):
|
|
|
|
pval = pe.PartialVal.unknown(aval)
|
|
|
|
comp = lu.wrap_init(lambda x, y: (computation(x, y),))
|
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False)
|
2018-11-17 18:03:33 -08:00
|
|
|
return jaxpr, consts
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
|
2018-11-17 18:03:33 -08:00
|
|
|
aval = core.get_aval(x)
|
2019-06-17 20:44:33 -07:00
|
|
|
dtype = _dtype(x)
|
2018-11-17 18:03:33 -08:00
|
|
|
if (type(aval) is ConcreteArray) and aval.shape == ():
|
|
|
|
if monoid_op is add:
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.equal(aval.val, 0) and _reduce_sum
|
2019-05-05 14:31:46 -04:00
|
|
|
if monoid_op is mul:
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.equal(aval.val, 1) and _reduce_prod
|
|
|
|
elif monoid_op is bitwise_or and dtype == np.bool_:
|
|
|
|
return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_or
|
|
|
|
elif monoid_op is bitwise_and and dtype == np.bool_:
|
|
|
|
return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_and
|
2018-12-14 08:42:02 -08:00
|
|
|
elif monoid_op is max:
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.equal(aval.val, _get_max_identity(dtype)) and _reduce_max
|
2018-12-14 08:42:02 -08:00
|
|
|
elif monoid_op is min:
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.equal(aval.val, _get_min_identity(dtype)) and _reduce_min
|
2020-04-08 14:13:15 -04:00
|
|
|
return None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _get_max_identity(dtype: DType) -> Array:
|
2020-07-14 13:05:31 -07:00
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
|
|
return np.array(-np.inf, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
|
|
return np.array(dtypes.iinfo(dtype).min, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
|
|
return np.array(False, np.bool_)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _get_min_identity(dtype: DType) -> Array:
|
2020-07-14 13:05:31 -07:00
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
|
|
return np.array(np.inf, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.integer):
|
|
|
|
return np.array(dtypes.iinfo(dtype).max, dtype)
|
|
|
|
elif dtypes.issubdtype(dtype, np.bool_):
|
|
|
|
return np.array(True, np.bool_)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_sum(operand: Array, axes: Sequence[int]) -> Array:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return reduce_sum_p.bind(operand, axes=tuple(axes))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_prod(operand: Array, axes: Sequence[int]) -> Array:
|
2019-05-05 14:31:46 -04:00
|
|
|
return reduce_prod_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_max(operand: Array, axes: Sequence[int]) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
return reduce_max_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_min(operand: Array, axes: Sequence[int]) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
return reduce_min_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_or(operand: Array, axes: Sequence[int]) -> Array:
|
2018-12-14 08:42:02 -08:00
|
|
|
return reduce_or_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_and(operand: Array, axes: Sequence[int]) -> Array:
|
2018-12-14 08:42:02 -08:00
|
|
|
return reduce_and_p.bind(operand, axes=tuple(axes))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def reduce_window(operand: Array, init_value: Array, computation: Callable,
|
|
|
|
window_dimensions: Shape, window_strides: Sequence[int],
|
2020-07-20 17:27:24 -04:00
|
|
|
padding: Union[str, Sequence[Tuple[int, int]]],
|
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
2020-07-13 09:49:52 -04:00
|
|
|
"""Wraps XLA's `ReduceWindowWithGeneralPadding
|
2019-02-28 22:48:31 -05:00
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2020-07-13 09:49:52 -04:00
|
|
|
if isinstance(padding, str):
|
2020-07-13 18:16:11 -07:00
|
|
|
padding = tuple(padtype_to_pads(operand.shape, window_dimensions,
|
|
|
|
window_strides, padding))
|
2020-07-13 14:37:46 -07:00
|
|
|
else:
|
|
|
|
padding = tuple(padding)
|
2020-07-20 17:27:24 -04:00
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
|
|
|
|
if monoid_reducer:
|
2020-07-20 17:27:24 -04:00
|
|
|
return monoid_reducer(operand, window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value))
|
2018-11-17 18:03:33 -08:00
|
|
|
return reduce_window_p.bind(
|
|
|
|
operand, init_value, jaxpr=jaxpr, consts=consts,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=tuple(window_strides), padding=padding,
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _get_monoid_window_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
|
2018-11-17 18:03:33 -08:00
|
|
|
aval = core.get_aval(x)
|
|
|
|
if (type(aval) is ConcreteArray) and aval.shape == ():
|
|
|
|
if monoid_op is add:
|
|
|
|
return aval.val == 0 and _reduce_window_sum
|
|
|
|
elif monoid_op is max:
|
|
|
|
return aval.val == _get_max_identity(aval.dtype) and _reduce_window_max
|
|
|
|
elif monoid_op is min:
|
|
|
|
return aval.val == _get_min_identity(aval.dtype) and _reduce_window_min
|
2020-04-08 14:13:15 -04:00
|
|
|
return None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_window_sum(operand: Array, window_dimensions: Shape,
|
2020-07-13 09:49:52 -04:00
|
|
|
window_strides: Sequence[int],
|
2020-07-20 17:27:24 -04:00
|
|
|
padding: Sequence[Tuple[int, int]],
|
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
return reduce_window_sum_p.bind(
|
|
|
|
operand, window_dimensions=tuple(window_dimensions),
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_window_prod(operand: Array, window_dimensions: Shape,
|
2020-07-13 09:49:52 -04:00
|
|
|
window_strides: Sequence[int],
|
2020-07-20 17:27:24 -04:00
|
|
|
padding: Sequence[Tuple[int, int]],
|
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
2019-01-31 18:56:06 -05:00
|
|
|
init_value = _const(operand, 1)
|
2019-07-27 15:46:14 -07:00
|
|
|
jaxpr, consts = _reduction_jaxpr(mul, _abstractify(init_value))
|
2020-07-20 17:27:24 -04:00
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
2019-01-31 18:56:06 -05:00
|
|
|
return reduce_window_p.bind(
|
|
|
|
operand, init_value, jaxpr=jaxpr, consts=consts,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2019-01-31 18:56:06 -05:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_window_max(operand: Array, window_dimensions: Shape,
|
2020-07-13 09:49:52 -04:00
|
|
|
window_strides: Sequence[int],
|
2020-07-20 17:27:24 -04:00
|
|
|
padding: Sequence[Tuple[int, int]],
|
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
return reduce_window_max_p.bind(
|
|
|
|
operand, window_dimensions=tuple(window_dimensions),
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _reduce_window_min(operand: Array, window_dimensions: Shape,
|
2020-07-13 09:49:52 -04:00
|
|
|
window_strides: Sequence[int],
|
2020-07-20 17:27:24 -04:00
|
|
|
padding: Sequence[Tuple[int, int]],
|
|
|
|
base_dilation: Optional[Sequence[int]] = None,
|
|
|
|
window_dilation: Optional[Sequence[int]] = None) -> Array:
|
|
|
|
if base_dilation is None:
|
|
|
|
base_dilation = (1,) * len(window_dimensions)
|
|
|
|
if window_dilation is None:
|
|
|
|
window_dilation = (1,) * len(window_dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
return reduce_window_min_p.bind(
|
|
|
|
operand, window_dimensions=tuple(window_dimensions),
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _select_and_scatter(operand: Array, select: Callable,
|
|
|
|
window_dimensions: Shape, window_strides: Sequence[int],
|
2020-07-13 09:49:52 -04:00
|
|
|
padding: Sequence[Tuple[int, int]], source: Array,
|
2020-07-20 17:27:24 -04:00
|
|
|
init_value: Array, scatter: Callable,
|
|
|
|
base_dilation: Sequence[int],
|
|
|
|
window_dilation: Sequence[int]) -> Array:
|
2019-07-27 15:46:14 -07:00
|
|
|
select_jaxpr, select_consts = _reduction_jaxpr(select, _abstractify(init_value))
|
|
|
|
scatter_jaxpr, scatter_consts = _reduction_jaxpr(scatter, _abstractify(init_value))
|
2018-11-17 18:03:33 -08:00
|
|
|
return select_and_scatter_p.bind(
|
|
|
|
operand, source, init_value, select_jaxpr=select_jaxpr,
|
|
|
|
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
|
|
|
|
scatter_consts=scatter_consts, window_dimensions=tuple(window_dimensions),
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _select_and_scatter_add(source: Array, operand: Array,
|
|
|
|
select_prim: core.Primitive,
|
|
|
|
window_dimensions: Shape,
|
|
|
|
window_strides: Sequence[int],
|
2020-07-13 09:49:52 -04:00
|
|
|
padding: Sequence[Tuple[int, int]]) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
return select_and_scatter_add_p.bind(
|
|
|
|
source, operand, select_prim=select_prim,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
2020-07-13 09:49:52 -04:00
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _select_and_gather_add(tangents: Array, operand: Array,
|
|
|
|
select_prim: core.Primitive,
|
|
|
|
window_dimensions: Shape,
|
|
|
|
window_strides: Sequence[int],
|
2020-07-20 17:27:24 -04:00
|
|
|
padding: Sequence[Tuple[int, int]],
|
|
|
|
base_dilation: Sequence[int],
|
|
|
|
window_dilation: Sequence[int]) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
return select_and_gather_add_p.bind(
|
|
|
|
tangents, operand, select_prim=select_prim,
|
|
|
|
window_dimensions=tuple(window_dimensions),
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=tuple(window_strides), padding=tuple(padding),
|
|
|
|
base_dilation=tuple(base_dilation),
|
|
|
|
window_dilation=tuple(window_dilation))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def cumsum(operand: Array, axis: int) -> Array:
|
2020-04-06 11:22:01 -04:00
|
|
|
"""Computes a cumulative sum along `axis`."""
|
|
|
|
return cumsum_p.bind(operand, axis=int(axis))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def cumprod(operand: Array, axis: int) -> Array:
|
2020-04-06 11:22:01 -04:00
|
|
|
"""Computes a cumulative product along `axis`."""
|
|
|
|
return cumprod_p.bind(operand, axis=int(axis))
|
|
|
|
|
2020-06-28 19:33:20 +01:00
|
|
|
def cummax(operand: Array, axis: int) -> Array:
|
2020-06-29 18:13:36 +01:00
|
|
|
"""Computes a cumulative maximum along `axis`."""
|
2020-06-28 19:33:20 +01:00
|
|
|
return cummax_p.bind(operand, axis=int(axis))
|
2020-06-28 18:21:09 +01:00
|
|
|
|
2020-06-28 19:33:20 +01:00
|
|
|
def cummin(operand: Array, axis: int) -> Array:
|
2020-06-29 18:13:36 +01:00
|
|
|
"""Computes a cumulative minimum along `axis`."""
|
2020-06-28 19:33:20 +01:00
|
|
|
return cummin_p.bind(operand, axis=int(axis))
|
2020-06-28 18:21:09 +01:00
|
|
|
|
2020-06-26 18:40:00 +01:00
|
|
|
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
2020-07-09 20:05:19 -07:00
|
|
|
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Sort
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
|
|
|
|
operator.
|
2020-07-09 20:05:19 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
operand : Array or sequence of arrays
|
|
|
|
dimension : integer dimension along which to sort. Default: -1.
|
|
|
|
is_stable : boolean specifying whether to use a stable sort. Default: True.
|
|
|
|
num_keys : number of operands to treat as sort keys. Default: 1.
|
|
|
|
For num_keys > 1, the sort order will be determined lexicographically using
|
|
|
|
the first `num_keys` arrays, with the first key being primary.
|
|
|
|
The remaining operands will be returned with the same permutation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
operand : sorted version of the input or inputs.
|
2019-02-28 22:48:31 -05:00
|
|
|
"""
|
2020-06-23 08:28:04 -07:00
|
|
|
if isinstance(operand, Sequence):
|
2020-05-14 11:13:15 -04:00
|
|
|
if len(operand) == 0:
|
|
|
|
raise TypeError("Sort requires at least one operand")
|
2020-07-09 20:05:19 -07:00
|
|
|
if not (1 <= num_keys <= len(operand)):
|
|
|
|
raise ValueError(f"num_keys={num_keys} must be between 1 and len(operand)={len(operand)}")
|
2020-05-14 11:13:15 -04:00
|
|
|
dimension = _canonicalize_axis(dimension, len(operand[0].shape))
|
2020-06-26 18:40:00 +01:00
|
|
|
return tuple(sort_p.bind(*operand, dimension=dimension,
|
2020-07-09 20:05:19 -07:00
|
|
|
is_stable=is_stable,
|
2020-07-10 09:58:35 -07:00
|
|
|
num_keys=num_keys))
|
2020-05-14 11:13:15 -04:00
|
|
|
else:
|
2020-07-09 20:05:19 -07:00
|
|
|
if num_keys != 1:
|
|
|
|
raise ValueError(f"num_keys={num_keys} must equal 1 for a single operand.")
|
2020-05-14 11:13:15 -04:00
|
|
|
dimension = _canonicalize_axis(dimension, len(operand.shape))
|
2020-07-10 09:58:35 -07:00
|
|
|
return sort_p.bind(operand, dimension=dimension, is_stable=is_stable, num_keys=1)[0]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-26 18:40:00 +01:00
|
|
|
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
|
2020-06-26 19:50:41 -07:00
|
|
|
is_stable: bool = True) -> Tuple[Array, Array]:
|
2020-04-28 10:49:17 -07:00
|
|
|
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
|
2020-05-14 11:13:15 -04:00
|
|
|
dimension = _canonicalize_axis(dimension, len(keys.shape))
|
2020-07-10 09:58:35 -07:00
|
|
|
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
|
2020-05-14 11:13:15 -04:00
|
|
|
return k, v
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-19 11:49:15 -07:00
|
|
|
def top_k(operand: Array, k: int) -> Tuple[Array, Array]:
|
|
|
|
"""Returns top ``k`` values and their indices along the last axis of ``operand``."""
|
2020-02-20 17:15:25 -08:00
|
|
|
k = int(k)
|
|
|
|
if k < 0:
|
|
|
|
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
|
|
|
|
return top_k_p.bind(operand, k=k)
|
2019-03-13 10:41:42 -04:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def tie_in(x: Array, y: Array) -> Array:
|
2020-05-24 19:12:37 +03:00
|
|
|
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
|
2020-03-05 16:21:19 -08:00
|
|
|
|
|
|
|
When staging to XLA (e.g. running under jit or pmap), values that don't depend
|
|
|
|
on computation inputs are computed op-by-op, and folded into the XLA
|
|
|
|
computation as constants.
|
|
|
|
|
|
|
|
``tie_in`` provides a way to explicitly stage values into the computation.
|
|
|
|
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
|
|
|
|
is ``y``, but staged to XLA. Downstream use of the result will also be staged
|
|
|
|
to XLA.
|
2020-05-24 19:12:37 +03:00
|
|
|
|
|
|
|
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
|
|
|
|
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
|
|
|
|
XLA as long as ``x`` is staged to XLA.
|
2020-03-05 16:21:19 -08:00
|
|
|
"""
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
return y
|
|
|
|
else:
|
|
|
|
return tie_in_p.bind(x, y)
|
|
|
|
|
|
|
|
# def tie_in(x: Array, y: Array) -> Array:
|
|
|
|
# """Deprecated. Ignores ``x`` and returns ``y``."""
|
|
|
|
# return y
|
2018-12-13 07:24:14 -08:00
|
|
|
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Array:
|
2019-02-28 12:57:56 -05:00
|
|
|
"""Returns an array of `shape` filled with `fill_value`.
|
|
|
|
|
|
|
|
Arguments:
|
2020-01-06 20:57:19 -08:00
|
|
|
shape: sequence of integers, describing the shape of the output array.
|
|
|
|
fill_value: the value to fill the new array with.
|
2019-02-28 12:57:56 -05:00
|
|
|
dtype: the type of the output array, or `None`. If not `None`, `fill_value`
|
|
|
|
will be cast to `dtype`.
|
|
|
|
"""
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = canonicalize_shape(shape)
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.shape(fill_value):
|
2018-12-13 07:24:14 -08:00
|
|
|
msg = "full must be called with scalar fill_value, got fill_value.shape {}."
|
2020-07-14 13:05:31 -07:00
|
|
|
raise TypeError(msg.format(np.shape(fill_value)))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
fill_value = convert_element_type(fill_value, dtype)
|
|
|
|
if not isinstance(fill_value, (xla.DeviceArray, core.Tracer)):
|
|
|
|
fill_value = _device_put_raw(fill_value)
|
|
|
|
else:
|
|
|
|
fill_value = xla.device_put_p.bind(convert_element_type(fill_value, dtype))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
return broadcast(fill_value, shape)
|
2018-12-13 07:24:14 -08:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def _device_put_raw(x):
|
|
|
|
if isinstance(x, xla.DeviceValue):
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
aval = raise_to_shaped(core.get_aval(x))
|
|
|
|
return xla.array_result_handler(None, aval)(xla.device_put(x))
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def iota(dtype: DType, size: int) -> Array:
|
2019-02-28 22:48:31 -05:00
|
|
|
"""Wraps XLA's `Iota
|
|
|
|
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
|
|
|
|
operator.
|
|
|
|
"""
|
2020-05-01 21:34:29 +02:00
|
|
|
size = size if type(size) is masking.Poly else int(size)
|
|
|
|
shape = canonicalize_shape((size,))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-05-01 21:34:29 +02:00
|
|
|
lazy_expr = lazy.iota(dtype, shape[0])
|
|
|
|
aval = ShapedArray(shape, dtype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
|
2018-12-15 17:49:00 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
"""Convenience wrapper around ``iota``."""
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-02-05 10:10:33 -08:00
|
|
|
shape = canonicalize_shape(shape)
|
2018-12-13 11:12:11 -08:00
|
|
|
dimension = int(dimension)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
return broadcast_in_dim(iota(dtype, shape[dimension]), shape, [dimension])
|
2018-12-13 11:12:11 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
"""Like numpy.eye, create a 2D array with ones on a diagonal.
|
2018-12-15 17:49:00 -08:00
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
This function exists for creating lazy identity matrices; that is,
|
|
|
|
materialization of the array is delayed and it may be fused into consumers to
|
|
|
|
avoid materialization at all."""
|
|
|
|
N, M = tuple(map(int, shape))
|
|
|
|
offset = int(offset)
|
2019-11-15 10:02:51 -05:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
lazy_expr = lazy.eye(dtype, (N, M), offset)
|
|
|
|
aval = ShapedArray((N, M), dtype)
|
|
|
|
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
"""This function exists for creating lazy Kronecker delta arrays, particularly
|
|
|
|
for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that
|
|
|
|
it can create arrays of any rank, but doesn't allow offsets."""
|
|
|
|
shape = tuple(map(int, shape))
|
2018-12-15 17:49:00 -08:00
|
|
|
axes = tuple(map(int, axes))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
2020-07-14 13:05:31 -07:00
|
|
|
base_shape = tuple(np.take(shape, axes))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
|
|
|
|
aval = ShapedArray(shape, dtype)
|
|
|
|
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
"""Like numpy.tri, create a 2D array with ones below a diagonal.
|
|
|
|
This function exists for creating lazy triangular matrices, particularly for
|
|
|
|
use in jax.numpy.tri."""
|
|
|
|
N, M = tuple(map(int, shape))
|
|
|
|
offset = int(offset)
|
|
|
|
dtype = dtypes.canonicalize_dtype(dtype)
|
|
|
|
lazy_expr = lazy.tri(dtype, (N, M), offset)
|
|
|
|
aval = ShapedArray((N, M), dtype)
|
|
|
|
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
|
2018-12-13 07:24:14 -08:00
|
|
|
|
2019-01-30 10:39:35 -08:00
|
|
|
def stop_gradient(x):
|
2019-03-10 18:08:04 -04:00
|
|
|
"""Stops gradient computation.
|
|
|
|
|
2020-06-01 18:09:45 -04:00
|
|
|
Operationally ``stop_gradient`` is the identity function, that is, it returns
|
|
|
|
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
|
2020-04-17 12:42:53 -07:00
|
|
|
gradients during forward or reverse-mode automatic differentiation. If there
|
2020-06-01 18:09:45 -04:00
|
|
|
are multiple nested gradient computations, ``stop_gradient`` stops gradients
|
2020-04-17 12:42:53 -07:00
|
|
|
for all of them.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
>>> jax.grad(lambda x: x**2)(3.)
|
|
|
|
array(6., dtype=float32)
|
|
|
|
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
|
|
|
|
array(0., dtype=float32)
|
|
|
|
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
|
|
|
|
array(2., dtype=float32)
|
|
|
|
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
|
|
|
array(0., dtype=float32)
|
|
|
|
"""
|
2020-04-23 13:12:24 -07:00
|
|
|
return tree_map(ad_util.stop_gradient_p.bind, x)
|
2019-01-30 10:39:35 -08:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
### convenience wrappers around traceables
|
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
|
|
|
|
padding: str, precision: Optional[PrecisionType] = None) -> Array:
|
2019-02-19 21:28:01 -05:00
|
|
|
"""Convenience wrapper around `conv_general_dilated`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs: a rank `n+2` dimensional input array.
|
|
|
|
rhs: a rank `n+2` dimensional array of kernel weights.
|
|
|
|
window_strides: a sequence of `n` integers, representing the inter-window
|
|
|
|
strides.
|
|
|
|
padding: either the string `'SAME'`, the string `'VALID'`.
|
2020-06-14 21:42:45 -07:00
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
|
|
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
2019-02-19 21:28:01 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the convolution result.
|
|
|
|
"""
|
2019-06-28 09:00:32 -04:00
|
|
|
return conv_general_dilated(lhs, rhs, window_strides, padding,
|
|
|
|
precision=precision)
|
2018-12-10 17:18:56 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def conv_with_general_padding(lhs: Array, rhs: Array,
|
|
|
|
window_strides: Sequence[int],
|
|
|
|
padding: Union[str, Sequence[Tuple[int, int]]],
|
|
|
|
lhs_dilation: Optional[Sequence[int]],
|
|
|
|
rhs_dilation: Optional[Sequence[int]],
|
|
|
|
precision: Optional[PrecisionType] = None) -> Array:
|
2019-02-19 21:28:01 -05:00
|
|
|
"""Convenience wrapper around `conv_general_dilated`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs: a rank `n+2` dimensional input array.
|
|
|
|
rhs: a rank `n+2` dimensional array of kernel weights.
|
|
|
|
window_strides: a sequence of `n` integers, representing the inter-window
|
|
|
|
strides.
|
|
|
|
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
|
|
|
|
`n` `(low, high)` integer pairs that give the padding to apply before and
|
|
|
|
after each spatial dimension.
|
|
|
|
lhs_dilation: `None`, or a sequence of `n` integers, giving the
|
|
|
|
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
|
|
|
|
is also known as transposed convolution.
|
|
|
|
rhs_dilation: `None`, or a sequence of `n` integers, giving the
|
|
|
|
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
|
|
|
|
is also known as atrous convolution.
|
2020-06-14 21:42:45 -07:00
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
|
|
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
2019-02-19 21:28:01 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array containing the convolution result.
|
|
|
|
"""
|
2018-12-10 17:18:56 -08:00
|
|
|
return conv_general_dilated(
|
|
|
|
lhs, rhs, window_strides, padding, lhs_dilation=lhs_dilation,
|
2019-06-28 09:00:32 -04:00
|
|
|
rhs_dilation=rhs_dilation, precision=precision)
|
2018-12-10 17:18:56 -08:00
|
|
|
|
|
|
|
|
2019-04-09 15:06:46 -07:00
|
|
|
def _conv_transpose_padding(k, s, padding):
|
|
|
|
"""Calculate before and after padding for a dim of transposed convolution.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
k: int: kernel dimension.
|
|
|
|
s: int: dimension stride value.
|
|
|
|
padding: 'same' or 'valid' padding mode for original forward conv.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
2-tuple: ints: before and after padding for transposed convolution.
|
|
|
|
"""
|
2019-04-09 22:59:03 -07:00
|
|
|
if padding == 'SAME':
|
2019-04-09 15:06:46 -07:00
|
|
|
pad_len = k + s - 2
|
|
|
|
if s > k - 1:
|
|
|
|
pad_a = k - 1
|
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
pad_a = int(np.ceil(pad_len / 2))
|
2019-04-09 22:59:03 -07:00
|
|
|
elif padding == 'VALID':
|
2019-06-19 10:12:13 -07:00
|
|
|
pad_len = k + s - 2 + _max(k - s, 0)
|
2019-04-09 15:06:46 -07:00
|
|
|
pad_a = k - 1
|
|
|
|
else:
|
2019-04-09 22:59:03 -07:00
|
|
|
raise ValueError('Padding mode must be `SAME` or `VALID`.')
|
2019-04-09 15:06:46 -07:00
|
|
|
pad_b = pad_len - pad_a
|
|
|
|
return pad_a, pad_b
|
|
|
|
|
|
|
|
|
|
|
|
def _flip_axes(x, axes):
|
|
|
|
"""Flip ndarray 'x' along each axis specified in axes tuple."""
|
|
|
|
for axis in axes:
|
2020-07-14 13:05:31 -07:00
|
|
|
x = np.flip(x, axis)
|
2019-04-09 15:06:46 -07:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
|
|
|
padding: Union[str, Sequence[Tuple[int, int]]],
|
|
|
|
rhs_dilation: Optional[Sequence[int]] = None,
|
|
|
|
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
|
|
|
|
transpose_kernel: bool = False,
|
|
|
|
precision: Optional[PrecisionType] = None) -> Array:
|
2019-04-09 22:59:03 -07:00
|
|
|
"""Convenience wrapper for calculating the N-d convolution "transpose".
|
2019-04-09 15:06:46 -07:00
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
This function directly calculates a fractionally strided conv rather than
|
|
|
|
indirectly calculating the gradient (transpose) of a forward convolution.
|
2019-04-09 15:06:46 -07:00
|
|
|
|
|
|
|
Args:
|
2019-04-09 22:59:03 -07:00
|
|
|
lhs: a rank `n+2` dimensional input array.
|
|
|
|
rhs: a rank `n+2` dimensional array of kernel weights.
|
2019-04-09 15:06:46 -07:00
|
|
|
strides: sequence of `n` integers, sets fractional stride.
|
2019-04-09 22:59:03 -07:00
|
|
|
padding: 'SAME', 'VALID' will set as transpose of corresponding forward
|
2019-04-09 15:06:46 -07:00
|
|
|
conv, or a sequence of `n` integer 2-tuples describing before-and-after
|
|
|
|
padding for each `n` spatial dimension.
|
2019-12-17 02:03:17 +00:00
|
|
|
rhs_dilation: `None`, or a sequence of `n` integers, giving the
|
|
|
|
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
|
|
|
|
is also known as atrous convolution.
|
2019-04-09 15:06:46 -07:00
|
|
|
dimension_numbers: tuple of dimension descriptors as in
|
|
|
|
lax.conv_general_dilated. Defaults to tensorflow convention.
|
2019-04-09 22:59:03 -07:00
|
|
|
transpose_kernel: if True flips spatial axes and swaps the input/output
|
|
|
|
channel axes of the kernel. This makes the output of this function identical
|
|
|
|
to the gradient-derived functions like keras.layers.Conv2DTranspose
|
|
|
|
applied to the same kernel. For typical use in neural nets this is completely
|
|
|
|
pointless and just makes input/output channel specification confusing.
|
2020-06-14 21:42:45 -07:00
|
|
|
precision: Optional. Either ``None``, which means the default precision for
|
|
|
|
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
|
|
|
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
2019-04-09 15:06:46 -07:00
|
|
|
|
|
|
|
Returns:
|
2019-04-09 22:59:03 -07:00
|
|
|
Transposed N-d convolution, with output padding following the conventions of
|
|
|
|
keras.layers.Conv2DTranspose.
|
2019-04-09 15:06:46 -07:00
|
|
|
"""
|
2020-07-02 14:38:35 -07:00
|
|
|
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
|
2019-04-09 22:59:03 -07:00
|
|
|
ndims = len(lhs.shape)
|
2019-04-09 15:06:46 -07:00
|
|
|
one = (1,) * (ndims - 2)
|
2019-04-09 22:59:03 -07:00
|
|
|
# Set dimensional layout defaults if not specified.
|
2019-04-09 15:06:46 -07:00
|
|
|
if dimension_numbers is None:
|
2020-07-02 14:38:35 -07:00
|
|
|
if ndims == 2:
|
|
|
|
dimension_numbers = ('NC', 'IO', 'NC')
|
|
|
|
elif ndims == 3:
|
2019-04-09 15:06:46 -07:00
|
|
|
dimension_numbers = ('NHC', 'HIO', 'NHC')
|
|
|
|
elif ndims == 4:
|
|
|
|
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
|
|
|
elif ndims == 5:
|
|
|
|
dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
|
|
|
|
else:
|
|
|
|
raise ValueError('No 4+ dimensional dimension_number defaults.')
|
2019-04-09 22:59:03 -07:00
|
|
|
dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
2020-07-14 13:05:31 -07:00
|
|
|
k_shape = np.take(rhs.shape, dn.rhs_spec)
|
2019-04-09 15:06:46 -07:00
|
|
|
k_sdims = k_shape[2:]
|
|
|
|
# Calculate correct output shape given padding and strides.
|
2020-04-08 14:13:15 -04:00
|
|
|
pads: Union[str, Sequence[Tuple[int, int]]]
|
2019-04-09 22:59:03 -07:00
|
|
|
if padding in {'SAME', 'VALID'}:
|
2019-12-17 02:03:17 +00:00
|
|
|
if rhs_dilation is None:
|
|
|
|
rhs_dilation = (1,) * (rhs.ndim - 2)
|
|
|
|
effective_k_size = map(lambda k, r: (k-1) * r + 1, k_sdims, rhs_dilation)
|
2019-04-09 15:06:46 -07:00
|
|
|
pads = [_conv_transpose_padding(k, s, padding)
|
2019-12-17 02:03:17 +00:00
|
|
|
for k,s in zip(effective_k_size, strides)]
|
2019-04-09 15:06:46 -07:00
|
|
|
else:
|
|
|
|
pads = padding
|
2019-04-09 22:59:03 -07:00
|
|
|
if transpose_kernel:
|
|
|
|
# flip spatial dims and swap input / output channel axes
|
2020-07-14 13:05:31 -07:00
|
|
|
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
|
|
|
|
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
|
2019-12-17 02:03:17 +00:00
|
|
|
return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn,
|
2019-06-28 09:00:32 -04:00
|
|
|
precision=precision)
|
2019-04-09 15:06:46 -07:00
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
|
|
|
|
shape: Optional[Shape] = None) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Create a full array like np.full based on the example array `x`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: example array-like, used for shape and dtype information.
|
|
|
|
fill_value: a scalar value to fill the entries of the output array.
|
|
|
|
dtype: optional, a dtype parameter for the output ndarray.
|
|
|
|
shape: optional, a shape parameter for the output ndarray.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An ndarray with the same shape as `x` with its entries set equal to
|
|
|
|
`fill_value`, similar to the output of np.full.
|
|
|
|
"""
|
2020-07-14 13:05:31 -07:00
|
|
|
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
|
2020-07-30 12:59:36 -07:00
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
fill_value = tie_in(x, fill_value)
|
2020-04-08 14:13:15 -04:00
|
|
|
return full(fill_shape, fill_value, dtype or _dtype(x))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def collapse(operand: Array, start_dimension: int, stop_dimension: int) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
lo, hi = start_dimension, stop_dimension
|
2018-11-26 18:50:27 -08:00
|
|
|
size = prod(operand.shape[lo:hi])
|
2018-11-17 18:03:33 -08:00
|
|
|
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
|
|
|
return reshape(operand, new_shape)
|
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def slice_in_dim(operand: Array, start_index: Optional[int],
|
|
|
|
limit_index: Optional[int],
|
|
|
|
stride: int = 1, axis: int = 0)-> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Convenience wrapper around slice applying to only one dimension."""
|
|
|
|
start_indices = [0] * operand.ndim
|
|
|
|
limit_indices = list(operand.shape)
|
|
|
|
strides = [1] * operand.ndim
|
|
|
|
|
2020-01-08 12:22:12 +01:00
|
|
|
# translate `None`
|
|
|
|
len_axis = operand.shape[axis]
|
2020-06-03 22:40:48 +02:00
|
|
|
start_index_int = _canonicalize_dimension(start_index) if start_index is not None else 0
|
|
|
|
limit_index_int = _canonicalize_dimension(limit_index) if limit_index is not None else len_axis
|
2020-01-15 13:13:11 -08:00
|
|
|
|
2020-01-08 12:22:12 +01:00
|
|
|
# translate negative indices
|
2020-04-08 14:13:15 -04:00
|
|
|
if start_index_int < 0:
|
|
|
|
start_index_int = start_index_int + len_axis
|
|
|
|
if limit_index_int < 0:
|
|
|
|
limit_index_int = limit_index_int + len_axis
|
2020-01-08 12:22:12 +01:00
|
|
|
|
2019-02-02 21:41:06 -08:00
|
|
|
axis = int(axis)
|
2020-04-08 14:13:15 -04:00
|
|
|
start_indices[axis] = start_index_int
|
|
|
|
limit_indices[axis] = limit_index_int
|
2019-02-02 21:41:06 -08:00
|
|
|
strides[axis] = int(stride)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
return slice(operand, start_indices, limit_indices, strides)
|
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def index_in_dim(operand: Array, index: int, axis: int = 0,
|
|
|
|
keepdims: bool = True) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Convenience wrapper around slice to perform int indexing."""
|
2019-02-02 21:41:06 -08:00
|
|
|
index, axis = int(index), int(axis)
|
2018-11-17 18:03:33 -08:00
|
|
|
axis_size = operand.shape[axis]
|
|
|
|
wrapped_index = index + axis_size if index < 0 else index
|
|
|
|
if not 0 <= wrapped_index < axis_size:
|
|
|
|
msg = 'index {} is out of bounds for axis {} with size {}'
|
|
|
|
raise IndexError(msg.format(index, axis, axis_size))
|
|
|
|
result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis)
|
|
|
|
if keepdims:
|
|
|
|
return result
|
|
|
|
else:
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
return squeeze(result, (axis,))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def dynamic_slice_in_dim(operand: Array, start_index: Array,
|
|
|
|
slice_size: int, axis: int = 0) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Convenience wrapper around dynamic_slice applying to one dimension."""
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
start_indices = [_zero(start_index)] * operand.ndim
|
2018-11-17 18:03:33 -08:00
|
|
|
slice_sizes = list(operand.shape)
|
|
|
|
|
2019-02-02 21:41:06 -08:00
|
|
|
axis = int(axis)
|
2019-08-15 15:22:55 -04:00
|
|
|
start_indices[axis] = start_index
|
2019-02-02 21:41:06 -08:00
|
|
|
slice_sizes[axis] = int(slice_size)
|
2018-11-17 18:03:33 -08:00
|
|
|
return dynamic_slice(operand, start_indices, slice_sizes)
|
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
|
|
|
|
keepdims: bool = True) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Convenience wrapper around dynamic_slice to perform int indexing."""
|
|
|
|
result = dynamic_slice_in_dim(operand, index, 1, axis)
|
|
|
|
if keepdims:
|
|
|
|
return result
|
|
|
|
else:
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
return squeeze(result, (axis,))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def dynamic_update_slice_in_dim(operand: Array, update: Array,
|
|
|
|
start_index: Array, axis: int) -> Array:
|
2019-02-02 21:41:06 -08:00
|
|
|
axis = int(axis)
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
start_indices = [_zero(start_index)] * _ndim(operand)
|
2019-08-15 15:22:55 -04:00
|
|
|
start_indices[axis] = start_index
|
2018-11-17 18:03:33 -08:00
|
|
|
return dynamic_update_slice(operand, update, start_indices)
|
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
|
|
|
|
axis: int) -> Array:
|
2019-02-02 21:41:06 -08:00
|
|
|
axis = int(axis)
|
2018-11-17 18:03:33 -08:00
|
|
|
if _ndim(update) != _ndim(operand):
|
|
|
|
assert _ndim(update) + 1 == _ndim(operand)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
update = expand_dims(update, (axis,))
|
2018-11-17 18:03:33 -08:00
|
|
|
return dynamic_update_slice_in_dim(operand, update, index, axis)
|
|
|
|
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def batch_matmul(lhs: Array, rhs: Array,
|
|
|
|
precision: Optional[PrecisionType] = None) -> Array:
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Batch matrix multiplication."""
|
|
|
|
if _min(lhs.ndim, rhs.ndim) < 2:
|
|
|
|
raise ValueError('Arguments to batch_matmul must be at least 2D, got {}, {}'
|
|
|
|
.format(lhs.ndim, rhs.ndim))
|
|
|
|
if lhs.ndim != rhs.ndim:
|
|
|
|
raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}'
|
|
|
|
.format(lhs.ndim, rhs.ndim))
|
|
|
|
lhs_contract = (lhs.ndim - 1,)
|
|
|
|
rhs_contract = (rhs.ndim - 2,)
|
|
|
|
batch = tuple(range(lhs.ndim - 2))
|
2020-04-08 14:13:15 -04:00
|
|
|
return dot_general(lhs, rhs, ((lhs_contract, rhs_contract), (batch, batch)),
|
2019-12-10 00:38:18 -08:00
|
|
|
precision=precision)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-09-04 15:06:46 -07:00
|
|
|
# These functions also exist in the XLA client library, but we treat them
|
2018-11-17 18:03:33 -08:00
|
|
|
# as non-primitive to maintain a smaller set of autodiff primitives.
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def square(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise square: :math:`x^2`."""
|
2020-05-18 17:54:20 -04:00
|
|
|
return integer_pow(x, 2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def reciprocal(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise reciprocal: :math:`1 \over x`."""
|
2020-05-18 17:54:20 -04:00
|
|
|
return integer_pow(x, -1)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-20 22:43:46 -05:00
|
|
|
def _upcast_fp16_for_computation(f):
|
|
|
|
@functools.wraps(f)
|
|
|
|
def f_wrapped(x):
|
|
|
|
dtype = _dtype(x)
|
2020-07-14 13:05:31 -07:00
|
|
|
if dtype == np.float16 or dtype == dtypes.bfloat16:
|
2019-11-20 22:43:46 -05:00
|
|
|
return convert_element_type(
|
2020-07-14 13:05:31 -07:00
|
|
|
f(convert_element_type(x, np.float32)), dtype)
|
2019-11-20 22:43:46 -05:00
|
|
|
return f(x)
|
|
|
|
|
|
|
|
return f_wrapped
|
|
|
|
|
2019-10-21 10:56:54 -04:00
|
|
|
@api.jit
|
2019-11-20 22:43:46 -05:00
|
|
|
@_upcast_fp16_for_computation
|
2020-04-08 14:13:15 -04:00
|
|
|
def tan(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
|
2018-11-17 18:03:33 -08:00
|
|
|
return div(sin(x), cos(x))
|
|
|
|
|
2019-10-21 10:56:54 -04:00
|
|
|
@api.jit
|
2020-04-08 14:13:15 -04:00
|
|
|
def asin(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
|
2018-12-12 15:30:41 -08:00
|
|
|
return mul(_const(x, 2),
|
|
|
|
atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x))))))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-10-21 10:56:54 -04:00
|
|
|
@api.jit
|
2020-04-08 14:13:15 -04:00
|
|
|
def acos(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
|
2019-05-29 12:51:24 -04:00
|
|
|
return select(
|
|
|
|
ne(x, _const(x, -1.0)),
|
|
|
|
mul(_const(x, 2),
|
|
|
|
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))),
|
2020-07-14 13:05:31 -07:00
|
|
|
full_like(x, np.pi))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def atan(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
|
2018-12-12 19:05:40 -08:00
|
|
|
return atan2(x, _const(x, 1))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def sinh(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`."""
|
2020-03-19 07:29:37 -07:00
|
|
|
return sinh_p.bind(x)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def cosh(x: Array) -> Array:
|
2019-02-19 11:30:31 -05:00
|
|
|
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`."""
|
2020-03-19 07:29:37 -07:00
|
|
|
return cosh_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def asinh(x: Array) -> Array:
|
2020-03-19 07:29:37 -07:00
|
|
|
r"""Elementwise inverse hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
|
|
|
|
return asinh_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def acosh(x: Array) -> Array:
|
2020-03-19 07:29:37 -07:00
|
|
|
r"""Elementwise inverse hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
|
|
|
|
return acosh_p.bind(x)
|
|
|
|
|
2020-04-08 14:13:15 -04:00
|
|
|
def atanh(x: Array) -> Array:
|
2020-03-19 07:29:37 -07:00
|
|
|
r"""Elementwise inverse hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
|
|
|
|
return atanh_p.bind(x)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-12 15:30:41 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# Add some methods to ShapedArray that rely on lax primitives
|
|
|
|
|
|
|
|
ShapedArray.broadcast = core.aval_method(broadcast)
|
|
|
|
ShapedArray.transpose = core.aval_method(transpose) # clobbered by lax_numpy
|
|
|
|
ShapedArray.reshape = core.aval_method(reshape) # clobbered by lax_numpy
|
|
|
|
|
|
|
|
def _iter(tracer):
|
|
|
|
if tracer.ndim == 0:
|
|
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
|
|
else:
|
2020-06-03 22:40:48 +02:00
|
|
|
n = int(tracer.shape[0])
|
2020-01-08 13:17:55 -05:00
|
|
|
# return (index_in_dim(tracer, i, keepdims=False) for i in range(n))
|
|
|
|
return iter([index_in_dim(tracer, i, keepdims=False) for i in range(n)])
|
2018-11-17 18:03:33 -08:00
|
|
|
ShapedArray._iter = staticmethod(_iter)
|
|
|
|
|
|
|
|
# Add some ad handlers that use (or could use) lax primitives
|
|
|
|
|
|
|
|
def zeros_like_array(x):
|
2018-12-13 07:24:14 -08:00
|
|
|
return full_like(x, 0)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
for t in itertools.chain(dtypes.python_scalar_dtypes.keys(), array_types,
|
2020-03-28 11:56:12 -07:00
|
|
|
[xla.DeviceArray, pxla.ShardedDeviceArray]):
|
2018-11-17 18:03:33 -08:00
|
|
|
ad_util.jaxval_adders[t] = add
|
2019-01-06 11:59:33 -08:00
|
|
|
ad_util.jaxval_zeros_likers[xla.DeviceArray] = zeros_like_array
|
2020-03-28 11:56:12 -07:00
|
|
|
ad_util.jaxval_zeros_likers[pxla.ShardedDeviceArray] = zeros_like_array
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
### primitives
|
|
|
|
|
|
|
|
|
2019-11-15 10:02:51 -05:00
|
|
|
_input_dtype = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype)
|
|
|
|
_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
|
2020-07-14 13:05:31 -07:00
|
|
|
_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-08-21 00:22:53 -07:00
|
|
|
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
2018-11-17 18:03:33 -08:00
|
|
|
prim = Primitive(name)
|
|
|
|
prim.def_impl(partial(xla.apply_primitive, prim))
|
2019-11-22 10:53:11 -08:00
|
|
|
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
|
2019-08-21 00:22:53 -07:00
|
|
|
xla.translations[prim] = translation_rule or partial(standard_translate, name)
|
|
|
|
return prim
|
|
|
|
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert all(isinstance(arg, UnshapedArray) for arg in args), args
|
|
|
|
least_specialized = _max(
|
|
|
|
map(type, args), key=operator.attrgetter('array_abstraction_level'))
|
|
|
|
if least_specialized is ConcreteArray:
|
2019-11-22 10:53:11 -08:00
|
|
|
return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
|
2018-11-17 18:03:33 -08:00
|
|
|
elif least_specialized is ShapedArray:
|
|
|
|
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
|
|
|
|
elif least_specialized is UnshapedArray:
|
|
|
|
return UnshapedArray(dtype_rule(*args, **kwargs))
|
|
|
|
else:
|
|
|
|
raise TypeError(args, least_specialized)
|
|
|
|
|
|
|
|
|
|
|
|
def standard_translate(name, c, *args, **kwargs):
|
|
|
|
xla_opname = ''.join(term.capitalize() for term in name.split('_'))
|
2020-04-23 18:30:47 -04:00
|
|
|
return getattr(xops, xla_opname)(*args, **kwargs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
|
2019-11-15 10:02:51 -05:00
|
|
|
if not any(dtypes.issubdtype(aval.dtype, t) for t in accepted_dtypes):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = '{} does not accept dtype {}. Accepted dtypes are subtypes of {}.'
|
2020-07-14 13:05:31 -07:00
|
|
|
typename = str(np.dtype(aval.dtype).name)
|
2019-12-16 20:48:19 -05:00
|
|
|
accepted_typenames = (t.__name__ for t in accepted_dtypes)
|
2018-11-17 18:03:33 -08:00
|
|
|
raise TypeError(msg.format(name, typename, ', '.join(accepted_typenames)))
|
|
|
|
return result_dtype(aval.dtype)
|
|
|
|
|
|
|
|
|
2020-01-09 11:16:52 -05:00
|
|
|
def unop(result_dtype, accepted_dtypes, name, translation_rule=None):
|
2018-11-17 18:03:33 -08:00
|
|
|
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
|
2020-01-09 11:16:52 -05:00
|
|
|
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
|
|
|
|
translation_rule=translation_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
batching.defvectorized(prim)
|
2019-09-03 17:09:27 -07:00
|
|
|
masking.defvectorized(prim)
|
2018-11-17 18:03:33 -08:00
|
|
|
return prim
|
2019-05-29 10:39:51 -07:00
|
|
|
standard_unop = partial(unop, _identity)
|
2018-12-12 09:00:39 -08:00
|
|
|
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
|
2018-11-17 18:03:33 -08:00
|
|
|
aval_dtypes = [aval.dtype for aval in avals]
|
|
|
|
for i, (aval_dtype, types) in enumerate(zip(aval_dtypes, accepted_dtypes)):
|
2019-11-15 10:02:51 -05:00
|
|
|
if not any(dtypes.issubdtype(aval_dtype, t) for t in types):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ('{} does not accept dtype {} at position {}. '
|
|
|
|
'Accepted dtypes at position {} are subtypes of {}.')
|
2020-07-14 13:05:31 -07:00
|
|
|
typename = str(np.dtype(aval_dtype).name)
|
2019-12-16 20:48:19 -05:00
|
|
|
typenames = ', '.join(t.__name__ for t in types)
|
2018-11-17 18:03:33 -08:00
|
|
|
raise TypeError(msg.format(name, typename, i, i, typenames))
|
|
|
|
_check_same_dtypes(name, False, *aval_dtypes)
|
|
|
|
return result_dtype(*avals)
|
|
|
|
|
|
|
|
|
2019-02-19 11:30:31 -05:00
|
|
|
def _broadcasting_shape_rule(name, *avals):
|
2020-07-14 13:05:31 -07:00
|
|
|
shapes = np.array([aval.shape for aval in avals if aval.shape])
|
2018-11-17 18:03:33 -08:00
|
|
|
if not shapes.size:
|
|
|
|
return ()
|
|
|
|
if len({len(shape) for shape in shapes}) != 1:
|
|
|
|
msg = '{} got arrays of different rank: {}.'
|
|
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
2020-05-28 00:15:01 +02:00
|
|
|
result_shape = _try_broadcast_shapes(shapes)
|
|
|
|
if result_shape is None:
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = '{} got incompatible shapes for broadcasting: {}.'
|
|
|
|
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
2020-05-28 00:15:01 +02:00
|
|
|
return result_shape
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
def naryop(result_dtype, accepted_dtypes, name, translation_rule=None):
|
|
|
|
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
|
2019-02-19 11:30:31 -05:00
|
|
|
shape_rule = partial(_broadcasting_shape_rule, name)
|
2019-02-01 11:07:45 -05:00
|
|
|
prim = standard_primitive(shape_rule, dtype_rule, name,
|
|
|
|
translation_rule=translation_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
batching.defbroadcasting(prim)
|
2020-01-15 13:13:11 -08:00
|
|
|
masking.defnaryop(prim)
|
2018-11-17 18:03:33 -08:00
|
|
|
return prim
|
2020-01-15 13:13:11 -08:00
|
|
|
standard_naryop = partial(naryop, _input_dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-05-06 16:15:17 +02:00
|
|
|
def _broadcast_translate(translate: Callable):
|
|
|
|
# Decorator for translation rules which adds explicit broadcasting of
|
|
|
|
# positional arguments. This is necessary only for a handful of primitives
|
|
|
|
# whose XLA implementations do not support broadcasting.
|
|
|
|
def _broadcast_array(array, array_shape, result_shape):
|
|
|
|
if array_shape == result_shape:
|
|
|
|
return array
|
|
|
|
bcast_dims = tuple(range(len(result_shape) - len(array_shape),
|
|
|
|
len(result_shape)))
|
|
|
|
result = xops.BroadcastInDim(array, result_shape, bcast_dims)
|
|
|
|
return result
|
|
|
|
|
|
|
|
def _broadcasted_translation_rule(c, *args, **kwargs):
|
2020-05-11 17:43:55 -04:00
|
|
|
shapes = [c.get_shape(arg).dimensions() for arg in args]
|
2020-05-06 16:15:17 +02:00
|
|
|
result_shape = broadcast_shapes(*shapes)
|
|
|
|
args = [_broadcast_array(arg, arg_shape, result_shape)
|
|
|
|
for arg, arg_shape in zip(args, shapes)]
|
|
|
|
return translate(c, *args, **kwargs)
|
|
|
|
return _broadcasted_translation_rule
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
# NOTE(mattjj): this isn't great for orchestrate fwd mode because it means JVPs
|
|
|
|
# get two extra ops in them: a reshape and a broadcast_in_dim (or sometimes just
|
|
|
|
# a broadcast). but saving the shape info with the primitives isn't great either
|
|
|
|
# because then we can't trace these ops without shape data.
|
|
|
|
def _brcast(x, *others):
|
2020-01-15 13:13:11 -08:00
|
|
|
# Used in jvprules to make naryop broadcasting explicit for transposability.
|
2019-02-11 16:18:13 -08:00
|
|
|
# Requires shape info during jvp tracing, which isn't strictly necessary.
|
|
|
|
# We don't need full numpy broadcasting, but otherwise the logic is the same
|
|
|
|
# so we reuse the broadcast_shapes function after filtering out scalars.
|
2020-07-14 13:05:31 -07:00
|
|
|
shapes = tuple(filter(None, map(np.shape, (x,) + others)))
|
2019-02-11 16:18:13 -08:00
|
|
|
shape = shapes and broadcast_shapes(*shapes)
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.shape(x) != shape:
|
2018-11-17 18:03:33 -08:00
|
|
|
return _brcast_to(x, shape)
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def _brcast_to(x, shape):
|
2020-07-14 13:05:31 -07:00
|
|
|
x_shape = np.shape(x)
|
2018-11-17 18:03:33 -08:00
|
|
|
assert x_shape != shape
|
|
|
|
if x_shape:
|
|
|
|
assert len(x_shape) == len(shape)
|
2020-07-14 13:05:31 -07:00
|
|
|
broadcast_dimensions, = np.where(np.equal(x_shape, shape))
|
|
|
|
squeezed_dimensions, = np.where(np.not_equal(x_shape, shape))
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
squeezed = squeeze(x, squeezed_dimensions)
|
|
|
|
return broadcast_in_dim(squeezed, shape, broadcast_dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return broadcast(x, shape)
|
|
|
|
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
_float = {np.floating}
|
|
|
|
_complex = {np.complexfloating}
|
|
|
|
_complex_elem_types = {np.float32, np.float64}
|
|
|
|
_int = {np.integer}
|
|
|
|
_bool = {np.bool_}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
_num = _int | _float | _complex
|
|
|
|
_any = _int | _float | _complex | _bool
|
2020-01-23 11:53:55 -05:00
|
|
|
_bool_or_int = _int | _bool
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
neg_p = standard_unop(_num, 'neg')
|
|
|
|
ad.deflinear(neg_p, lambda t: [neg(t)])
|
|
|
|
|
2020-01-09 11:16:52 -05:00
|
|
|
def _sign_translation_rule(c, x):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(x)
|
2020-01-09 11:16:52 -05:00
|
|
|
dtype = shape.numpy_dtype()
|
2020-07-14 13:05:31 -07:00
|
|
|
if dtypes.issubdtype(dtype, np.unsignedinteger):
|
|
|
|
zero = xb.constant(c, np.array(0, dtype=dtype))
|
2020-05-11 17:43:55 -04:00
|
|
|
dims = c.get_shape(x).dimensions()
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Select(xops.Eq(x, zero), xops.Broadcast(zero, dims),
|
2020-07-14 13:05:31 -07:00
|
|
|
xops.Broadcast(xb.constant(c, np.array(1, dtype=dtype)),
|
2020-04-23 18:30:47 -04:00
|
|
|
dims))
|
|
|
|
return xops.Sign(x)
|
2020-01-09 11:16:52 -05:00
|
|
|
|
|
|
|
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(sign_p)
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
nextafter_p = standard_naryop(
|
2019-12-11 16:41:24 -05:00
|
|
|
[_float, _float], 'nextafter',
|
2020-04-23 18:30:47 -04:00
|
|
|
translation_rule=lambda c, x1, x2: xops.NextAfter(x1, x2))
|
2019-12-11 16:41:24 -05:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
floor_p = standard_unop(_float, 'floor')
|
|
|
|
ad.defjvp_zero(floor_p)
|
|
|
|
|
|
|
|
ceil_p = standard_unop(_float, 'ceil')
|
|
|
|
ad.defjvp_zero(ceil_p)
|
|
|
|
|
|
|
|
round_p = standard_unop(_float, 'round')
|
|
|
|
ad.defjvp_zero(round_p)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(is_finite_p)
|
|
|
|
|
|
|
|
exp_p = standard_unop(_float | _complex, 'exp')
|
2020-03-17 22:07:53 -07:00
|
|
|
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
iad.definverse(exp_p, lambda r, x: log(r))
|
|
|
|
# For exp_p it is more efficient to use the reconstructed output for the vjp
|
|
|
|
# rule instead of computing it again from the input.
|
|
|
|
iad.primitive_ivjps[exp_p] = lambda x, y, ct: [[log(y[0])], [ct[0] * y[0]]]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
log_p = standard_unop(_float | _complex, 'log')
|
|
|
|
ad.defjvp(log_p, lambda g, x: div(g, x))
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
iad.definverse(log_p, lambda r, x: exp(r))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
expm1_p = standard_unop(_float | _complex, 'expm1')
|
|
|
|
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
|
|
|
|
|
|
|
log1p_p = standard_unop(_float | _complex, 'log1p')
|
|
|
|
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
|
|
|
|
|
|
|
tanh_p = standard_unop(_float | _complex, 'tanh')
|
2019-05-24 11:07:08 -04:00
|
|
|
ad.defjvp2(tanh_p, lambda g, ans, x: mul(g, sub(_one(x), mul(ans, ans))))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
sin_p = standard_unop(_float | _complex, 'sin')
|
|
|
|
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
|
|
|
|
|
|
|
cos_p = standard_unop(_float | _complex, 'cos')
|
|
|
|
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
atan2_p = standard_naryop([_float, _float], 'atan2')
|
2019-04-17 19:53:06 -04:00
|
|
|
ad.defjvp(atan2_p,
|
2019-04-17 20:54:01 -04:00
|
|
|
lambda g, x, y: _brcast(g, y) * (y / (square(x) + square(y))),
|
|
|
|
lambda g, x, y: _brcast(g, x) * -x / (square(x) + square(y)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-19 07:29:37 -07:00
|
|
|
sinh_p = standard_unop(_float | _complex, 'sinh')
|
|
|
|
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
|
|
|
|
|
|
|
|
cosh_p = standard_unop(_float | _complex, 'cosh')
|
|
|
|
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
|
|
|
|
|
|
|
asinh_p = standard_unop(_float | _complex, 'asinh')
|
|
|
|
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
|
|
|
|
|
|
|
|
acosh_p = standard_unop(_float | _complex, 'acosh')
|
|
|
|
ad.defjvp(acosh_p,
|
|
|
|
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
|
|
|
|
|
|
|
|
atanh_p = standard_unop(_float | _complex, 'atanh')
|
|
|
|
ad.defjvp(atanh_p,
|
|
|
|
lambda g, x: mul(g, reciprocal((_one(x) - x) * (_one(x) + x))))
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
regularized_incomplete_beta_p = standard_naryop(
|
2020-05-06 16:15:17 +02:00
|
|
|
[_float, _float, _float], 'regularized_incomplete_beta',
|
|
|
|
translation_rule=_broadcast_translate(
|
|
|
|
partial(standard_translate, 'regularized_incomplete_beta')))
|
2020-01-15 13:13:11 -08:00
|
|
|
|
|
|
|
def betainc_gradx(g, a, b, x):
|
|
|
|
lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
|
|
|
|
partial_x = exp((b - 1) * log1p(-x) +
|
|
|
|
(a - 1) * log(x) - lbeta)
|
|
|
|
return partial_x * g
|
|
|
|
|
|
|
|
def betainc_grad_not_implemented(g, a, b, x):
|
|
|
|
raise ValueError("Betainc gradient with respect to a and b not supported.")
|
|
|
|
|
|
|
|
ad.defjvp(regularized_incomplete_beta_p,
|
|
|
|
betainc_grad_not_implemented,
|
|
|
|
betainc_grad_not_implemented,
|
|
|
|
betainc_gradx)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
lgamma_p = standard_unop(_float, 'lgamma')
|
|
|
|
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
|
|
|
|
|
|
|
digamma_p = standard_unop(_float, 'digamma')
|
|
|
|
|
2020-05-06 16:15:17 +02:00
|
|
|
igamma_p = standard_naryop(
|
|
|
|
[_float, _float], 'igamma',
|
|
|
|
translation_rule=_broadcast_translate(partial(standard_translate, 'igamma')))
|
|
|
|
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a',
|
|
|
|
translation_rule=_broadcast_translate(partial(standard_translate,
|
|
|
|
'igamma_grad_a')))
|
2020-01-29 08:25:21 -08:00
|
|
|
|
|
|
|
def igamma_gradx(g, a, x):
|
2020-05-06 16:15:17 +02:00
|
|
|
return _brcast(g, a, x) * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
|
2020-01-29 08:25:21 -08:00
|
|
|
|
2020-05-05 17:10:31 -07:00
|
|
|
def igamma_grada(g, a, x):
|
2020-05-06 16:15:17 +02:00
|
|
|
return _brcast(g, a, x) * igamma_grad_a(a, x)
|
2020-01-29 08:25:21 -08:00
|
|
|
|
2020-05-05 17:10:31 -07:00
|
|
|
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
|
2020-01-29 08:25:21 -08:00
|
|
|
|
2020-05-06 16:15:17 +02:00
|
|
|
igammac_p = standard_naryop(
|
|
|
|
[_float, _float], 'igammac',
|
|
|
|
translation_rule=_broadcast_translate(partial(standard_translate, 'igammac')))
|
2020-01-29 08:25:21 -08:00
|
|
|
|
|
|
|
def igammac_gradx(g, a, x):
|
|
|
|
return -igamma_gradx(g, a, x)
|
|
|
|
|
2020-05-05 17:10:31 -07:00
|
|
|
def igammac_grada(g, a, x):
|
|
|
|
return -igamma_grada(g, a, x)
|
|
|
|
|
|
|
|
ad.defjvp(igammac_p, igammac_grada, igammac_gradx)
|
2020-01-29 08:25:21 -08:00
|
|
|
|
2020-06-19 06:34:18 -07:00
|
|
|
random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad',
|
|
|
|
translation_rule=_broadcast_translate(partial(standard_translate,
|
|
|
|
'random_gamma_grad')))
|
|
|
|
|
2019-10-21 10:30:55 -04:00
|
|
|
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
|
|
|
|
ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
|
|
|
|
|
|
|
|
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
|
|
|
|
def _bessel_i1e_jvp(g, y, x):
|
2019-11-15 10:02:51 -05:00
|
|
|
eps = dtypes.finfo(_dtype(x)).eps
|
2019-10-21 10:30:55 -04:00
|
|
|
x_is_not_tiny = abs(x) > eps
|
|
|
|
safe_x = select(x_is_not_tiny, x, full_like(x, eps))
|
|
|
|
dy_dx = bessel_i0e(safe_x) - y * (sign(safe_x) + reciprocal(safe_x))
|
|
|
|
dy_dx = select(x_is_not_tiny, dy_dx, full_like(x, 0.5))
|
|
|
|
return g * dy_dx
|
|
|
|
ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
erf_p = standard_unop(_float, 'erf')
|
2020-07-14 13:05:31 -07:00
|
|
|
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
2018-11-17 18:03:33 -08:00
|
|
|
mul(g, exp(neg(square(x))))))
|
|
|
|
|
|
|
|
erfc_p = standard_unop(_float, 'erfc')
|
2020-07-14 13:05:31 -07:00
|
|
|
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
2018-11-17 18:03:33 -08:00
|
|
|
mul(neg(g), exp(neg(square(x))))))
|
|
|
|
|
|
|
|
erf_inv_p = standard_unop(_float, 'erf_inv')
|
2020-07-14 13:05:31 -07:00
|
|
|
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
|
2018-11-17 18:03:33 -08:00
|
|
|
mul(g, exp(square(ans)))))
|
|
|
|
|
2019-01-11 18:22:43 -05:00
|
|
|
real_p = unop(_complex_basetype, _complex, 'real')
|
2020-07-14 13:05:31 -07:00
|
|
|
ad.deflinear(real_p, lambda t: [complex(t, np.zeros((), _dtype(t)))])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-11 18:22:43 -05:00
|
|
|
imag_p = unop(_complex_basetype, _complex, 'imag')
|
2019-07-05 14:32:04 -07:00
|
|
|
ad.defjvp(imag_p, lambda g, _: real(mul(_const(g, -1j), g)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
_complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
|
2020-01-15 13:13:11 -08:00
|
|
|
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
2019-01-11 18:22:43 -05:00
|
|
|
'complex')
|
2019-07-05 14:39:32 -07:00
|
|
|
ad.deflinear(complex_p, lambda t: [real(t), imag(neg(t))])
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-10-22 19:53:59 -04:00
|
|
|
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
|
2018-12-12 09:00:39 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _conj_transpose_rule(t, x, *, input_dtype):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(x)
|
2020-07-14 13:05:31 -07:00
|
|
|
if dtypes.issubdtype(input_dtype, np.complexfloating):
|
2018-12-12 09:00:39 -08:00
|
|
|
return [conj(t)]
|
|
|
|
else:
|
|
|
|
return [real(t)]
|
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
xla.translations[conj_p] = lambda c, x, **kwargs: xops.Conj(x)
|
2018-12-12 09:00:39 -08:00
|
|
|
ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
abs_p = unop(_complex_basetype, _num, 'abs')
|
2019-09-18 23:55:31 -07:00
|
|
|
|
|
|
|
def _abs_jvp_rule(g, ans, x):
|
|
|
|
if _iscomplex(x):
|
|
|
|
return _maybe_real(mul(g, div(_maybe_conj(x),
|
|
|
|
_replace_zero(convert_element_type(ans, _dtype(x))))))
|
|
|
|
else:
|
|
|
|
return select(ge(x, _zero(x)), g, neg(g))
|
|
|
|
ad.defjvp2(abs_p, _abs_jvp_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
|
|
|
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
|
|
|
|
2019-05-29 12:51:24 -04:00
|
|
|
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
2020-03-17 22:07:53 -07:00
|
|
|
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
|
2019-05-29 12:51:24 -04:00
|
|
|
|
2019-09-04 15:06:46 -07:00
|
|
|
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
|
|
|
ad.defjvp2(rsqrt_p,
|
|
|
|
lambda g, ans, x:
|
2020-03-17 22:07:53 -07:00
|
|
|
mul(g, mul(_const(x, -0.5), pow(x, _const(x, -1.5)))))
|
2019-09-04 15:06:46 -07:00
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
|
2019-02-15 07:04:57 -08:00
|
|
|
|
2019-10-28 22:37:01 +08:00
|
|
|
def _pow_jvp_lhs(g, ans, x, y):
|
2019-02-15 18:32:50 -08:00
|
|
|
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
|
2020-03-17 22:07:53 -07:00
|
|
|
return mul(_brcast(g, y), jac)
|
2019-02-15 07:04:57 -08:00
|
|
|
|
2019-10-28 22:37:01 +08:00
|
|
|
def _pow_jvp_rhs(g, ans, x, y):
|
|
|
|
return mul(_brcast(g, x), mul(log(_replace_zero(x)), ans))
|
2019-02-15 07:04:57 -08:00
|
|
|
|
2019-10-28 22:37:01 +08:00
|
|
|
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
2020-05-18 17:54:20 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _integer_pow_dtype_rule(x, *, y):
|
|
|
|
dtype = unop_dtype_rule(_identity, _int | _float | _complex, 'integer_pow', x)
|
2020-07-14 13:05:31 -07:00
|
|
|
if y < 0 and dtypes.issubdtype(dtype, np.integer):
|
2020-05-18 17:54:20 -04:00
|
|
|
raise TypeError("Integers cannot be raised to negative powers, got "
|
|
|
|
f"integer_pow({x}, {y})")
|
|
|
|
return dtype
|
|
|
|
|
|
|
|
def _integer_pow_translation_rule(c, x, *, y):
|
|
|
|
if y == 0:
|
|
|
|
shape = c.get_shape(x)
|
2020-07-14 13:05:31 -07:00
|
|
|
return xb.constant(c, np.array(1, dtype=shape.numpy_dtype()))
|
2020-05-18 17:54:20 -04:00
|
|
|
is_reciprocal = y < 0
|
|
|
|
if is_reciprocal:
|
|
|
|
y = -y
|
|
|
|
acc = None
|
|
|
|
while y > 0:
|
|
|
|
if y & 1:
|
|
|
|
acc = x if acc is None else xops.Mul(acc, x)
|
|
|
|
y >>= 1
|
|
|
|
if y > 0:
|
|
|
|
x = xops.Mul(x, x)
|
|
|
|
return xops.Reciprocal(acc) if is_reciprocal else acc
|
|
|
|
|
|
|
|
def _integer_pow_jvp(g, x, *, y):
|
|
|
|
return g if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))
|
|
|
|
|
|
|
|
integer_pow_p = standard_primitive(
|
|
|
|
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow',
|
|
|
|
translation_rule=_integer_pow_translation_rule)
|
|
|
|
batching.defvectorized(integer_pow_p)
|
|
|
|
masking.defvectorized(integer_pow_p)
|
|
|
|
ad.defjvp(integer_pow_p, _integer_pow_jvp)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
|
|
|
|
2020-01-23 11:53:55 -05:00
|
|
|
not_p = standard_unop(_bool_or_int, 'not')
|
2020-06-16 22:48:25 -04:00
|
|
|
ad.defjvp_zero(not_p)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-23 11:53:55 -05:00
|
|
|
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(and_p)
|
|
|
|
|
2020-01-23 11:53:55 -05:00
|
|
|
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(or_p)
|
|
|
|
|
2020-01-23 11:53:55 -05:00
|
|
|
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(xor_p)
|
|
|
|
|
2020-07-28 19:46:00 -07:00
|
|
|
population_count_p = standard_unop(_int, 'population_count')
|
2020-04-28 06:32:52 +01:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _add_transpose(t, x, y):
|
2020-01-15 15:00:38 -08:00
|
|
|
# The following linearity assertion is morally true, but because in some cases we
|
|
|
|
# instantiate zeros for convenience, it doesn't always hold.
|
|
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
Add Cholesky, QR, and Triangular solve implementations.
* Adds lax.{cholesky,triangular_solve,qr}. Adds a JVP for Cholesky.
* Adds a transpose rule for add_p, needed by the Cholesky JVP.
* Adds np.linalg.{cholesky,qr,dot,matmul,trace}.
* Adds scipy.linalg.{cholesky,qr,solve_triangular,tril,triu}.
Pair programmed with mattjj.
2018-12-13 13:03:08 -05:00
|
|
|
return [t, t]
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
add_p = standard_naryop([_num, _num], 'add')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp(add_p, lambda g, x, y: _brcast(g, y), lambda g, x, y: _brcast(g, x))
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.primitive_transposes[add_p] = _add_transpose
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
def _add_inverse(r, x, y):
|
|
|
|
xr = r - y
|
|
|
|
yr = r - x
|
|
|
|
return xr, yr
|
|
|
|
iad.definverse(add_p, _add_inverse)
|
2019-02-19 11:45:16 -05:00
|
|
|
|
|
|
|
def _sub_transpose(t, x, y):
|
2020-01-15 15:00:38 -08:00
|
|
|
# The following linearity assertion is morally true, but because in some cases
|
|
|
|
# we instantiate zeros for convenience, it doesn't always hold.
|
|
|
|
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
2020-05-27 13:57:47 +00:00
|
|
|
return [t, neg(t) if type(t) is not ad_util.Zero else ad_util.Zero]
|
2019-02-19 11:45:16 -05:00
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
sub_p = standard_naryop([_num, _num], 'sub')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp(sub_p,
|
|
|
|
lambda g, x, y: _brcast(g, y),
|
|
|
|
lambda g, x, y: _brcast(neg(g), x))
|
2019-02-19 11:45:16 -05:00
|
|
|
ad.primitive_transposes[sub_p] = _sub_transpose
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
mul_p = standard_naryop([_num, _num], 'mul')
|
2019-06-17 11:49:54 -07:00
|
|
|
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul)
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
def _mul_inverse(r, x, y):
|
|
|
|
xr = r / y
|
|
|
|
yr = r / x
|
|
|
|
return xr, yr
|
|
|
|
iad.definverse(mul_p, _mul_inverse)
|
2019-02-15 18:32:50 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _div_transpose_rule(cotangent, x, y):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
2020-05-27 13:57:47 +00:00
|
|
|
res = ad_util.Zero if type(cotangent) is ad_util.Zero else div(cotangent, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
return res, None
|
2020-01-15 13:13:11 -08:00
|
|
|
div_p = standard_naryop([_num, _num], 'div')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp(div_p,
|
|
|
|
lambda g, x, y: div(_brcast(g, y), y),
|
2020-05-18 17:54:20 -04:00
|
|
|
lambda g, x, y: mul(mul(neg(_brcast(g, x)), x), integer_pow(y, -2)))
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.primitive_transposes[div_p] = _div_transpose_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
rem_p = standard_naryop([_num, _num], 'rem')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp(rem_p,
|
|
|
|
lambda g, x, y: _brcast(g, y),
|
2019-09-15 08:45:58 -07:00
|
|
|
lambda g, x, y: mul(_brcast(neg(g), x), floor(div(x, y))))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 11:07:45 -05:00
|
|
|
def _broadcasting_select(c, which, x, y):
|
|
|
|
"""Wrapper around XLA `Select` that broadcasts its arguments."""
|
|
|
|
which_shape, x_shape, y_shape = (
|
2020-05-11 17:43:55 -04:00
|
|
|
c.get_shape(t).dimensions() for t in (which, x, y))
|
2019-02-01 11:07:45 -05:00
|
|
|
out_shape = broadcast_shapes(which_shape, x_shape, y_shape)
|
|
|
|
bcast_dims = lambda shape: tuple(range(len(out_shape) - len(shape),
|
|
|
|
len(out_shape)))
|
2020-04-23 18:30:47 -04:00
|
|
|
which = xops.BroadcastInDim(which, out_shape, bcast_dims(which_shape))
|
|
|
|
x = xops.BroadcastInDim(x, out_shape, bcast_dims(x_shape))
|
|
|
|
y = xops.BroadcastInDim(y, out_shape, bcast_dims(y_shape))
|
|
|
|
return xops.Select(which, x, y)
|
2019-02-01 11:07:45 -05:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(x).numpy_dtype()
|
2020-07-14 13:05:31 -07:00
|
|
|
if dtypes.issubdtype(dtype, np.complexfloating):
|
2020-04-23 18:30:47 -04:00
|
|
|
rx = xops.Real(x)
|
|
|
|
ry = xops.Real(y)
|
2019-02-01 11:07:45 -05:00
|
|
|
return _broadcasting_select(
|
2020-04-23 18:30:47 -04:00
|
|
|
c, xops.Select(xops.Eq(rx, ry), cmp(xops.Imag(x), xops.Imag(y)),
|
|
|
|
cmp(rx, ry)),
|
2019-02-01 11:07:45 -05:00
|
|
|
x, y)
|
2020-04-23 18:30:47 -04:00
|
|
|
return minmax(x, y)
|
2019-02-01 11:07:45 -05:00
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
max_p = standard_naryop([_any, _any], 'max', translation_rule=partial(
|
2020-04-23 18:30:47 -04:00
|
|
|
_minmax_translation_rule, minmax=xops.Max, cmp=xops.Gt))
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp2(max_p,
|
|
|
|
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
|
|
|
|
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
min_p = standard_naryop([_any, _any], 'min', translation_rule=partial(
|
2020-04-23 18:30:47 -04:00
|
|
|
_minmax_translation_rule, minmax=xops.Min, cmp=xops.Lt))
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp2(min_p,
|
|
|
|
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
|
|
|
|
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
|
|
|
|
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(shift_left_p)
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(shift_right_arithmetic_p)
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(shift_right_logical_p)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(eq_p)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(ne_p)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
ge_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ge')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(ge_p)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
gt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'gt')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(gt_p)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
le_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'le')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(le_p)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
lt_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'lt')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(lt_p)
|
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _convert_element_type_shape_rule(operand, *, new_dtype, old_dtype):
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand.shape
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _convert_element_type_dtype_rule(operand, *, new_dtype, old_dtype):
|
2018-11-17 18:03:33 -08:00
|
|
|
return new_dtype
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _convert_element_type_translation_rule(c, operand, *, new_dtype, old_dtype):
|
2020-07-14 13:05:31 -07:00
|
|
|
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
|
|
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
2020-04-23 18:30:47 -04:00
|
|
|
operand = xops.Real(operand)
|
2019-08-04 12:34:03 -04:00
|
|
|
new_etype = xla_client.dtype_to_etype(new_dtype)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.ConvertElementType(operand, new_element_type=new_etype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
def _convert_element_type_transpose_rule(t, *, new_dtype, old_dtype):
|
|
|
|
assert t.dtype == new_dtype, (t.dtype, new_dtype)
|
|
|
|
return [convert_element_type_p.bind(t, new_dtype=old_dtype,
|
|
|
|
old_dtype=new_dtype)]
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
convert_element_type_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
|
|
|
'convert_element_type', _convert_element_type_translation_rule)
|
2020-04-12 15:35:35 -04:00
|
|
|
ad.deflinear(convert_element_type_p, _convert_element_type_transpose_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
batching.defvectorized(convert_element_type_p)
|
2019-09-03 17:09:27 -07:00
|
|
|
masking.defvectorized(convert_element_type_p)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand.shape
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):
|
2018-11-17 18:03:33 -08:00
|
|
|
return new_dtype
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _bitcast_convert_type_translation_rule(c, operand, *, new_dtype):
|
2018-11-17 18:03:33 -08:00
|
|
|
new_etype = xla_bridge.dtype_to_etype(new_dtype)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.BitcastConvertType(operand, new_element_type=new_etype)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
bitcast_convert_type_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule,
|
|
|
|
'bitcast_convert_type', _bitcast_convert_type_translation_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp_zero(bitcast_convert_type_p)
|
|
|
|
batching.defvectorized(bitcast_convert_type_p)
|
2019-09-03 17:09:27 -07:00
|
|
|
masking.defvectorized(bitcast_convert_type_p)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _conv_general_dilated_shape_rule(
|
2020-06-02 10:27:14 -04:00
|
|
|
lhs: ShapedArray, rhs: ShapedArray, *, window_strides, padding,
|
|
|
|
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
|
|
|
|
batch_group_count, **unused_kwargs) -> Tuple[int, ...]:
|
2018-12-10 17:18:56 -08:00
|
|
|
assert type(dimension_numbers) is ConvDimensionNumbers
|
2019-06-17 20:40:31 +02:00
|
|
|
if not feature_group_count > 0:
|
|
|
|
msg = ("conv_general_dilated feature_group_count "
|
|
|
|
"must be a positive integer, got {}.")
|
|
|
|
raise ValueError(msg.format(feature_group_count))
|
2019-06-15 13:38:55 -07:00
|
|
|
lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]]
|
|
|
|
quot, rem = divmod(lhs_feature_count, feature_group_count)
|
|
|
|
if rem:
|
|
|
|
msg = ("conv_general_dilated feature_group_count must divide lhs feature "
|
|
|
|
"dimension size, but {} does not divide {}.")
|
|
|
|
raise ValueError(msg.format(feature_group_count, lhs_feature_count))
|
|
|
|
if quot != rhs.shape[dimension_numbers.rhs_spec[1]]:
|
|
|
|
msg = ("conv_general_dilated lhs feature dimension size divided by "
|
|
|
|
"feature_group_count must equal the rhs input feature dimension "
|
|
|
|
"size, but {} // {} != {}.")
|
|
|
|
raise ValueError(msg.format(lhs_feature_count, feature_group_count,
|
|
|
|
rhs.shape[dimension_numbers.rhs_spec[1]]))
|
|
|
|
if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
|
|
|
|
msg = ("conv_general_dilated rhs output feature dimension size must be a "
|
|
|
|
"multiple of feature_group_count, but {} is not a multiple of {}.")
|
|
|
|
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
|
|
|
feature_group_count))
|
2020-04-09 16:21:30 -04:00
|
|
|
|
|
|
|
if not batch_group_count > 0:
|
|
|
|
msg = ("conv_general_dilated batch_group_count "
|
|
|
|
"must be a positive integer, got {}.")
|
|
|
|
raise ValueError(msg.format(batch_group_count))
|
|
|
|
lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]]
|
|
|
|
if lhs_batch_count % batch_group_count != 0:
|
|
|
|
msg = ("conv_general_dilated batch_group_count must divide lhs batch "
|
|
|
|
"dimension size, but {} does not divide {}.")
|
|
|
|
raise ValueError(msg.format(batch_group_count, lhs_batch_count))
|
2020-06-03 10:33:19 -07:00
|
|
|
|
|
|
|
if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
|
2020-04-09 16:21:30 -04:00
|
|
|
msg = ("conv_general_dilated rhs output feature dimension size must be a "
|
|
|
|
"multiple of batch_group_count, but {} is not a multiple of {}.")
|
|
|
|
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
2020-06-02 10:27:14 -04:00
|
|
|
batch_group_count))
|
2020-04-09 16:21:30 -04:00
|
|
|
|
2020-06-03 10:33:19 -07:00
|
|
|
if batch_group_count > 1 and feature_group_count > 1:
|
2020-04-09 16:21:30 -04:00
|
|
|
msg = ("At most one of batch_group_count and feature_group_count may be > "
|
|
|
|
"1, got batch_group_count={} and feature_group_count={}")
|
|
|
|
raise ValueError(msg.format(batch_group_count, feature_group_count))
|
|
|
|
|
2018-12-10 17:18:56 -08:00
|
|
|
lhs_perm, rhs_perm, out_perm = dimension_numbers
|
2020-07-14 13:05:31 -07:00
|
|
|
lhs_trans = _dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation)
|
|
|
|
rhs_trans = _dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
|
2020-04-09 16:21:30 -04:00
|
|
|
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
|
|
|
|
batch_group_count)
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.take(out_trans, np.argsort(out_perm)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _conv_general_dilated_dtype_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
2018-11-17 18:03:33 -08:00
|
|
|
dimension_numbers, **unused_kwargs):
|
2020-07-13 14:44:24 -04:00
|
|
|
return naryop_dtype_rule(_input_dtype, [_float | _complex, _float | _complex],
|
2019-06-27 17:17:04 -04:00
|
|
|
'conv_general_dilated', lhs, rhs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-06-15 13:38:55 -07:00
|
|
|
_conv_spec_transpose = lambda spec: (spec[1], spec[0]) + spec[2:]
|
2018-12-10 17:18:56 -08:00
|
|
|
_conv_sdims = lambda spec: spec[2:]
|
|
|
|
|
2020-04-09 16:21:30 -04:00
|
|
|
# Understanding the convolution transpose rules:
|
|
|
|
# Ignoring the spatial dimensions, let m = batch, j = input feature,
|
|
|
|
# k = output feature.
|
|
|
|
#
|
|
|
|
# Convolution computes the following contraction:
|
|
|
|
# Forward: [m, j] [j, k] -> [m, k]
|
|
|
|
#
|
|
|
|
# The transposes are similar to the rules for transposing a matmul:
|
|
|
|
# LHS transpose: [m, k] [k, j] -> [m, j]
|
|
|
|
# RHS transpose: [j, m] [m, k] -> [j, k]
|
|
|
|
#
|
|
|
|
# With feature grouping, we have the following signatures:
|
|
|
|
# Forward: [m, gj] [j, gk] -> [m, gk]
|
|
|
|
# LHS transpose: [m, gk] [k, gj] -> [m, gj]
|
|
|
|
# --> implemented as feature grouping after transposing the group from the
|
|
|
|
# kernel input features to the kernel output features.
|
|
|
|
# RHS transpose: [gj, m] [m, gk] -> [j, gk]
|
|
|
|
# --> which is batch grouping.
|
|
|
|
#
|
|
|
|
# With batch grouping, we have the following signatures:
|
|
|
|
# Forward: [gm,j] [j,gk]->[m,gk]
|
|
|
|
# LHS transpose: [m, gk][gk, j] -> [gm, j]
|
|
|
|
# --> implemented as feature grouping with transposing the group on the kernel
|
|
|
|
# and the output.
|
|
|
|
# RHS transpose: [j, gm][m, gk] -> [j, gk]
|
|
|
|
# --> which is feature grouping.
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _conv_general_dilated_transpose_lhs(
|
2020-04-07 09:38:10 -04:00
|
|
|
g, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
2020-04-09 16:21:30 -04:00
|
|
|
dimension_numbers, feature_group_count, batch_group_count,
|
2019-06-28 09:00:32 -04:00
|
|
|
lhs_shape, rhs_shape, precision):
|
2018-12-10 17:18:56 -08:00
|
|
|
assert type(dimension_numbers) is ConvDimensionNumbers
|
2020-04-09 16:21:30 -04:00
|
|
|
assert batch_group_count == 1 or feature_group_count == 1
|
2018-12-10 17:18:56 -08:00
|
|
|
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
|
|
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
2019-06-15 13:38:55 -07:00
|
|
|
t_rhs_spec = _conv_spec_transpose(rhs_spec)
|
|
|
|
if feature_group_count > 1:
|
|
|
|
# in addition to switching the dims in the spec, need to move the feature
|
|
|
|
# group axis into the transposed rhs's output feature dim
|
|
|
|
rhs = _reshape_axis_out_of(rhs_spec[0], feature_group_count, rhs)
|
2019-06-17 20:40:31 +02:00
|
|
|
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
|
2020-04-09 16:21:30 -04:00
|
|
|
elif batch_group_count > 1:
|
|
|
|
rhs = _reshape_axis_out_of(rhs_spec[0], batch_group_count, rhs)
|
|
|
|
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
|
|
|
|
feature_group_count = batch_group_count
|
2019-02-15 12:54:02 -05:00
|
|
|
trans_dimension_numbers = ConvDimensionNumbers(out_spec, t_rhs_spec, lhs_spec)
|
2018-11-17 18:03:33 -08:00
|
|
|
padding = _conv_general_vjp_lhs_padding(
|
2020-07-14 13:05:31 -07:00
|
|
|
np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims),
|
|
|
|
window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation,
|
2018-11-17 18:03:33 -08:00
|
|
|
rhs_dilation)
|
|
|
|
revd_weights = rev(rhs, rhs_sdims)
|
2020-04-09 16:21:30 -04:00
|
|
|
out = conv_general_dilated(
|
2018-11-17 18:03:33 -08:00
|
|
|
g, revd_weights, window_strides=lhs_dilation, padding=padding,
|
|
|
|
lhs_dilation=window_strides, rhs_dilation=rhs_dilation,
|
2019-06-15 13:38:55 -07:00
|
|
|
dimension_numbers=trans_dimension_numbers,
|
2020-04-09 16:21:30 -04:00
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
batch_group_count=1, precision=precision)
|
|
|
|
if batch_group_count > 1:
|
|
|
|
out = _reshape_axis_out_of(lhs_spec[1], batch_group_count, out)
|
|
|
|
out = _reshape_axis_into(lhs_spec[1], lhs_spec[0], out)
|
|
|
|
return out
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _conv_general_dilated_transpose_rhs(
|
2020-04-07 09:38:10 -04:00
|
|
|
g, lhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
2020-04-09 16:21:30 -04:00
|
|
|
dimension_numbers: ConvDimensionNumbers, feature_group_count: int,
|
|
|
|
batch_group_count: int, lhs_shape, rhs_shape, precision):
|
2018-12-10 17:18:56 -08:00
|
|
|
assert type(dimension_numbers) is ConvDimensionNumbers
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.size(g) == 0:
|
2019-12-02 14:43:43 -05:00
|
|
|
# Avoids forming degenerate convolutions where the RHS has spatial size 0.
|
2020-05-27 13:57:47 +00:00
|
|
|
return ad_util.Zero
|
2018-12-10 17:18:56 -08:00
|
|
|
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
|
2019-06-15 13:38:55 -07:00
|
|
|
lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers)
|
2020-04-09 16:21:30 -04:00
|
|
|
assert batch_group_count == 1 or feature_group_count == 1
|
|
|
|
if batch_group_count > 1:
|
|
|
|
feature_group_count = batch_group_count
|
|
|
|
batch_group_count = 1
|
|
|
|
elif feature_group_count > 1:
|
2020-04-21 19:04:28 -07:00
|
|
|
batch_group_count = feature_group_count
|
|
|
|
feature_group_count = 1
|
2018-12-14 18:40:50 -08:00
|
|
|
trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans)
|
2018-11-17 18:03:33 -08:00
|
|
|
padding = _conv_general_vjp_rhs_padding(
|
2020-07-14 13:05:31 -07:00
|
|
|
np.take(lhs_shape, lhs_sdims), np.take(rhs_shape, rhs_sdims),
|
|
|
|
window_strides, np.take(g.shape, out_sdims), padding, lhs_dilation,
|
2018-11-17 18:03:33 -08:00
|
|
|
rhs_dilation)
|
|
|
|
return conv_general_dilated(
|
|
|
|
lhs, g, window_strides=rhs_dilation, padding=padding,
|
|
|
|
lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
|
2019-06-15 13:38:55 -07:00
|
|
|
dimension_numbers=trans_dimension_numbers,
|
2020-04-09 16:21:30 -04:00
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
batch_group_count=batch_group_count, precision=precision)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-13 14:44:24 -04:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _conv_general_dilated_translation_rule(
|
2020-07-13 14:44:24 -04:00
|
|
|
c, lhs, rhs, *, window_strides, padding,
|
|
|
|
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
|
|
|
|
batch_group_count, precision, expand_complex_convolutions, **unused_kwargs):
|
2018-12-10 17:18:56 -08:00
|
|
|
assert type(dimension_numbers) is ConvDimensionNumbers
|
|
|
|
dimension_numbers = _conv_general_proto(dimension_numbers)
|
2020-07-13 14:44:24 -04:00
|
|
|
precision_config = _precision_config(precision)
|
|
|
|
dtype = c.get_shape(lhs).numpy_dtype()
|
|
|
|
conv = lambda x, y: xops.ConvGeneralDilated(
|
|
|
|
x, y, window_strides, padding, lhs_dilation, rhs_dilation,
|
|
|
|
dimension_numbers, feature_group_count, batch_group_count,
|
|
|
|
precision_config=precision_config)
|
2020-07-14 13:05:31 -07:00
|
|
|
if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
|
2020-07-13 14:44:24 -04:00
|
|
|
# We use a trick for complex multiplication due to Gauss which uses three
|
|
|
|
# multiplications and five additions; instead of the naive method of four
|
|
|
|
# multiplications and two additions.
|
|
|
|
# https://en.wikipedia.org/wiki/Multiplication_algorithm#Complex_multiplication_algorithm
|
|
|
|
#
|
|
|
|
# This performance win comes with a trade-off in accuracy; especially in
|
|
|
|
# cases when the real and imaginary differ hugely in magnitude. The relative
|
|
|
|
# error bound (e.g. 1p-24 in case of float32) would be relative to the
|
|
|
|
# maximum of real and imaginary parts of the result instead of being
|
|
|
|
# satisfied by the real and imaginary parts independently of each other.
|
|
|
|
lhs_real, lhs_imag = xops.Real(lhs), xops.Imag(lhs)
|
|
|
|
rhs_real, rhs_imag = xops.Real(rhs), xops.Imag(rhs)
|
|
|
|
k1 = conv(xops.Add(lhs_real, lhs_imag), rhs_real)
|
|
|
|
k2 = conv(lhs_real, xops.Sub(rhs_imag, rhs_real))
|
|
|
|
k3 = conv(lhs_imag, xops.Add(rhs_real, rhs_imag))
|
|
|
|
return xops.Complex(xops.Sub(k1, k3), xops.Add(k1, k2))
|
|
|
|
return conv(lhs, rhs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _conv_general_dilated_batch_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
batched_args, batch_dims, *, window_strides, padding,
|
2019-06-15 13:38:55 -07:00
|
|
|
lhs_dilation, rhs_dilation, dimension_numbers,
|
2020-04-09 16:21:30 -04:00
|
|
|
feature_group_count, batch_group_count, precision, **unused_kwargs):
|
|
|
|
assert batch_group_count == 1 or feature_group_count == 1
|
2019-01-28 14:33:57 -08:00
|
|
|
lhs, rhs = batched_args
|
|
|
|
lhs_bdim, rhs_bdim = batch_dims
|
2019-06-15 13:38:55 -07:00
|
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
2019-01-28 14:33:57 -08:00
|
|
|
|
|
|
|
if lhs_bdim is not None and rhs_bdim is not None:
|
2019-06-15 13:38:55 -07:00
|
|
|
assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim]
|
2020-04-09 16:21:30 -04:00
|
|
|
if batch_group_count > 1:
|
|
|
|
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
|
|
|
|
batch_group_count *= lhs.shape[lhs_bdim]
|
|
|
|
else:
|
|
|
|
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs)
|
|
|
|
feature_group_count *= lhs.shape[lhs_bdim]
|
2019-06-15 13:38:55 -07:00
|
|
|
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
|
2019-06-28 09:00:32 -04:00
|
|
|
out = conv_general_dilated(
|
|
|
|
new_lhs, new_rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
2020-04-09 16:21:30 -04:00
|
|
|
dimension_numbers, feature_group_count=feature_group_count,
|
|
|
|
batch_group_count=batch_group_count,
|
2019-06-28 09:00:32 -04:00
|
|
|
precision=precision)
|
2019-06-15 13:38:55 -07:00
|
|
|
out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
|
|
|
|
return out, out_spec[1]
|
2019-01-28 14:33:57 -08:00
|
|
|
|
|
|
|
elif lhs_bdim is not None:
|
2020-04-09 16:21:30 -04:00
|
|
|
if batch_group_count == 1:
|
|
|
|
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
|
|
|
|
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
|
|
|
|
lhs_dilation, rhs_dilation, dimension_numbers,
|
|
|
|
feature_group_count, precision=precision)
|
|
|
|
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
|
|
|
|
return out, out_spec[0]
|
|
|
|
else:
|
|
|
|
new_lhs = _reshape_axis_out_of(lhs_spec[0] + int(lhs_bdim <= lhs_spec[0]),
|
|
|
|
batch_group_count, lhs)
|
|
|
|
new_lhs = _reshape_axis_into(lhs_bdim + int(lhs_spec[0] < lhs_bdim),
|
|
|
|
lhs_spec[0] + 1,
|
|
|
|
new_lhs)
|
|
|
|
new_lhs = _reshape_axis_into(lhs_spec[0], lhs_spec[0], new_lhs)
|
|
|
|
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
|
|
|
|
lhs_dilation, rhs_dilation, dimension_numbers,
|
|
|
|
feature_group_count, batch_group_count,
|
|
|
|
precision=precision)
|
|
|
|
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
|
|
|
|
return out, out_spec[0]
|
2019-06-15 13:38:55 -07:00
|
|
|
|
2019-01-28 14:33:57 -08:00
|
|
|
elif rhs_bdim is not None:
|
2020-04-09 16:21:30 -04:00
|
|
|
if feature_group_count == 1 and batch_group_count == 1:
|
2019-06-15 13:38:55 -07:00
|
|
|
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
|
|
|
|
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
|
2020-04-09 16:21:30 -04:00
|
|
|
lhs_dilation, rhs_dilation, dimension_numbers,
|
|
|
|
feature_group_count, batch_group_count,
|
|
|
|
precision=precision)
|
2019-06-15 13:38:55 -07:00
|
|
|
out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out)
|
|
|
|
return out, out_spec[1]
|
|
|
|
else:
|
2020-04-09 16:21:30 -04:00
|
|
|
# groups need to be outermost, so we need to factor them out of the
|
2019-06-15 13:38:55 -07:00
|
|
|
# rhs output feature dim, then factor the batch dim into the remaining rhs
|
2020-04-09 16:21:30 -04:00
|
|
|
# output feature dim, then put groups back in. We do something
|
|
|
|
# similar on the output. An alternative which would require more FLOPs but
|
2019-06-15 13:38:55 -07:00
|
|
|
# fewer reshapes would be to broadcast lhs.
|
2020-04-09 16:21:30 -04:00
|
|
|
group_count = (feature_group_count if feature_group_count > 1
|
|
|
|
else batch_group_count)
|
2019-06-15 13:38:55 -07:00
|
|
|
new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]),
|
2020-04-09 16:21:30 -04:00
|
|
|
group_count, rhs)
|
2019-06-15 13:38:55 -07:00
|
|
|
new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim),
|
|
|
|
rhs_spec[0] + 1,
|
|
|
|
new_rhs)
|
|
|
|
new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs)
|
|
|
|
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
|
2020-04-09 16:21:30 -04:00
|
|
|
lhs_dilation, rhs_dilation, dimension_numbers,
|
|
|
|
feature_group_count, batch_group_count,
|
|
|
|
precision=precision)
|
|
|
|
out = _reshape_axis_out_of(out_spec[1], group_count, out)
|
2019-06-15 13:38:55 -07:00
|
|
|
out = _reshape_axis_out_of(out_spec[1] + 1, rhs.shape[rhs_bdim], out)
|
|
|
|
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
|
|
|
|
return out, out_spec[1]
|
|
|
|
|
2020-06-03 22:40:48 +02:00
|
|
|
def _masked(padded_value, logical_shape, dimensions, value=0):
|
|
|
|
"""
|
|
|
|
Sets all padding to the given value (default is 0) in the given dimensions.
|
|
|
|
All values outside the logical shape are considered padding.
|
|
|
|
"""
|
|
|
|
if len(dimensions) == 0:
|
|
|
|
return padded_value
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
masks = [broadcasted_iota(np.int32, padded_value.shape, d) < logical_shape[d]
|
2020-06-03 22:40:48 +02:00
|
|
|
for d in dimensions]
|
|
|
|
mask_intersection = masks[0]
|
|
|
|
for mask in masks[1:]:
|
|
|
|
mask_intersection &= mask
|
|
|
|
return select(mask_intersection, padded_value, full_like(padded_value, value))
|
|
|
|
|
|
|
|
def _conv_general_dilated_masking_rule(
|
|
|
|
padded_vals, logical_shapes, window_strides, padding, lhs_dilation,
|
|
|
|
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
|
|
|
lhs_shape, rhs_shape, precision):
|
|
|
|
lhs, rhs = padded_vals
|
|
|
|
logical_lhs_shape, logical_rhs_shape = logical_shapes
|
|
|
|
|
|
|
|
o, i, *window_dimensions = dimension_numbers.rhs_spec
|
2020-07-14 13:05:31 -07:00
|
|
|
assert (np.all(np.take(rhs.shape, window_dimensions)
|
|
|
|
== np.take(logical_rhs_shape, window_dimensions))), \
|
2020-06-03 22:40:48 +02:00
|
|
|
"Conv filter masking not yet implemented."
|
|
|
|
|
|
|
|
n, c, *padded_dimensions = dimension_numbers.lhs_spec
|
|
|
|
|
|
|
|
return conv_general_dilated(
|
|
|
|
_masked(lhs, logical_lhs_shape, padded_dimensions),
|
|
|
|
_masked(rhs, logical_rhs_shape, (i,)),
|
|
|
|
window_strides=window_strides, padding=padding,
|
|
|
|
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation,
|
|
|
|
dimension_numbers=dimension_numbers,
|
|
|
|
feature_group_count=feature_group_count,
|
|
|
|
batch_group_count=batch_group_count,
|
|
|
|
precision=precision)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
conv_general_dilated_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
|
2020-07-13 14:44:24 -04:00
|
|
|
'conv_general_dilated', partial(_conv_general_dilated_translation_rule,
|
|
|
|
expand_complex_convolutions=False))
|
|
|
|
|
|
|
|
# TODO(b/161124619, b/161126248): XLA does not support complex convolution on
|
|
|
|
# CPU or GPU; on these backends, lower complex convolutions away.
|
|
|
|
xla.backend_specific_translations['cpu'][conv_general_dilated_p] = partial(
|
|
|
|
_conv_general_dilated_translation_rule, expand_complex_convolutions=True)
|
|
|
|
xla.backend_specific_translations['gpu'][conv_general_dilated_p] = partial(
|
|
|
|
_conv_general_dilated_translation_rule, expand_complex_convolutions=True)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defbilinear(conv_general_dilated_p,
|
2019-02-01 13:42:16 -05:00
|
|
|
_conv_general_dilated_transpose_lhs,
|
|
|
|
_conv_general_dilated_transpose_rhs)
|
2019-06-24 19:45:18 -07:00
|
|
|
batching.primitive_batchers[conv_general_dilated_p] = \
|
|
|
|
_conv_general_dilated_batch_rule
|
2020-06-03 22:40:48 +02:00
|
|
|
masking.masking_rules[conv_general_dilated_p] = \
|
|
|
|
_conv_general_dilated_masking_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-06-15 12:01:20 -07:00
|
|
|
def _reshape_axis_into(src, dst, x):
|
|
|
|
perm = [i for i in range(x.ndim) if i != src]
|
|
|
|
perm.insert(dst, src)
|
2020-07-14 13:05:31 -07:00
|
|
|
new_shape = list(np.delete(x.shape, src))
|
2019-06-15 12:01:20 -07:00
|
|
|
new_shape[dst] *= x.shape[src]
|
2019-06-17 19:39:14 -07:00
|
|
|
return reshape(x, new_shape, perm)
|
2019-06-15 12:01:20 -07:00
|
|
|
|
|
|
|
def _reshape_axis_out_of(src, size1, x):
|
|
|
|
shape = list(x.shape)
|
|
|
|
size2, ragged = divmod(shape[src], size1)
|
|
|
|
assert not ragged
|
|
|
|
shape[src:src+1] = [size1, size2]
|
|
|
|
return reshape(x, shape)
|
|
|
|
|
2019-07-23 21:45:41 -04:00
|
|
|
def _precision_config(precision):
|
2019-06-28 09:00:32 -04:00
|
|
|
if precision is not None:
|
|
|
|
config = xla_client.PrecisionConfig()
|
2019-06-28 12:48:44 -04:00
|
|
|
config.operand_precision.extend((precision, precision))
|
2019-07-23 21:45:41 -04:00
|
|
|
return config
|
|
|
|
return None
|
2019-06-28 09:00:32 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision):
|
2018-11-17 18:03:33 -08:00
|
|
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
2020-07-14 13:05:31 -07:00
|
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
2020-07-11 20:47:22 -07:00
|
|
|
for d in (lhs_contracting, lhs_batch)):
|
|
|
|
msg = ("dot_general requires lhs dimension numbers to be nonnegative and "
|
|
|
|
"less than the number of axes of the lhs value, got "
|
|
|
|
f"lhs_batch of {lhs_batch} and lhs_contracting of {lhs_contracting} "
|
|
|
|
f"for lhs of rank {lhs.ndim}")
|
|
|
|
raise TypeError(msg)
|
2020-07-14 13:05:31 -07:00
|
|
|
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, rhs.ndim))
|
2020-07-11 20:47:22 -07:00
|
|
|
for d in (rhs_contracting, rhs_batch)):
|
|
|
|
msg = ("dot_general requires rhs dimension numbers to be nonnegative and "
|
|
|
|
"less than the number of axes of the rhs value, got "
|
|
|
|
f"rhs_batch of {rhs_batch} and rhs_contracting of {rhs_contracting} "
|
|
|
|
f"for rhs of rank {rhs.ndim}")
|
|
|
|
raise TypeError(msg)
|
2018-11-17 18:03:33 -08:00
|
|
|
if len(lhs_batch) != len(rhs_batch):
|
|
|
|
msg = ("dot_general requires equal numbers of lhs_batch and rhs_batch "
|
|
|
|
"dimensions, got lhs_batch {} and rhs_batch {}.")
|
|
|
|
raise TypeError(msg.format(lhs_batch, rhs_batch))
|
2020-07-16 16:23:27 -04:00
|
|
|
lhs_contracting_set, lhs_batch_set = set(lhs_contracting), set(lhs_batch)
|
|
|
|
rhs_contracting_set, rhs_batch_set = set(rhs_contracting), set(rhs_batch)
|
|
|
|
if len(lhs_batch_set) != len(lhs_batch):
|
|
|
|
msg = ("dot_general requires lhs batch dimensions to be distinct, got "
|
|
|
|
f"lhs_batch {lhs_batch}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if len(rhs_batch_set) != len(rhs_batch):
|
|
|
|
msg = ("dot_general requires rhs batch dimensions to be distinct, got "
|
|
|
|
f"rhs_batch {rhs_batch}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if len(lhs_contracting_set) != len(lhs_contracting):
|
|
|
|
msg = ("dot_general requires lhs contracting dimensions to be distinct, "
|
|
|
|
f"got lhs_contracting {lhs_contracting}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if len(rhs_contracting_set) != len(rhs_contracting):
|
|
|
|
msg = ("dot_general requires rhs contracting dimensions to be distinct, "
|
|
|
|
f"got rhs_contracting {rhs_contracting}.")
|
|
|
|
raise TypeError(msg)
|
|
|
|
if lhs_contracting_set & lhs_batch_set:
|
|
|
|
msg = ("dot_general requires lhs batch dimensions to be disjoint from "
|
|
|
|
"contracting dimensions, got lhs_batch {} and lhs_contracting {}.")
|
|
|
|
raise TypeError(msg.format(lhs_batch, lhs_contracting))
|
|
|
|
if rhs_contracting_set & rhs_batch_set:
|
|
|
|
msg = ("dot_general requires rhs batch dimensions to be disjoint from "
|
|
|
|
"contracting dimensions, got rhs_batch {} and rhs_contracting {}.")
|
|
|
|
raise TypeError(msg.format(rhs_batch, rhs_contracting))
|
2020-07-14 13:05:31 -07:00
|
|
|
lhs_batch_shape = np.take(lhs.shape, lhs_batch)
|
|
|
|
rhs_batch_shape = np.take(rhs.shape, rhs_batch)
|
|
|
|
if not np.all(np.equal(lhs_batch_shape, rhs_batch_shape)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
|
|
|
|
"to have the same shape, got {} and {}.")
|
|
|
|
raise TypeError(msg.format(lhs_batch_shape, rhs_batch_shape))
|
2020-07-14 13:05:31 -07:00
|
|
|
lhs_contracting_shape = np.take(lhs.shape, lhs_contracting)
|
|
|
|
rhs_contracting_shape = np.take(rhs.shape, rhs_contracting)
|
|
|
|
if not np.all(np.equal(lhs_contracting_shape, rhs_contracting_shape)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("dot_general requires contracting dimensions to have the same "
|
|
|
|
"shape, got {} and {}.")
|
|
|
|
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
|
|
|
|
|
2020-07-16 16:23:27 -04:00
|
|
|
batch_shape = tuple(lhs_batch_shape)
|
|
|
|
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
|
2020-07-14 13:05:31 -07:00
|
|
|
lhs_tensored_shape = tuple(np.delete(lhs.shape, lhs_contract_or_batch))
|
2020-07-16 16:23:27 -04:00
|
|
|
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
|
2020-07-14 13:05:31 -07:00
|
|
|
rhs_tensored_shape = tuple(np.delete(rhs.shape, rhs_contract_or_batch))
|
2018-11-17 18:03:33 -08:00
|
|
|
return batch_shape + lhs_tensored_shape + rhs_tensored_shape
|
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision):
|
2020-07-16 16:23:27 -04:00
|
|
|
return naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
2019-06-28 09:00:32 -04:00
|
|
|
swap_ans=False):
|
2018-11-17 18:03:33 -08:00
|
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
|
|
x_ndim = g.ndim - y.ndim + len(x_batch) + 2 * len(x_contract)
|
|
|
|
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
|
|
|
y_kept = remaining(range(y.ndim), y_contract, y_batch)
|
|
|
|
if swap_ans:
|
|
|
|
ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
|
|
|
|
else:
|
|
|
|
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
|
|
|
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
2020-07-14 13:05:31 -07:00
|
|
|
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
|
|
|
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
2019-10-09 16:25:37 -07:00
|
|
|
return transpose(dot_general(g, y, dims, precision=precision),
|
|
|
|
tuple(out_axes))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision):
|
2018-11-17 18:03:33 -08:00
|
|
|
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
|
|
|
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
2020-04-07 09:38:10 -04:00
|
|
|
return _dot_general_transpose_lhs(
|
|
|
|
g, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
|
|
|
swap_ans=True)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
2019-06-28 09:00:32 -04:00
|
|
|
precision):
|
2019-09-10 13:58:23 -07:00
|
|
|
# there are three kinds of dimensions in a dot_general:
|
|
|
|
# - contraction dimensions appear in lhs and rhs but not the result
|
|
|
|
# - batch dimensions appear in lhs, rhs, and result
|
|
|
|
# - tensor product dimensions appear in the result and one of lhs or rhs
|
2018-12-09 06:47:38 -08:00
|
|
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
|
|
|
lhs, rhs = batched_args
|
|
|
|
lbd, rbd = batch_dims
|
|
|
|
assert lbd is not None or rbd is not None
|
2020-07-16 16:23:27 -04:00
|
|
|
def bump_dims(dims, b):
|
|
|
|
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
|
|
|
|
2019-09-10 13:58:23 -07:00
|
|
|
if lbd is not None and rbd is not None:
|
|
|
|
# adding a batch dimension
|
2020-07-16 16:23:27 -04:00
|
|
|
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
|
|
|
|
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
|
|
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
|
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
2019-09-10 13:58:23 -07:00
|
|
|
result_batch_dim = 0
|
2018-12-09 06:47:38 -08:00
|
|
|
else:
|
2019-09-10 13:58:23 -07:00
|
|
|
# adding a tensor product dimension
|
|
|
|
if lbd is not None:
|
2020-07-16 16:23:27 -04:00
|
|
|
other = tuple(d for d in range(lhs.ndim)
|
|
|
|
if d not in lhs_batch and d not in lhs_contract)
|
|
|
|
result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd)))
|
|
|
|
lhs_batch = bump_dims(lhs_batch, lbd)
|
|
|
|
lhs_contract = bump_dims(lhs_contract, lbd)
|
2019-09-10 13:58:23 -07:00
|
|
|
else:
|
2020-07-16 16:23:27 -04:00
|
|
|
other = tuple(d for d in range(rhs.ndim)
|
|
|
|
if d not in rhs_batch and d not in rhs_contract)
|
|
|
|
result_batch_dim = (lhs.ndim - len(lhs_contract) +
|
|
|
|
sum(np.less(other, rbd)))
|
|
|
|
rhs_batch = bump_dims(rhs_batch, rbd)
|
|
|
|
rhs_contract = bump_dims(rhs_contract, rbd)
|
|
|
|
|
2020-06-02 10:27:14 -04:00
|
|
|
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
2019-06-28 09:00:32 -04:00
|
|
|
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
|
|
|
precision=precision)
|
2019-10-08 13:06:43 -07:00
|
|
|
return batched_out, int(result_batch_dim)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-16 16:23:27 -04:00
|
|
|
def _dot_using_sum_of_products(lhs, rhs, *, dimension_numbers):
|
|
|
|
contract_dims, batch_dims = dimension_numbers
|
|
|
|
lhs_contract_dims, rhs_contract_dims = contract_dims
|
|
|
|
lhs_batch_dims, rhs_batch_dims = batch_dims
|
|
|
|
lhs_noncontract_dims = tuple(sorted(
|
|
|
|
set(range(np.ndim(lhs))) - set(lhs_batch_dims) - set(lhs_contract_dims)))
|
|
|
|
rhs_noncontract_dims = tuple(sorted(
|
|
|
|
set(range(np.ndim(rhs))) - set(rhs_batch_dims) - set(rhs_contract_dims)))
|
|
|
|
lhs = transpose(lhs,
|
|
|
|
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
|
|
|
|
rhs = transpose(rhs,
|
|
|
|
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
|
|
|
|
|
|
|
|
lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
|
|
|
|
lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
|
|
|
|
lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
|
|
|
|
|
|
|
|
rhs_start_expand = len(lhs_batch_dims)
|
|
|
|
rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
|
|
|
|
rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
|
|
|
|
|
|
|
|
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
|
|
|
|
len(rhs_noncontract_dims))
|
|
|
|
op_product = bitwise_and if lhs.dtype == np.bool_ else mul
|
|
|
|
op_sum = bitwise_or if lhs.dtype == np.bool_ else add
|
|
|
|
return reduce(op_product(lhs, rhs), _zero(lhs), op_sum,
|
|
|
|
tuple(range(out_ndim, out_ndim + len(lhs_contract_dims))))
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dot_general_translation_rule(c, lhs, rhs, *, dimension_numbers, precision):
|
2020-07-16 16:23:27 -04:00
|
|
|
dtype = c.get_shape(lhs).numpy_dtype()
|
|
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
|
|
return xops.DotGeneral(lhs, rhs,
|
|
|
|
xc.make_dot_dimension_numbers(dimension_numbers),
|
|
|
|
precision_config=_precision_config(precision))
|
|
|
|
else:
|
|
|
|
# TODO(b/134526360): XLA doesn't support bool or integer dots, so we emit a
|
|
|
|
# sum of products instead.
|
|
|
|
translation = xla.lower_fun(_dot_using_sum_of_products,
|
|
|
|
multiple_results=False)
|
|
|
|
return translation(c, lhs, rhs, dimension_numbers=dimension_numbers)
|
2019-03-22 14:10:38 -07:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
|
2019-10-08 14:22:51 -07:00
|
|
|
precision):
|
|
|
|
lhs, rhs = padded_vals
|
2020-06-03 22:40:48 +02:00
|
|
|
# Only need to mask off contraction dims of one side - we mask the lhs here
|
|
|
|
# but this is arbitrary. Could check the sizes of lhs and rhs and mask
|
|
|
|
# whichever is smallest.
|
|
|
|
lhs_shape, _ = logical_shapes
|
|
|
|
(lhs_contract, _), _ = dimension_numbers
|
|
|
|
return dot_general(_masked(lhs, lhs_shape, lhs_contract),
|
|
|
|
rhs, dimension_numbers, precision=precision)
|
2019-10-08 14:22:51 -07:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
2019-06-28 09:00:32 -04:00
|
|
|
_dot_general_dtype_rule, 'dot_general',
|
|
|
|
_dot_general_translation_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defbilinear(dot_general_p,
|
2019-02-01 13:42:16 -05:00
|
|
|
_dot_general_transpose_lhs, _dot_general_transpose_rhs)
|
|
|
|
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
2019-10-08 14:22:51 -07:00
|
|
|
masking.masking_rules[dot_general_p] = _dot_general_masking_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _broadcast_shape_rule(operand, sizes):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_shapelike('broadcast', 'sizes', sizes)
|
|
|
|
return tuple(sizes) + operand.shape
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _broadcast_batch_rule(batched_args, batch_dims, *, sizes):
|
2018-11-17 18:03:33 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
new_bdim = None if bdim is None else bdim + len(sizes)
|
|
|
|
return broadcast(operand, sizes), new_bdim
|
|
|
|
|
|
|
|
broadcast_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_broadcast_shape_rule, _input_dtype, 'broadcast')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.deflinear(broadcast_p, lambda t, sizes: [_reduce_sum(t, range(len(sizes)))])
|
2019-02-01 13:42:16 -05:00
|
|
|
batching.primitive_batchers[broadcast_p] = _broadcast_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _broadcast_in_dim_impl(operand, *, shape, broadcast_dimensions):
|
2020-07-14 13:05:31 -07:00
|
|
|
if type(operand) is xla.DeviceArray and np.all(
|
|
|
|
np.equal(operand.shape, np.take(shape, broadcast_dimensions))):
|
2020-04-07 09:38:10 -04:00
|
|
|
shape = _broadcast_in_dim_shape_rule(
|
|
|
|
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
aval = ShapedArray(shape, _dtype(operand))
|
|
|
|
lazy_expr = lazy.broadcast(operand._lazy_expr, shape, broadcast_dimensions)
|
2020-05-01 10:06:59 +03:00
|
|
|
return xla.DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
else:
|
|
|
|
return xla.apply_primitive(broadcast_in_dim_p, operand, shape=shape,
|
|
|
|
broadcast_dimensions=broadcast_dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_shapelike('broadcast_in_dim', 'shape', shape)
|
|
|
|
_check_shapelike('broadcast_in_dim', 'broadcast_dimensions',
|
|
|
|
broadcast_dimensions)
|
2020-07-14 13:05:31 -07:00
|
|
|
operand_ndim = np.ndim(operand)
|
2020-03-16 10:59:25 +01:00
|
|
|
if operand_ndim != len(broadcast_dimensions):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ('broadcast_in_dim broadcast_dimensions must have length equal to '
|
2020-03-16 09:54:58 +01:00
|
|
|
'operand ndim; got broadcast_dimensions {} for operand ndim {}.')
|
2020-03-16 10:59:25 +01:00
|
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim))
|
|
|
|
if len(shape) < operand_ndim:
|
2020-03-16 09:54:58 +01:00
|
|
|
msg = ('broadcast_in_dim target broadcast shape must have equal or higher rank '
|
|
|
|
'to the operand shape; got operand ndim {} and target broadcast ndim {}.')
|
2020-03-16 10:59:25 +01:00
|
|
|
raise TypeError(msg.format(operand_ndim, len(shape)))
|
2018-11-17 18:03:33 -08:00
|
|
|
if not set(broadcast_dimensions).issubset(set(range(len(shape)))):
|
|
|
|
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
|
|
|
|
'dimensions, got {} for operand ndim {} and shape {}.')
|
2020-03-16 10:59:25 +01:00
|
|
|
raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
|
2020-03-16 09:54:58 +01:00
|
|
|
if any(operand.shape[i] != 1 and operand.shape[i] != shape[broadcast_dimensions[i]]
|
2020-03-16 10:59:25 +01:00
|
|
|
for i in range(operand_ndim)):
|
2020-03-16 09:54:58 +01:00
|
|
|
msg = ('broadcast_in_dim operand dimension sizes must either be 1, or be '
|
|
|
|
'equal to their corresponding dimensions in the target broadcast shape; '
|
|
|
|
'got operand of shape {}, target broadcast shape {}, '
|
|
|
|
'broadcast_dimensions {} ')
|
|
|
|
raise TypeError(msg.format(operand.shape, shape, broadcast_dimensions))
|
|
|
|
if (len(broadcast_dimensions) != len(set(broadcast_dimensions)) or
|
|
|
|
tuple(broadcast_dimensions) != tuple(sorted(broadcast_dimensions))):
|
|
|
|
msg = ('broadcast_in_dim broadcast_dimensions must be strictly increasing; '
|
|
|
|
'got broadcast_dimensions {}')
|
|
|
|
raise TypeError(msg.format(broadcast_dimensions))
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
return shape
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _broadcast_in_dim_transpose_rule(t, *, shape, broadcast_dimensions):
|
2020-07-14 13:05:31 -07:00
|
|
|
axes = tuple(np.delete(range(len(shape)), broadcast_dimensions))
|
2018-11-17 18:03:33 -08:00
|
|
|
return [_reduce_sum(t, axes)]
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
|
2019-02-01 13:42:16 -05:00
|
|
|
broadcast_dimensions):
|
2018-11-17 18:03:33 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
new_operand = batching.moveaxis(operand, bdim, 0)
|
2019-07-06 11:47:50 -07:00
|
|
|
new_shape = (operand.shape[bdim],) + shape
|
2020-07-14 13:05:31 -07:00
|
|
|
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
|
2019-07-06 11:47:50 -07:00
|
|
|
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
broadcast_in_dim_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
broadcast_in_dim_p.def_impl(_broadcast_in_dim_impl)
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.deflinear(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
|
|
|
|
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _clamp_shape_rule(min, operand, max):
|
2018-11-17 18:03:33 -08:00
|
|
|
if min.shape and min.shape != operand.shape:
|
|
|
|
m = "clamp requires min.shape == operand.shape or min.shape == (), got {}."
|
|
|
|
raise TypeError(m.format(min.shape))
|
|
|
|
if max.shape and max.shape != operand.shape:
|
|
|
|
m = "clamp requires max.shape == operand.shape or max.shape == (), got {}."
|
|
|
|
raise TypeError(m.format(max.shape))
|
|
|
|
return operand.shape
|
|
|
|
|
2020-01-15 13:13:11 -08:00
|
|
|
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
|
2019-02-28 22:48:31 -05:00
|
|
|
'clamp')
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-28 22:48:31 -05:00
|
|
|
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp(clamp_p,
|
|
|
|
lambda g, min, operand, max:
|
|
|
|
select(bitwise_and(gt(min, operand), lt(min, max)),
|
|
|
|
_brcast(g, operand), _zeros(operand)),
|
|
|
|
lambda g, min, operand, max:
|
|
|
|
select(bitwise_and(gt(operand, min), lt(operand, max)),
|
|
|
|
g, _zeros(operand)),
|
|
|
|
lambda g, min, operand, max:
|
|
|
|
select(lt(max, operand), _brcast(g, operand), _zeros(operand)))
|
2020-06-01 19:07:11 -07:00
|
|
|
batching.defbroadcasting(clamp_p)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _concatenate_shape_rule(*operands, **kwargs):
|
2018-11-17 18:03:33 -08:00
|
|
|
dimension = kwargs.pop('dimension')
|
|
|
|
if not operands:
|
|
|
|
msg = "concatenate expects at least one operand, got 0."
|
|
|
|
raise TypeError(msg)
|
|
|
|
if not all(isinstance(operand, UnshapedArray) for operand in operands):
|
|
|
|
msg = "All objects to concatenate must be arrays, got {}."
|
|
|
|
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
|
|
|
raise TypeError(msg.format(type(op)))
|
|
|
|
if len(set(operand.ndim for operand in operands)) != 1:
|
|
|
|
msg = "Cannot concatenate arrays with different ranks, got {}."
|
|
|
|
raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands)))
|
2020-07-14 13:05:31 -07:00
|
|
|
shapes = np.array([operand.shape for operand in operands])
|
2018-11-17 18:03:33 -08:00
|
|
|
if not 0 <= dimension < shapes.shape[1]:
|
|
|
|
msg = "concatenate dimension out of bounds: dimension {} for shapes {}."
|
|
|
|
raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.delete(shapes[0] == shapes, dimension, axis=1)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
|
|
|
|
"other than the one being concatenated: dimension {} for shapes {}.")
|
|
|
|
raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
|
|
|
|
|
|
|
|
concat_size = sum(o.shape[dimension] for o in operands)
|
|
|
|
ex_shape = operands[0].shape
|
|
|
|
return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:]
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _concatenate_dtype_rule(*operands, **kwargs):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_same_dtypes('concatenate', False, *(o.dtype for o in operands))
|
|
|
|
return operands[0].dtype
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _concatenate_translation_rule(c, *operands, **kwargs):
|
2018-11-17 18:03:33 -08:00
|
|
|
dimension = kwargs.pop('dimension')
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.ConcatInDim(c, operands, dimension)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _concatenate_transpose_rule(t, *operands, dimension):
|
|
|
|
operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape
|
|
|
|
for o in operands]
|
2020-06-08 13:22:13 -07:00
|
|
|
if type(t) is ad_util.Zero:
|
2020-05-27 13:57:47 +00:00
|
|
|
return ad_util.Zero
|
2018-12-08 06:01:25 -08:00
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
limit_points = np.cumsum([shape[dimension] for shape in operand_shapes])
|
|
|
|
starts = np.zeros((len(operands), t.ndim), dtype=int)
|
2018-12-08 06:01:25 -08:00
|
|
|
starts[1:, dimension] = limit_points[:-1]
|
2020-07-14 13:05:31 -07:00
|
|
|
limits = np.tile(t.shape, (len(operands), 1))
|
2018-12-08 06:01:25 -08:00
|
|
|
limits[:, dimension] = limit_points
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return [slice(t, start, limit) if ad.is_undefined_primal(o) else None
|
2018-12-08 06:01:25 -08:00
|
|
|
for o, start, limit in zip(operands, starts, limits)]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
|
2018-12-11 16:07:28 -08:00
|
|
|
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
|
|
|
|
if bdim is not None)
|
2019-07-27 15:46:14 -07:00
|
|
|
operands = [batching.moveaxis(op, bdim, 0) if bdim is not None
|
2018-12-11 16:07:28 -08:00
|
|
|
else broadcast(op, (size,))
|
|
|
|
for op, bdim in zip(batched_args, batch_dims)]
|
|
|
|
return concatenate(operands, dimension + 1), 0
|
|
|
|
|
2019-09-03 17:09:27 -07:00
|
|
|
# The concatenate_p masking rule requires use of a while-loop construct and so
|
|
|
|
# is defined in lax_control_flow.py
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
concatenate_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate',
|
|
|
|
_concatenate_translation_rule)
|
|
|
|
ad.deflinear(concatenate_p, _concatenate_transpose_rule)
|
|
|
|
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
|
|
|
|
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _pad_dtype_rule(operand, padding_value, *, padding_config):
|
2018-11-17 18:03:33 -08:00
|
|
|
if operand.dtype != padding_value.dtype:
|
|
|
|
msg = "pad operand and padding_value must be same dtype: got {} and {}."
|
|
|
|
raise TypeError(msg.format(operand.dtype, padding_value.dtype))
|
|
|
|
|
Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
2020-01-16 00:36:00 +00:00
|
|
|
return _input_dtype(operand, padding_value)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _pad_shape_rule(operand, padding_value, *, padding_config):
|
2018-11-17 18:03:33 -08:00
|
|
|
lo, hi, interior = zip(*padding_config)
|
2020-07-14 13:05:31 -07:00
|
|
|
out_shape = np.add(
|
|
|
|
np.add(np.add(lo, hi), operand.shape),
|
|
|
|
np.maximum(0, np.multiply(interior, np.subtract(operand.shape, 1))))
|
2018-11-17 18:03:33 -08:00
|
|
|
return tuple(out_shape)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _pad_transpose(t, operand, padding_value, *, padding_config):
|
2020-06-08 13:22:13 -07:00
|
|
|
if type(t) is ad_util.Zero:
|
2020-05-27 13:57:47 +00:00
|
|
|
return ad_util.Zero
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-31 13:27:19 -07:00
|
|
|
lo, hi, interior = zip(*padding_config)
|
2018-11-17 18:03:33 -08:00
|
|
|
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
|
|
|
|
|
2019-03-14 16:35:23 +00:00
|
|
|
def t_op():
|
2020-07-14 13:05:31 -07:00
|
|
|
unpad_config = safe_zip(np.negative(lo), np.negative(hi),
|
|
|
|
np.zeros_like(interior))
|
|
|
|
unpadded = pad(t, np.array(0., t.dtype), unpad_config)
|
|
|
|
return slice(unpadded, np.zeros_like(lo), unpadded.shape, np.add(interior, 1))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
t_operand = t_op() if ad.is_undefined_primal(operand) else None
|
|
|
|
t_padv = sub(total(t), total(t_operand)) if ad.is_undefined_primal(padding_value) else None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
return [t_operand, t_padv]
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _pad_batch_rule(batched_args, batch_dims, *, padding_config):
|
2018-12-11 14:00:58 -08:00
|
|
|
operand, padding_value = batched_args
|
|
|
|
operand_bdim, padding_value_bdim = batch_dims
|
|
|
|
if padding_value_bdim is None:
|
|
|
|
assert operand_bdim is not None
|
|
|
|
padding_config = list(padding_config)
|
|
|
|
padding_config.insert(operand_bdim, (0, 0, 0))
|
|
|
|
return pad(operand, padding_value, padding_config), operand_bdim
|
|
|
|
else:
|
2018-12-14 08:42:02 -08:00
|
|
|
raise NotImplementedError # loop and stack
|
2018-12-11 14:00:58 -08:00
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
def _pad_translation_rule(c, operand, padding_value, *, padding_config):
|
|
|
|
return xops.Pad(operand, padding_value,
|
|
|
|
xc.make_padding_config(padding_config))
|
|
|
|
|
2020-06-03 22:40:48 +02:00
|
|
|
def _pad_masking_rule(padded_vals, logical_shapes, padding_config):
|
|
|
|
operand, padding_value = padded_vals
|
|
|
|
shape, _ = logical_shapes
|
|
|
|
|
|
|
|
out = pad(operand, padding_value, padding_config)
|
|
|
|
out_shape = [lo + shape[i] * (interior + 1)
|
|
|
|
for i, (lo, hi, interior) in enumerate(padding_config)]
|
|
|
|
padded_dims = [i for i, config in enumerate(padding_config)
|
|
|
|
if config != (0, 0, 0)]
|
|
|
|
return _masked(out, out_shape, padded_dims, padding_value)
|
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
|
|
|
|
translation_rule=_pad_translation_rule)
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.deflinear(pad_p, _pad_transpose)
|
|
|
|
ad.primitive_transposes[pad_p] = _pad_transpose
|
|
|
|
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
2020-06-03 22:40:48 +02:00
|
|
|
masking.masking_rules[pad_p] = _pad_masking_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
# The squeeze primitive exists for the benefit of masking and other
|
|
|
|
# transformations that need to keep track of axis identity.
|
|
|
|
# For example, consider reshaping a 2D array with shape (1, N) into a 1D array
|
|
|
|
# with shape (N,). This results in the following JAXpr:
|
|
|
|
# reshape[ dimension=None new_sizes=(N,) ]
|
|
|
|
# For N > 1, we can match up the output array axis with the second axis of the
|
|
|
|
# input. But for N = 1, it is not clear how axes match up: all we know from the
|
|
|
|
# JAXpr is that we are reshaping from (1, 1) to (1,).
|
|
|
|
# In constrast, squeeze[ dimensions=(0,) ] is unambiguous.
|
|
|
|
|
|
|
|
def squeeze(array: Array, dimensions: Tuple[int, ...]) -> Array:
|
|
|
|
"""Squeeze any number of size 1 dimensions from an array."""
|
2020-07-14 13:05:31 -07:00
|
|
|
ndim = np.ndim(array)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
dimensions = tuple(sorted(_canonicalize_axis(i, ndim) for i in dimensions))
|
|
|
|
if not dimensions:
|
|
|
|
return array
|
|
|
|
return squeeze_p.bind(array, dimensions=dimensions)
|
|
|
|
|
|
|
|
def _squeeze_dtype_rule(operand, *, dimensions):
|
|
|
|
return operand.dtype
|
|
|
|
|
|
|
|
def _squeeze_shape_rule(operand, *, dimensions):
|
2020-07-14 13:05:31 -07:00
|
|
|
return _compute_squeeze_shape(np.shape(operand), dimensions)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
|
|
|
|
def _compute_squeeze_shape(shape, dimensions):
|
|
|
|
dims_set = set(dimensions)
|
|
|
|
if len(dims_set) != len(dimensions):
|
|
|
|
raise ValueError(f"dimensions are not unique: {dimensions}")
|
|
|
|
if not all(0 <= d < len(shape) for d in dims_set):
|
|
|
|
raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
|
|
|
|
if any(shape[d] != 1 for d in dimensions):
|
|
|
|
raise ValueError(
|
|
|
|
"cannot select an axis to squeeze out which has size not equal to "
|
|
|
|
f"one, got shape={shape} and dimensions={dimensions}")
|
|
|
|
return tuple(s for i, s in enumerate(shape) if i not in dims_set)
|
|
|
|
|
|
|
|
def _squeeze_translation_rule(c, arg, *, dimensions):
|
|
|
|
new_shape = _compute_squeeze_shape(c.get_shape(arg).dimensions(), dimensions)
|
|
|
|
return xops.Reshape(arg, new_shape)
|
|
|
|
|
|
|
|
def _squeeze_transpose_rule(t, operand, *, dimensions):
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
return [expand_dims(t, dimensions)]
|
|
|
|
|
|
|
|
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
operand = batching.moveaxis(operand, bdim, 0)
|
2020-07-14 13:05:31 -07:00
|
|
|
dimensions = tuple(np.add(1, dimensions))
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
return squeeze(operand, dimensions=dimensions), 0
|
|
|
|
|
|
|
|
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
|
|
|
|
'squeeze', _squeeze_translation_rule)
|
|
|
|
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
|
|
|
|
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
|
|
|
|
|
|
|
|
|
|
|
def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array:
|
|
|
|
"""Insert any number of size 1 dimensions into an array."""
|
2020-07-14 13:05:31 -07:00
|
|
|
ndim_out = np.ndim(array) + len(dimensions)
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
dims_set = frozenset(_canonicalize_axis(i, ndim_out) for i in dimensions)
|
2020-07-14 13:05:31 -07:00
|
|
|
result_shape = list(np.shape(array))
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
for i in sorted(dims_set):
|
|
|
|
result_shape.insert(i, 1)
|
|
|
|
broadcast_dims = [i for i in range(ndim_out) if i not in dims_set]
|
|
|
|
return broadcast_in_dim(array, result_shape, broadcast_dims)
|
|
|
|
|
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
# We have a nonstandard reshape impl so that we can be lazy about data movement.
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reshape_impl(operand, *, new_sizes, dimensions):
|
2020-07-14 13:05:31 -07:00
|
|
|
old_sizes = np.shape(operand)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
if type(operand) is xla.DeviceArray and dimensions is None:
|
|
|
|
bcast_dims = _is_singleton_reshape(old_sizes, new_sizes)
|
|
|
|
if bcast_dims is not None:
|
|
|
|
aval = ShapedArray(new_sizes, operand.dtype)
|
|
|
|
lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)
|
2020-05-01 10:06:59 +03:00
|
|
|
return xla.DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
|
2020-04-15 12:43:55 -07:00
|
|
|
return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,
|
|
|
|
dimensions=dimensions)
|
2019-07-06 10:00:08 -07:00
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
def _is_singleton_reshape(old, new):
|
|
|
|
# A singleton reshape is one where only singleton dimensions are added. We
|
|
|
|
# want to detect them because they can be expressed as (lazy) broadcasts.
|
|
|
|
old, new = iter(old), iter(new)
|
|
|
|
d1, d2 = next(old, None), next(new, None)
|
|
|
|
bcast_dims = []
|
|
|
|
i = 0
|
|
|
|
while True:
|
|
|
|
if d1 is d2 is None:
|
|
|
|
return bcast_dims
|
|
|
|
elif d1 == d2:
|
|
|
|
bcast_dims.append(i)
|
|
|
|
i += 1
|
|
|
|
d1, d2 = next(old, None), next(new, None)
|
|
|
|
elif d2 == 1:
|
|
|
|
i += 1
|
|
|
|
d2 = next(new, None)
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reshape_shape_rule(operand, *, new_sizes, dimensions):
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.greater_equal(new_sizes, 0)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = 'reshape new_sizes must all be positive, got {}.'
|
|
|
|
raise TypeError(msg.format(new_sizes))
|
2020-07-14 13:05:31 -07:00
|
|
|
if prod(np.shape(operand)) != prod(new_sizes):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
|
2020-07-14 13:05:31 -07:00
|
|
|
raise TypeError(msg.format(new_sizes, np.shape(operand)))
|
2018-11-17 18:03:33 -08:00
|
|
|
if dimensions is not None:
|
2020-07-14 13:05:31 -07:00
|
|
|
if set(dimensions) != set(range(np.ndim(operand))):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ('reshape dimensions must be a permutation of operand dimensions, '
|
|
|
|
'got dimensions {} for shape {}.')
|
2020-07-14 13:05:31 -07:00
|
|
|
raise TypeError(msg.format(dimensions, np.shape(operand)))
|
2018-11-17 18:03:33 -08:00
|
|
|
return tuple(new_sizes)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reshape_dtype_rule(operand, *, new_sizes, dimensions):
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand.dtype
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reshape_translation_rule(c, operand, *, new_sizes, dimensions):
|
2020-04-23 18:30:47 -04:00
|
|
|
if dimensions is None:
|
|
|
|
return xops.Reshape(operand, new_sizes)
|
|
|
|
else:
|
|
|
|
return xops.Reshape(operand, dimensions, new_sizes)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(operand)
|
2018-11-17 18:03:33 -08:00
|
|
|
if dimensions is None:
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
return [reshape(t, operand.aval.shape)]
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)),
|
|
|
|
np.argsort(dimensions))]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
|
2018-11-17 18:03:33 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
2019-07-27 15:46:14 -07:00
|
|
|
operand = batching.moveaxis(operand, bdim, 0)
|
2018-11-17 18:03:33 -08:00
|
|
|
if dimensions is not None:
|
2020-07-14 13:05:31 -07:00
|
|
|
dimensions = (0,) + tuple(np.add(1, dimensions))
|
2018-11-17 18:03:33 -08:00
|
|
|
return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
|
|
|
|
'reshape', _reshape_translation_rule)
|
2019-07-06 10:00:08 -07:00
|
|
|
reshape_p.def_impl(_reshape_impl)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
ad.deflinear2(reshape_p, _reshape_transpose_rule)
|
2019-02-01 13:42:16 -05:00
|
|
|
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _rev_shape_rule(operand, *, dimensions):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_shapelike('rev', 'dimensions', dimensions)
|
|
|
|
if len(set(dimensions)) != len(dimensions):
|
|
|
|
msg = 'rev dimensions must be unique, got {}.'
|
|
|
|
raise TypeError(msg.format(dimensions))
|
Mare the reverse operator work on empty list of dimensions
Example that this fixes:
```
from jax import lax
import jax.numpy as np
from jax.api import jacrev
x = np.ones((3, 5))
def f(x):
return lax.conv_general_dilated(lhs=x,
rhs=np.ones((5, 2)),
window_strides=(),
padding='VALID',
dimension_numbers=('NC', 'IO', 'NC'))
jacrev(f)(x)
```
currently gives
```
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-136-2ad65e41f1de> in <module>()
12 dimension_numbers=('NC', 'IO', 'NC'))
13
---> 14 jacrev(f)(x).shape
15 frames
google3/third_party/py/jax/api.py in jacfun(*args, **kwargs)
514 y, pullback = vjp(f_partial, *dyn_args)
515 holomorphic or tree_map(_check_real_output_jacrev, y)
--> 516 jac = vmap(pullback)(_std_basis(y))
517 jac = jac[0] if isinstance(argnums, int) else jac
518 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
google3/third_party/py/jax/api.py in batched_fun(*args)
692 _check_axis_sizes(in_tree, args_flat, in_axes_flat)
693 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 694 lambda: _flatten_axes(out_tree(), out_axes))
695 return tree_unflatten(out_tree(), out_flat)
696
google3/third_party/py/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
38 def batch(fun, in_vals, in_dims, out_dim_dests):
39 size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 40 out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
41 return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
42
google3/third_party/py/jax/interpreters/batching.py in batch_fun(fun, in_vals, in_dims)
44 with new_master(BatchTrace) as master:
45 fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 46 out_vals = fun.call_wrapped(*in_vals)
47 del master
48 return out_vals, out_dims()
google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
150 gen = None
151
--> 152 ans = self.f(*args, **dict(self.params, **kwargs))
153 del args
154 while stack:
google3/third_party/py/jax/api.py in _vjp_pullback_wrapper(fun, cotangent_dtypes, io_tree, py_args)
1237 "match type of corresponding primal output ({})")
1238 raise TypeError(msg.format(_dtype(a), dtype))
-> 1239 ans = fun(*args)
1240 return tree_unflatten(out_tree, ans)
1241
google3/third_party/py/jax/interpreters/ad.py in vjp_(*cts)
114 dummy_primals_and_cts = (core.unit,) * len(cts) + cts
115 dummy_args = (undefined_primal,) * len(jaxpr.invars)
--> 116 _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts)
117 arg_cts = arg_cts[len(primals):]
118 return map(instantiate_zeros, primals, arg_cts)
google3/third_party/py/jax/interpreters/ad.py in backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in)
222 map(write_cotangent, bound_vars, ct_free_vars_out)
223 else:
--> 224 cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
225 cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
226 map(write_cotangent, eqn.invars, cts_out)
google3/third_party/py/jax/interpreters/ad.py in bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs)
505 assert (x is undefined_primal) ^ (y is undefined_primal)
506 if x is undefined_primal:
--> 507 out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
508 return out, None
509 else:
google3/third_party/py/jax/lax/lax.py in _conv_general_dilated_transpose_lhs(g, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, lhs_shape, rhs_shape, precision)
2042 window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation,
2043 rhs_dilation)
-> 2044 revd_weights = rev(rhs, rhs_sdims)
2045 return conv_general_dilated(
2046 g, revd_weights, window_strides=lhs_dilation, padding=padding,
google3/third_party/py/jax/lax/lax.py in rev(operand, dimensions)
671 operator.
672 """
--> 673 return rev_p.bind(operand, dimensions=tuple(dimensions))
674
675 def select(pred, on_true, on_false):
google3/third_party/py/jax/core.py in bind(self, *args, **kwargs)
157 top_trace = find_top_trace(args)
158 if top_trace is None:
--> 159 return self.impl(*args, **kwargs)
160
161 tracers = map(top_trace.full_raise, args)
google3/third_party/py/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
159 def apply_primitive(prim, *args, **params):
160 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 161 compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
162 return compiled_fun(*args)
163
google3/third_party/py/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params)
167 device = _device_from_arg_devices(arg_devices)
168 backend = xb.get_device_backend(device)
--> 169 aval_out = prim.abstract_eval(*avals, **params)
170 if not prim.multiple_results:
171 handle_result = aval_to_result_handler(device, aval_out)
google3/third_party/py/jax/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs)
1540 return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
1541 elif least_specialized is ShapedArray:
-> 1542 return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
1543 elif least_specialized is UnshapedArray:
1544 return UnshapedArray(dtype_rule(*args, **kwargs))
google3/third_party/py/jax/lax/lax.py in _rev_shape_rule(operand, dimensions)
2620 msg = 'rev dimensions must be unique, got {}.'
2621 raise TypeError(msg.format(dimensions))
-> 2622 if not _max(dimensions) < operand.ndim:
2623 msg = ('rev dimensions must all be less than operand ndim, got dimensions '
2624 '{} for operand ndim {}.')
ValueError: max() arg is an empty sequence
```
2020-01-27 00:16:04 -08:00
|
|
|
if dimensions and not _max(dimensions) < operand.ndim:
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ('rev dimensions must all be less than operand ndim, got dimensions '
|
|
|
|
'{} for operand ndim {}.')
|
|
|
|
raise TypeError(msg.format(dimensions, operand.ndim))
|
|
|
|
return operand.shape
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
|
2019-02-01 16:29:53 -05:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
|
|
|
|
return rev(operand, new_dimensions), bdim
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.deflinear(rev_p, lambda t, dimensions: [rev(t, dimensions)])
|
2019-02-01 16:29:53 -05:00
|
|
|
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _transpose_impl(operand, *, permutation):
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
if type(operand) is xla.DeviceArray:
|
|
|
|
lazy_expr = lazy.transpose(operand._lazy_expr, permutation)
|
|
|
|
aval = ShapedArray(lazy_expr.shape, operand.dtype)
|
2020-05-01 10:06:59 +03:00
|
|
|
return xla.DeviceArray(aval, operand._device, lazy_expr, operand.device_buffer)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
else:
|
|
|
|
return xla.apply_primitive(transpose_p, operand, permutation=permutation)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _transpose_shape_rule(operand, *, permutation):
|
2020-07-14 13:05:31 -07:00
|
|
|
if not isinstance(permutation, (tuple, list, np.ndarray)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "transpose permutation must be a tuple/list/ndarray, got {}."
|
|
|
|
raise TypeError(msg.format(type(permutation)))
|
|
|
|
if tuple(sorted(permutation)) != tuple(range(operand.ndim)):
|
|
|
|
msg = ("transpose permutation isn't a permutation of operand dimensions, "
|
|
|
|
"got permutation {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(permutation, operand.shape))
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.take(operand.shape, permutation))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
2018-11-17 18:03:33 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
2019-02-12 07:26:32 -08:00
|
|
|
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
|
2018-11-17 18:03:33 -08:00
|
|
|
return transpose(operand, perm), 0
|
|
|
|
|
2020-06-03 22:40:48 +02:00
|
|
|
def _transpose_masking_rule(padded_vals, logical_shapes, permutation):
|
|
|
|
return transpose(*padded_vals, permutation=permutation)
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
2018-11-17 18:03:33 -08:00
|
|
|
'transpose')
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
transpose_p.def_impl(_transpose_impl)
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.deflinear(transpose_p,
|
2020-07-14 13:05:31 -07:00
|
|
|
lambda t, permutation: [transpose(t, np.argsort(permutation))])
|
2019-02-01 13:42:16 -05:00
|
|
|
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
2020-06-03 22:40:48 +02:00
|
|
|
masking.masking_rules[transpose_p] = _transpose_masking_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_shape_rule(pred, on_true, on_false):
|
2018-11-17 18:03:33 -08:00
|
|
|
if on_true.shape != on_false.shape:
|
|
|
|
msg = "select on_true and on_false must have the same shape, got {} and {}."
|
|
|
|
raise TypeError(msg.format(on_true.shape, on_false.shape))
|
|
|
|
if pred.shape and pred.shape != on_true.shape:
|
|
|
|
msg = ("select pred must be scalar or have the same shape as on_true and "
|
|
|
|
"on_false, got pred shape {} for on_true and on_false of shape {}.")
|
|
|
|
raise TypeError(msg.format(pred.shape, on_true.shape))
|
|
|
|
return on_true.shape
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_dtype_rule(pred, on_true, on_false):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_same_dtypes("select", False, on_true.dtype, on_false.dtype)
|
2020-07-14 13:05:31 -07:00
|
|
|
if not dtypes.issubdtype(pred.dtype, np.bool_):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "select pred must be boolean type, got {}."
|
|
|
|
raise TypeError(msg.format(pred.dtype))
|
|
|
|
return on_true.dtype
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_transpose_rule(t, pred, on_true, on_false):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert not ad.is_undefined_primal(pred)
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return ad_util.Zero
|
2018-12-19 09:40:40 -08:00
|
|
|
else:
|
|
|
|
zeros = full_like(t, 0)
|
|
|
|
return [None,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
select(pred, t, zeros) if ad.is_undefined_primal(on_true) else None,
|
|
|
|
select(pred, zeros, t) if ad.is_undefined_primal(on_false) else None]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
2019-02-03 09:27:03 -08:00
|
|
|
pred, on_true, on_false, = batched_args
|
2018-11-26 12:30:01 -08:00
|
|
|
pred_bdim, ot_bdim, of_bdim = batch_dims
|
2019-02-03 09:27:03 -08:00
|
|
|
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
|
|
|
if i is not None)
|
2018-11-26 12:30:01 -08:00
|
|
|
|
2019-02-03 10:01:06 -08:00
|
|
|
# avoid transposes and some broadcasts in special cases
|
|
|
|
if pred_bdim == ot_bdim == of_bdim:
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.shape(pred) == np.shape(on_true):
|
2019-02-03 10:01:06 -08:00
|
|
|
return select(pred, on_true, on_false), pred_bdim
|
|
|
|
else:
|
|
|
|
# vmapped function had a scalar pred with nonscalar args
|
2020-07-14 13:05:31 -07:00
|
|
|
assert np.ndim(pred) == 1
|
2019-02-03 10:01:06 -08:00
|
|
|
pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim])
|
|
|
|
return select(pred, on_true, on_false), pred_bdim
|
2020-07-14 13:05:31 -07:00
|
|
|
elif np.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None:
|
2019-02-03 14:00:51 -08:00
|
|
|
if ot_bdim == of_bdim:
|
|
|
|
return select(pred, on_true, on_false), ot_bdim
|
2020-07-14 13:05:31 -07:00
|
|
|
elif np.shape(on_true) == np.shape(on_false):
|
2019-07-27 15:46:14 -07:00
|
|
|
on_false = batching.moveaxis(on_false, of_bdim, ot_bdim)
|
2019-02-03 14:00:51 -08:00
|
|
|
return select(pred, on_true, on_false), ot_bdim
|
2019-02-03 10:01:06 -08:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
pred = batching.bdim_at_front(pred, pred_bdim, size) if np.shape(pred) else pred
|
|
|
|
if not np.shape(on_true) == np.shape(on_false) == ():
|
2019-07-27 15:46:14 -07:00
|
|
|
on_true = batching.bdim_at_front(on_true, ot_bdim, size)
|
|
|
|
on_false = batching.bdim_at_front(on_false, of_bdim, size)
|
2020-07-14 13:05:31 -07:00
|
|
|
assert np.shape(on_true) == np.shape(on_false)
|
|
|
|
if 0 < np.ndim(pred) < np.ndim(on_true):
|
2019-02-03 09:52:33 -08:00
|
|
|
# vmapped function had a scalar pred with nonscalar args
|
2020-07-14 13:05:31 -07:00
|
|
|
assert np.ndim(pred) == 1
|
2019-02-03 09:52:33 -08:00
|
|
|
pred = broadcast_in_dim(pred, on_true.shape, [0])
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.ndim(pred) > np.ndim(on_true):
|
|
|
|
assert np.ndim(on_true) == 0
|
2019-07-27 15:46:14 -07:00
|
|
|
on_true = broadcast(on_true, pred.shape)
|
|
|
|
on_false = broadcast(on_false, pred.shape)
|
2019-02-03 09:27:03 -08:00
|
|
|
return select(pred, on_true, on_false), 0
|
2018-11-26 12:30:01 -08:00
|
|
|
|
2020-06-03 22:40:48 +02:00
|
|
|
def _select_masking_rule(padded_vals, logical_shapes):
|
|
|
|
pred_shape, true_shape, false_shape = [
|
|
|
|
masking.padded_shape_as_value(val.shape) for val in padded_vals]
|
2020-07-14 13:05:31 -07:00
|
|
|
assert np.array_equal(pred_shape, true_shape)
|
|
|
|
assert np.array_equal(pred_shape, false_shape)
|
2020-06-03 22:40:48 +02:00
|
|
|
return select(*padded_vals)
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.defjvp(select_p,
|
|
|
|
None,
|
|
|
|
lambda g, b, x, y: select(b, g, _zeros(g)),
|
|
|
|
lambda g, b, x, y: select(b, _zeros(g), g))
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.primitive_transposes[select_p] = _select_transpose_rule
|
|
|
|
batching.primitive_batchers[select_p] = _select_batch_rule
|
2020-06-03 22:40:48 +02:00
|
|
|
masking.masking_rules[select_p] = _select_masking_rule
|
2019-04-03 15:13:04 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_shapelike("slice", "start_indices", start_indices)
|
|
|
|
_check_shapelike("slice", "limit_indices", limit_indices)
|
|
|
|
if operand.ndim != len(start_indices):
|
|
|
|
msg = ("slice start_indices must have length equal to the number of "
|
|
|
|
"dimensions of the operand, got indices {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(start_indices, operand.shape))
|
|
|
|
if len(start_indices) != len(limit_indices):
|
|
|
|
msg = ("slice limit_indices must have the same length as start_indices, "
|
|
|
|
"got start_inidices {} and limit_indices {}.")
|
|
|
|
raise TypeError(msg.format(start_indices, limit_indices))
|
2020-06-03 22:40:48 +02:00
|
|
|
if (not masking.is_polymorphic(limit_indices) and
|
|
|
|
not masking.is_polymorphic(operand.shape) and
|
2020-07-14 13:05:31 -07:00
|
|
|
not np.all(np.less_equal(limit_indices, operand.shape))):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("slice limit_indices must be less than or equal to operand shape, "
|
|
|
|
"got limit_indices {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(limit_indices, operand.shape))
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.greater_equal(start_indices, 0)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("slice start_indices must be greater than or equal to zero, "
|
|
|
|
"got start_indices of {}.")
|
|
|
|
raise TypeError(msg.format(start_indices))
|
2020-06-03 22:40:48 +02:00
|
|
|
if (not masking.is_polymorphic(limit_indices) and
|
2020-07-14 13:05:31 -07:00
|
|
|
not np.all(np.greater_equal(limit_indices, start_indices))):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("slice limit_indices must be greater than or equal to start_indices,"
|
|
|
|
" got start_indices {} and limit_indices {}.")
|
|
|
|
raise TypeError(msg.format(start_indices, limit_indices))
|
|
|
|
if strides is None:
|
2020-07-14 13:05:31 -07:00
|
|
|
strides = np.ones(operand.ndim, np.int32)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
_check_shapelike("slice", "strides", strides)
|
|
|
|
if len(strides) != operand.ndim:
|
|
|
|
msg = ("slice strides must have length equal to the number of dimensions "
|
|
|
|
"of the operand, got strides {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(strides, operand.shape))
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.greater(strides, 0)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "slice strides must be positive, got {}"
|
|
|
|
raise TypeError(msg.format(strides))
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
result_shape = np.floor_divide(
|
|
|
|
np.add(np.subtract(limit_indices, start_indices), strides) - 1, strides)
|
2018-11-17 18:03:33 -08:00
|
|
|
return tuple(result_shape)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _slice_translation_rule(c, operand, *, start_indices, limit_indices,
|
|
|
|
strides):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Slice(operand, start_indices, limit_indices,
|
|
|
|
strides or [1] * len(start_indices))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
operand_shape = operand.aval.shape
|
2020-07-14 13:05:31 -07:00
|
|
|
if strides is None or np.all(np.equal(strides, 1)):
|
|
|
|
pads = zip(start_indices, np.subtract(operand_shape, limit_indices),
|
2018-11-17 18:03:33 -08:00
|
|
|
(0,) * len(start_indices))
|
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
real_limits = np.add(np.add(start_indices, 1),
|
|
|
|
np.multiply(np.subtract(t.shape, 1), strides))
|
|
|
|
pads = safe_zip(start_indices, np.subtract(operand_shape, real_limits),
|
|
|
|
np.subtract(strides, 1))
|
2018-11-17 18:03:33 -08:00
|
|
|
result = pad(t, _const(t, 0), pads)
|
|
|
|
assert result.shape == operand_shape
|
|
|
|
return [result]
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
|
|
|
|
limit_indices, strides):
|
2018-11-19 07:43:23 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
|
|
|
|
new_start_indices = list(start_indices)
|
|
|
|
new_start_indices.insert(bdim, 0)
|
|
|
|
|
|
|
|
new_limit_indices = list(limit_indices)
|
|
|
|
new_limit_indices.insert(bdim, operand.shape[bdim])
|
|
|
|
|
|
|
|
if strides is None:
|
|
|
|
new_strides = None
|
|
|
|
else:
|
|
|
|
new_strides = list(strides)
|
|
|
|
new_strides.insert(bdim, 1)
|
|
|
|
|
|
|
|
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
|
|
|
|
return out, bdim
|
|
|
|
|
2020-06-03 22:40:48 +02:00
|
|
|
def _slice_masking_rule(
|
|
|
|
padded_vals, logical_shapes, start_indices, limit_indices, strides):
|
|
|
|
operand, = padded_vals
|
|
|
|
return slice(operand,
|
|
|
|
start_indices=masking.padded_shape_as_value(start_indices),
|
|
|
|
limit_indices=masking.padded_shape_as_value(limit_indices),
|
|
|
|
strides=strides)
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
|
|
|
|
_slice_translation_rule)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
ad.deflinear2(slice_p, _slice_transpose_rule)
|
2019-02-01 13:42:16 -05:00
|
|
|
batching.primitive_batchers[slice_p] = _slice_batching_rule
|
2020-06-03 22:40:48 +02:00
|
|
|
masking.masking_rules[slice_p] = _slice_masking_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
|
2018-11-17 18:03:33 -08:00
|
|
|
if operand.ndim != len(start_indices):
|
|
|
|
msg = ("dynamic_slice start_indices must have length equal to the number "
|
|
|
|
"of dimensions of the operand, got indices {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(start_indices, operand.shape))
|
|
|
|
if len(start_indices) != len(slice_sizes):
|
|
|
|
msg = ("dynamic_slice slice_sizes must have the same length as "
|
|
|
|
"start_indices, got start_inidices length {} and slice_sizes {}.")
|
|
|
|
raise TypeError(msg.format(len(start_indices), slice_sizes))
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.less_equal(slice_sizes, operand.shape)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("slice slice_sizes must be less than or equal to operand shape, "
|
|
|
|
"got slice_sizes {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(slice_sizes, operand.shape))
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.greater_equal(slice_sizes, 0)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("slice slice_sizes must be greater than or equal to zero, "
|
|
|
|
"got slice_sizes of {}.")
|
|
|
|
raise TypeError(msg.format(slice_sizes))
|
|
|
|
return tuple(slice_sizes)
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes):
|
2019-11-14 15:51:27 -05:00
|
|
|
if any(i.dtype != start_indices[0].dtype or
|
2020-07-14 13:05:31 -07:00
|
|
|
not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices):
|
2019-11-14 15:51:27 -05:00
|
|
|
msg = ("index arguments to dynamic_slice must be integers of the same "
|
|
|
|
"type, got: {}")
|
|
|
|
raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices)))
|
|
|
|
return operand.dtype
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _dynamic_slice_translation_rule(c, operand, *start_indices, slice_sizes):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.DynamicSlice(operand, start_indices, slice_sizes)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dynamic_slice_jvp(primals, tangents, *, slice_sizes):
|
2020-05-27 13:57:47 +00:00
|
|
|
tangent_out = tangents[0]
|
|
|
|
if type(tangent_out) is not ad_util.Zero:
|
|
|
|
tangent_out = dynamic_slice(tangent_out, primals[1:], slice_sizes)
|
2019-08-15 11:26:30 -04:00
|
|
|
return dynamic_slice(primals[0], primals[1:], slice_sizes), tangent_out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
assert all(not ad.is_undefined_primal(s) for s in start_indices)
|
|
|
|
operand_shape = operand.aval.shape
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
zeros = full(operand_shape, _zero(t))
|
|
|
|
else:
|
|
|
|
zeros = full(operand_shape, tie_in(t, _zero(t)))
|
2019-08-15 11:26:30 -04:00
|
|
|
return ([dynamic_update_slice(zeros, t, start_indices)] +
|
2019-08-15 12:33:36 -04:00
|
|
|
[None] * len(start_indices))
|
2019-08-15 11:26:30 -04:00
|
|
|
|
|
|
|
def _batch_dynamic_slice_indices(indices, bdims):
|
2020-07-29 03:39:32 +02:00
|
|
|
if len(indices) == 0:
|
|
|
|
return np.array([], 'int32'), None
|
2019-08-15 11:42:08 -04:00
|
|
|
size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), -1)
|
|
|
|
if size < 0:
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
return concatenate([broadcast(i, (1,)) for i in indices], 0), None
|
2019-08-15 11:26:30 -04:00
|
|
|
indices = concatenate(
|
2019-08-15 12:24:38 -04:00
|
|
|
[broadcast_in_dim(x, (size, 1),
|
|
|
|
broadcast_dimensions=((0,) if i is not None else ()))
|
|
|
|
for x, i in zip(indices, bdims)],
|
|
|
|
dimension=1)
|
2019-08-15 11:26:30 -04:00
|
|
|
return indices, 0
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
|
2019-03-13 11:10:50 -04:00
|
|
|
# A dynamic slice is a special case of gather; we can delegate to the gather
|
|
|
|
# batching rule.
|
|
|
|
# TODO(phawkins): consider removing dynamic_slice entirely and using gather
|
|
|
|
# always.
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
operand, *start_indices = batched_args
|
|
|
|
operand_bd, *start_idx_bds = batch_dims
|
|
|
|
operand_shape = (operand.shape if operand_bd is batching.not_mapped
|
2020-07-14 13:05:31 -07:00
|
|
|
else tuple(np.delete(operand.shape, operand_bd)))
|
2019-03-13 11:10:50 -04:00
|
|
|
dims = tuple(range(len(operand_shape)))
|
|
|
|
dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(),
|
|
|
|
start_index_map=dims)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds)
|
2019-08-15 11:26:30 -04:00
|
|
|
return _gather_batching_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
[operand, index], [operand_bd, index_bdim], dimension_numbers=dnums,
|
|
|
|
slice_sizes=slice_sizes)
|
2019-03-13 11:10:50 -04:00
|
|
|
|
2018-12-23 09:28:23 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
dynamic_slice_p = standard_primitive(
|
2019-11-14 15:51:27 -05:00
|
|
|
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
|
2019-02-01 13:42:16 -05:00
|
|
|
_dynamic_slice_translation_rule)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
|
|
|
|
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dynamic_update_slice_shape_rule(operand, update, *start_indices):
|
2018-11-17 18:03:33 -08:00
|
|
|
if operand.ndim != update.ndim:
|
|
|
|
msg = ("dynamic_update_slice update must have the same rank as operand, "
|
|
|
|
"got update shape {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(update.shape, operand.shape))
|
|
|
|
if operand.ndim != len(start_indices):
|
|
|
|
msg = ("dynamic_update_slice start_indices must have length equal to the "
|
|
|
|
"rank of operand, got indices {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(start_indices, operand.shape))
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.less_equal(update.shape, operand.shape)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = ("dynamic_update_slice update shape must be smaller than operand "
|
|
|
|
"shape, got update shape {} for operand shape {}.")
|
|
|
|
raise TypeError(msg.format(update.shape, operand.shape))
|
|
|
|
return operand.shape
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_same_dtypes("dynamic_update_slice", False, operand.dtype, update.dtype)
|
2019-11-14 15:51:27 -05:00
|
|
|
if any(i.dtype != start_indices[0].dtype or
|
2020-07-14 13:05:31 -07:00
|
|
|
not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices):
|
2019-11-14 15:51:27 -05:00
|
|
|
msg = ("index arguments to dynamic_update_slice must be integers of the "
|
|
|
|
"same type, got {}")
|
|
|
|
raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices)))
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand.dtype
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _dynamic_update_slice_jvp(primals, tangents):
|
2019-08-15 11:26:30 -04:00
|
|
|
operand, update = primals[:2]
|
|
|
|
start_indices = primals[2:]
|
|
|
|
g_operand, g_update = tangents[:2]
|
2018-11-17 18:03:33 -08:00
|
|
|
val_out = dynamic_update_slice(operand, update, start_indices)
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(val_out)
|
2018-11-28 17:17:52 -08:00
|
|
|
else:
|
2020-05-28 13:20:56 +00:00
|
|
|
g_operand = ad.instantiate_zeros(g_operand)
|
|
|
|
g_update = ad.instantiate_zeros(g_update)
|
2018-11-28 17:17:52 -08:00
|
|
|
tangent_out = dynamic_update_slice(g_operand, g_update, start_indices)
|
2018-11-17 18:03:33 -08:00
|
|
|
return val_out, tangent_out
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices):
|
|
|
|
assert all(not ad.is_undefined_primal(x) for x in start_indices)
|
|
|
|
if ad.is_undefined_primal(update):
|
|
|
|
update_shape = update.aval.shape
|
|
|
|
else:
|
|
|
|
update_shape = update.shape
|
2018-11-17 18:03:33 -08:00
|
|
|
dus = dynamic_update_slice
|
|
|
|
ds = dynamic_slice
|
|
|
|
zeros = _zeros(t, shape=update_shape)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
operand_t = dus(t, zeros, start_indices) if ad.is_undefined_primal(operand) else None
|
|
|
|
update_t = ds(t, start_indices, update_shape) if ad.is_undefined_primal(update) else None
|
2019-08-15 11:26:30 -04:00
|
|
|
return [operand_t, update_t] + [None] * len(start_indices)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _dynamic_update_slice_translation_rule(c, operand, update, *start_indices):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.DynamicUpdateSlice(operand, update, start_indices)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
|
2019-04-30 11:48:53 -04:00
|
|
|
# A dynamic update slice is a special case of scatter; we can delegate to the
|
|
|
|
# scatter batching rule.
|
|
|
|
# TODO(phawkins): consider removing dynamic_update_slice entirely and using
|
|
|
|
# scatter always.
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
operand, update, *start_idx = batched_args
|
|
|
|
operand_bd, update_bd, *start_idx_bd = batch_dims
|
2020-07-29 03:39:32 +02:00
|
|
|
update_shape = (np.shape(update) if update_bd is batching.not_mapped
|
|
|
|
else tuple(np.delete(np.shape(update), update_bd)))
|
2019-04-30 11:48:53 -04:00
|
|
|
dims = tuple(range(len(update_shape)))
|
|
|
|
dnums = ScatterDimensionNumbers(update_window_dims=dims,
|
|
|
|
inserted_window_dims=(),
|
|
|
|
scatter_dims_to_operand_dims=dims)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd)
|
2019-04-30 11:48:53 -04:00
|
|
|
return _scatter_batching_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
scatter, (operand, index, update), (operand_bd, index_bdim, update_bd),
|
2020-07-21 23:16:27 -07:00
|
|
|
update_jaxpr=None, update_consts=None, dimension_numbers=dnums,
|
|
|
|
indices_are_sorted=True, unique_indices=True)
|
2019-04-30 11:48:53 -04:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
dynamic_update_slice_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
|
|
|
|
'dynamic_update_slice', _dynamic_update_slice_translation_rule)
|
|
|
|
ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.primitive_transposes[dynamic_update_slice_p] = \
|
2019-02-01 13:42:16 -05:00
|
|
|
_dynamic_update_slice_transpose_rule
|
2019-04-30 11:48:53 -04:00
|
|
|
batching.primitive_batchers[dynamic_update_slice_p] = \
|
|
|
|
_dynamic_update_slice_batching_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-03-01 10:34:46 -05:00
|
|
|
def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
2019-01-08 21:34:48 -05:00
|
|
|
assert type(dimension_numbers) is GatherDimensionNumbers
|
2019-07-29 15:21:47 -04:00
|
|
|
proto = xla_client.GatherDimensionNumbers()
|
2019-01-08 21:34:48 -05:00
|
|
|
proto.offset_dims.extend(dimension_numbers.offset_dims)
|
|
|
|
proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims)
|
|
|
|
proto.start_index_map.extend(dimension_numbers.start_index_map)
|
2019-03-01 11:05:04 -05:00
|
|
|
assert indices_shape.rank() > 0
|
|
|
|
proto.index_vector_dim = indices_shape.rank() - 1
|
2019-01-08 21:34:48 -05:00
|
|
|
return proto
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _gather_dtype_rule(operand, start_indices, **kwargs):
|
2020-07-14 13:05:31 -07:00
|
|
|
if not dtypes.issubdtype(start_indices.dtype, np.integer):
|
2019-01-08 21:34:48 -05:00
|
|
|
raise ValueError("start_indices must have an integer type")
|
2019-11-15 10:02:51 -05:00
|
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _gather_shape_rule(operand, start_indices, *, dimension_numbers,
|
|
|
|
slice_sizes):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if len(operand.shape) != len(slice_sizes):
|
2019-01-14 10:28:35 -05:00
|
|
|
msg = ("slice_sizes must have rank equal to the gather operand; "
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
"operand.shape={}, slice_sizes={}".format(operand.shape, slice_sizes))
|
2019-01-14 10:28:35 -05:00
|
|
|
raise ValueError(msg)
|
2019-06-05 17:04:33 -07:00
|
|
|
result_rank = len(dimension_numbers.offset_dims) + start_indices.ndim - 1
|
|
|
|
start_indices_shape = iter(start_indices.shape[:-1])
|
2020-07-14 13:05:31 -07:00
|
|
|
slice_sizes = iter(np.delete(slice_sizes, dimension_numbers.collapsed_slice_dims))
|
2019-06-05 17:04:33 -07:00
|
|
|
return tuple(next(slice_sizes) if i in dimension_numbers.offset_dims
|
|
|
|
else next(start_indices_shape) for i in range(result_rank))
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _gather_translation_rule(c, operand, start_indices, *, dimension_numbers,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
slice_sizes):
|
2020-05-11 17:43:55 -04:00
|
|
|
indices_shape = c.get_shape(start_indices)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Gather(
|
2019-03-01 10:34:46 -05:00
|
|
|
operand, start_indices,
|
2020-04-23 18:30:47 -04:00
|
|
|
_gather_dimensions_proto(indices_shape, dimension_numbers), slice_sizes,
|
|
|
|
indices_are_sorted=False)
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _gather_jvp_rule(g, operand, start_indices, *, dimension_numbers,
|
|
|
|
slice_sizes):
|
2019-01-08 21:34:48 -05:00
|
|
|
return gather(g, start_indices, dimension_numbers, slice_sizes)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
slice_sizes):
|
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
operand_shape = operand.aval.shape
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return ad_util.Zero
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
zeros = full(operand_shape, _zero(t))
|
|
|
|
else:
|
|
|
|
zeros = full(operand_shape, tie_in(t, _zero(t)))
|
2019-01-08 21:34:48 -05:00
|
|
|
scatter_dnums = ScatterDimensionNumbers(
|
2019-01-09 10:58:44 -05:00
|
|
|
update_window_dims=dimension_numbers.offset_dims,
|
2019-01-08 21:34:48 -05:00
|
|
|
inserted_window_dims=dimension_numbers.collapsed_slice_dims,
|
2019-03-01 10:34:46 -05:00
|
|
|
scatter_dims_to_operand_dims=dimension_numbers.start_index_map)
|
2020-07-21 23:16:27 -07:00
|
|
|
out = scatter_add(zeros, start_indices, t, scatter_dnums,
|
|
|
|
indices_are_sorted=False,
|
|
|
|
unique_indices=False)
|
|
|
|
return [out, ad_util.Zero.from_value(start_indices)]
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
slice_sizes):
|
2019-02-03 09:00:16 -08:00
|
|
|
operand, start_indices = batched_args
|
|
|
|
operand_bdim, start_indices_bdim = batch_dims
|
|
|
|
|
|
|
|
if operand_bdim is not None and start_indices_bdim is None:
|
2019-07-27 15:46:14 -07:00
|
|
|
operand = batching.moveaxis(operand, operand_bdim, 0)
|
2019-02-06 10:58:41 -08:00
|
|
|
slice_sizes = (operand.shape[0],) + slice_sizes
|
2020-07-14 13:05:31 -07:00
|
|
|
offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims))
|
|
|
|
collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
|
|
|
|
start_index_map = tuple(np.add(1, dimension_numbers.start_index_map))
|
2019-02-03 09:00:16 -08:00
|
|
|
dnums = GatherDimensionNumbers(
|
2019-02-06 10:58:41 -08:00
|
|
|
offset_dims=offset_dims,
|
2019-02-03 09:00:16 -08:00
|
|
|
collapsed_slice_dims=collapsed_slice_dims,
|
2019-03-01 10:34:46 -05:00
|
|
|
start_index_map=start_index_map)
|
2019-02-06 10:58:41 -08:00
|
|
|
return gather(operand, start_indices, dimension_numbers=dnums,
|
|
|
|
slice_sizes=slice_sizes), 0
|
2019-02-03 09:00:16 -08:00
|
|
|
|
2019-02-06 10:58:41 -08:00
|
|
|
elif operand_bdim is None and start_indices_bdim is not None:
|
2019-07-27 15:46:14 -07:00
|
|
|
start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0)
|
2020-07-14 13:05:31 -07:00
|
|
|
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
|
2019-02-06 10:58:41 -08:00
|
|
|
dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=offset_dims,
|
|
|
|
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
|
2019-03-01 10:34:46 -05:00
|
|
|
start_index_map=dimension_numbers.start_index_map)
|
2019-02-03 09:00:16 -08:00
|
|
|
return gather(operand, start_indices, dimension_numbers=dnums,
|
2019-02-06 10:58:41 -08:00
|
|
|
slice_sizes=slice_sizes), 0
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-02-03 09:00:16 -08:00
|
|
|
else:
|
2020-07-10 09:29:06 -07:00
|
|
|
# move batch dimensions to the front to simplify logic
|
2019-07-27 15:46:14 -07:00
|
|
|
operand = batching.moveaxis(operand, operand_bdim, 0)
|
|
|
|
start_indices = batching.moveaxis(start_indices, start_indices_bdim, 0)
|
2019-02-10 18:36:21 -08:00
|
|
|
|
2019-03-01 10:34:46 -05:00
|
|
|
# Example: user code had start_indices shape (3, 4, 5), and we have to deal
|
|
|
|
# with start_indices shape (7, 3, 4, 5). We transform that to a
|
|
|
|
# start_indices of shape (7, 3, 4, 6) where we concatenated an iota that
|
|
|
|
# counts along our batch dimension to the front of the ndindex.
|
2019-02-11 11:30:44 -08:00
|
|
|
count_shape = list(start_indices.shape)
|
2019-03-01 10:34:46 -05:00
|
|
|
count_shape[-1] = 1
|
2019-02-11 11:30:44 -08:00
|
|
|
counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0)
|
2019-03-01 11:59:54 -05:00
|
|
|
start_indices = concatenate([counts, start_indices], len(count_shape) - 1)
|
2019-02-11 09:28:21 -08:00
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
slice_sizes = (1,) + slice_sizes
|
2020-07-14 13:05:31 -07:00
|
|
|
collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
|
|
|
|
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
|
|
|
|
start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map))
|
2019-02-11 09:28:21 -08:00
|
|
|
|
2019-02-10 18:36:21 -08:00
|
|
|
dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=offset_dims,
|
|
|
|
collapsed_slice_dims=collapsed_slice_dims,
|
2019-03-01 10:34:46 -05:00
|
|
|
start_index_map=start_index_map)
|
2019-02-10 18:36:21 -08:00
|
|
|
return gather(operand, start_indices, dimension_numbers=dnums,
|
|
|
|
slice_sizes=slice_sizes), 0
|
2019-01-08 21:34:48 -05:00
|
|
|
|
|
|
|
gather_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_gather_shape_rule, _gather_dtype_rule, 'gather',
|
|
|
|
_gather_translation_rule)
|
|
|
|
ad.defjvp(gather_p, _gather_jvp_rule, None)
|
2020-02-28 20:40:47 -05:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
2019-02-03 09:00:16 -08:00
|
|
|
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2019-06-23 20:01:53 -07:00
|
|
|
|
2019-03-01 10:34:46 -05:00
|
|
|
def _scatter_dimensions_proto(indices_shape, dimension_numbers):
|
2019-01-08 21:34:48 -05:00
|
|
|
assert type(dimension_numbers) is ScatterDimensionNumbers
|
2019-07-29 15:21:47 -04:00
|
|
|
proto = xla_client.ScatterDimensionNumbers()
|
2019-01-08 21:34:48 -05:00
|
|
|
proto.update_window_dims.extend(dimension_numbers.update_window_dims)
|
|
|
|
proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims)
|
|
|
|
proto.scatter_dims_to_operand_dims.extend(
|
|
|
|
dimension_numbers.scatter_dims_to_operand_dims)
|
2019-03-01 11:05:04 -05:00
|
|
|
assert indices_shape.rank() > 0
|
|
|
|
proto.index_vector_dim = indices_shape.rank() - 1
|
2019-01-08 21:34:48 -05:00
|
|
|
return proto
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _scatter_dtype_rule(operand, scatter_indices, updates, **kwargs):
|
2020-07-14 13:05:31 -07:00
|
|
|
if not dtypes.issubdtype(scatter_indices.dtype, np.integer):
|
2019-02-22 08:39:18 -05:00
|
|
|
raise ValueError("scatter_indices must have an integer type")
|
2019-01-08 21:34:48 -05:00
|
|
|
_check_same_dtypes("scatter", False, operand.dtype, updates.dtype)
|
2019-11-15 10:02:51 -05:00
|
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _scatter_shape_rule(operand, scatter_indices, updates, **kwargs):
|
2019-01-08 21:34:48 -05:00
|
|
|
return operand.shape
|
|
|
|
|
2020-07-21 23:16:27 -07:00
|
|
|
def _scatter_translation_rule(c, operand, scatter_indices, updates, *,
|
|
|
|
update_jaxpr, update_consts, dimension_numbers,
|
|
|
|
indices_are_sorted, unique_indices):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(operand).numpy_dtype()
|
2020-07-14 13:05:31 -07:00
|
|
|
init_value = xb.constant(c, np.array(0, dtype))
|
2019-01-08 21:34:48 -05:00
|
|
|
update_computation = _reduction_computation(
|
2019-12-18 11:18:33 -08:00
|
|
|
c, update_jaxpr, update_consts, init_value)
|
2020-05-11 17:43:55 -04:00
|
|
|
indices_shape = c.get_shape(scatter_indices)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Scatter(operand, scatter_indices, updates, update_computation,
|
|
|
|
_scatter_dimensions_proto(indices_shape, dimension_numbers),
|
2020-07-21 23:16:27 -07:00
|
|
|
indices_are_sorted, unique_indices)
|
2019-01-08 21:34:48 -05:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
2020-07-21 23:16:27 -07:00
|
|
|
dimension_numbers, indices_are_sorted, unique_indices):
|
2019-01-09 10:58:44 -05:00
|
|
|
operand, scatter_indices, updates = primals
|
|
|
|
g_operand, g_scatter_indices, g_updates = tangents
|
2019-03-01 15:41:49 -05:00
|
|
|
val_out = scatter_add_p.bind(
|
2019-01-09 10:58:44 -05:00
|
|
|
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(val_out)
|
2019-01-09 10:58:44 -05:00
|
|
|
else:
|
2020-05-28 13:20:56 +00:00
|
|
|
g_operand = ad.instantiate_zeros(g_operand)
|
|
|
|
g_updates = ad.instantiate_zeros(g_updates)
|
2019-03-01 15:41:49 -05:00
|
|
|
tangent_out = scatter_add_p.bind(
|
2019-01-09 10:58:44 -05:00
|
|
|
g_operand, scatter_indices, g_updates, update_jaxpr=update_jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2019-01-09 10:58:44 -05:00
|
|
|
return val_out, tangent_out
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _scatter_add_transpose_rule(t, operand, scatter_indices, updates, *,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_jaxpr, update_consts, dimension_numbers,
|
|
|
|
indices_are_sorted, unique_indices):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert not ad.is_undefined_primal(scatter_indices)
|
|
|
|
if ad.is_undefined_primal(updates):
|
|
|
|
updates_shape = updates.aval.shape
|
|
|
|
else:
|
|
|
|
updates_shape = updates.shape
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return ad_util.Zero
|
2019-05-28 15:41:27 -04:00
|
|
|
|
2019-01-09 10:58:44 -05:00
|
|
|
operand_t = update_t = None
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if ad.is_undefined_primal(operand):
|
2019-01-09 10:58:44 -05:00
|
|
|
operand_t = t
|
|
|
|
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
if ad.is_undefined_primal(updates):
|
2019-01-09 10:58:44 -05:00
|
|
|
gather_dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=dimension_numbers.update_window_dims,
|
|
|
|
collapsed_slice_dims=dimension_numbers.inserted_window_dims,
|
2019-03-01 10:34:46 -05:00
|
|
|
start_index_map=dimension_numbers.scatter_dims_to_operand_dims)
|
2019-01-14 10:28:35 -05:00
|
|
|
slice_sizes = []
|
|
|
|
pos = 0
|
2020-01-08 13:17:55 -05:00
|
|
|
for i in range(len(t.shape)):
|
2019-01-14 10:28:35 -05:00
|
|
|
if i in dimension_numbers.inserted_window_dims:
|
|
|
|
slice_sizes.append(1)
|
|
|
|
else:
|
|
|
|
slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]])
|
|
|
|
pos += 1
|
2019-01-09 10:58:44 -05:00
|
|
|
update_t = gather(t, scatter_indices, dimension_numbers=gather_dnums,
|
|
|
|
slice_sizes=slice_sizes)
|
2019-01-14 14:33:40 -05:00
|
|
|
return [operand_t, None, update_t]
|
2019-01-09 10:58:44 -05:00
|
|
|
|
2020-04-13 16:16:34 -04:00
|
|
|
def _scatter_mul_transpose_rule(t, operand, scatter_indices, updates, *,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_jaxpr, update_consts, dimension_numbers,
|
|
|
|
indices_are_sorted, unique_indices):
|
2020-04-13 16:16:34 -04:00
|
|
|
assert not ad.is_undefined_primal(scatter_indices)
|
|
|
|
if ad.is_undefined_primal(updates):
|
|
|
|
updates_shape = updates.aval.shape
|
|
|
|
else:
|
|
|
|
updates_shape = updates.shape
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(t) is ad_util.Zero:
|
|
|
|
return ad_util.Zero
|
2020-04-13 16:16:34 -04:00
|
|
|
|
|
|
|
operand_t = update_t = None
|
|
|
|
if ad.is_undefined_primal(operand):
|
2020-07-21 23:16:27 -07:00
|
|
|
operand_t = scatter_mul(
|
|
|
|
t, scatter_indices, updates, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2020-04-13 16:16:34 -04:00
|
|
|
|
|
|
|
if ad.is_undefined_primal(updates):
|
|
|
|
gather_dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=dimension_numbers.update_window_dims,
|
|
|
|
collapsed_slice_dims=dimension_numbers.inserted_window_dims,
|
|
|
|
start_index_map=dimension_numbers.scatter_dims_to_operand_dims)
|
|
|
|
slice_sizes = []
|
|
|
|
pos = 0
|
|
|
|
for i in range(len(t.shape)):
|
|
|
|
if i in dimension_numbers.inserted_window_dims:
|
|
|
|
slice_sizes.append(1)
|
|
|
|
else:
|
|
|
|
slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]])
|
|
|
|
pos += 1
|
|
|
|
update_t = gather(mul(t, operand), scatter_indices,
|
|
|
|
dimension_numbers=gather_dnums, slice_sizes=slice_sizes)
|
|
|
|
return [operand_t, None, update_t]
|
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_jaxpr, update_consts, dimension_numbers,
|
|
|
|
indices_are_sorted, unique_indices):
|
2019-02-11 10:24:21 -08:00
|
|
|
operand, scatter_indices, updates = batched_args
|
|
|
|
operand_bdim, scatter_indices_bdim, updates_bdim = batch_dims
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
del update_jaxpr, update_consts # Unused.
|
2019-02-11 10:24:21 -08:00
|
|
|
|
2019-02-11 12:46:17 -08:00
|
|
|
# move the operand batch dim to the front if it is not None, otherwise create
|
|
|
|
# it at the front (so that we can scatter into it)
|
2019-02-11 10:24:21 -08:00
|
|
|
size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
|
|
|
|
if ax is not None)
|
2019-07-27 15:46:14 -07:00
|
|
|
operand = batching.bdim_at_front(operand, operand_bdim, size)
|
2019-02-11 10:24:21 -08:00
|
|
|
operand_bdim = 0
|
|
|
|
|
2019-12-17 21:42:37 -05:00
|
|
|
updates = batching.bdim_at_front(updates, updates_bdim, size)
|
2019-05-29 17:13:46 -04:00
|
|
|
|
2019-12-17 21:42:37 -05:00
|
|
|
if scatter_indices_bdim is None:
|
2020-07-14 13:05:31 -07:00
|
|
|
inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims))
|
|
|
|
update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims))
|
|
|
|
scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims))
|
2019-02-11 10:24:21 -08:00
|
|
|
dnums = ScatterDimensionNumbers(
|
|
|
|
update_window_dims=update_window_dims,
|
|
|
|
inserted_window_dims=inserted_window_dims,
|
2019-03-01 10:34:46 -05:00
|
|
|
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
2020-07-21 23:16:27 -07:00
|
|
|
return scatter_op(
|
|
|
|
operand, scatter_indices, updates, dnums,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices), 0
|
2019-02-11 10:24:21 -08:00
|
|
|
|
|
|
|
|
2019-12-17 21:42:37 -05:00
|
|
|
# see the third case in _gather_batching_rule for comparison and comments
|
|
|
|
scatter_indices = batching.bdim_at_front(
|
|
|
|
scatter_indices, scatter_indices_bdim, size)
|
2019-02-11 10:24:21 -08:00
|
|
|
|
2019-12-17 21:42:37 -05:00
|
|
|
count_shape = list(scatter_indices.shape)
|
|
|
|
count_shape[-1] = 1
|
|
|
|
counts = broadcasted_iota(scatter_indices.dtype, tuple(count_shape), 0)
|
|
|
|
scatter_indices = concatenate([counts, scatter_indices],
|
|
|
|
len(count_shape) - 1)
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims))
|
|
|
|
inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims))
|
|
|
|
scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims))
|
2019-12-17 21:42:37 -05:00
|
|
|
|
|
|
|
dnums = ScatterDimensionNumbers(
|
|
|
|
update_window_dims=update_window_dims,
|
|
|
|
inserted_window_dims=inserted_window_dims,
|
|
|
|
scatter_dims_to_operand_dims=scatter_dims_to_operand_dims)
|
2020-07-21 23:16:27 -07:00
|
|
|
return scatter_op(
|
|
|
|
operand, scatter_indices, updates, dnums,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices), 0
|
2019-01-09 10:58:44 -05:00
|
|
|
|
2019-12-18 11:18:33 -08:00
|
|
|
scatter_add_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-add',
|
2019-08-21 00:22:53 -07:00
|
|
|
_scatter_translation_rule)
|
2019-03-01 15:41:49 -05:00
|
|
|
ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp
|
|
|
|
ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
|
|
|
|
batching.primitive_batchers[scatter_add_p] = (
|
|
|
|
partial(_scatter_batching_rule, scatter_add))
|
|
|
|
|
2020-04-13 16:16:34 -04:00
|
|
|
|
|
|
|
scatter_mul_p = standard_primitive(
|
|
|
|
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul',
|
|
|
|
_scatter_translation_rule)
|
|
|
|
|
2020-07-21 23:16:27 -07:00
|
|
|
def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers,
|
|
|
|
indices_are_sorted, unique_indices, **kw):
|
|
|
|
return mul(x, scatter_add(
|
|
|
|
zeros_like_array(x), i, g, dimension_numbers=dimension_numbers,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices))
|
2020-04-13 16:16:34 -04:00
|
|
|
|
|
|
|
ad.defjvp(scatter_mul_p,
|
|
|
|
lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw),
|
|
|
|
None,
|
|
|
|
_scatter_mul_jvp_rhs)
|
|
|
|
ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
|
|
|
|
batching.primitive_batchers[scatter_mul_p] = (
|
|
|
|
partial(_scatter_batching_rule, scatter_mul))
|
|
|
|
|
2020-05-19 07:06:32 +01:00
|
|
|
def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts, dimension_numbers,
|
|
|
|
indices_are_sorted, unique_indices):
|
2020-05-19 07:06:32 +01:00
|
|
|
operand, scatter_indices, updates = primals
|
|
|
|
g_operand, g_scatter_indices, g_updates = tangents
|
|
|
|
|
|
|
|
scatter_dnums = dimension_numbers
|
|
|
|
updates_shape = updates.shape
|
|
|
|
|
|
|
|
val_out = scatter_op.bind(
|
|
|
|
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=update_consts, dimension_numbers=scatter_dnums,
|
|
|
|
indices_are_sorted=indices_are_sorted,
|
|
|
|
unique_indices=unique_indices)
|
2020-05-19 07:06:32 +01:00
|
|
|
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(val_out)
|
2020-05-19 07:06:32 +01:00
|
|
|
else:
|
2020-05-28 13:20:56 +00:00
|
|
|
g_operand = ad.instantiate_zeros(g_operand)
|
|
|
|
g_updates = ad.instantiate_zeros(g_updates)
|
2020-05-19 07:06:32 +01:00
|
|
|
|
|
|
|
# gather_dnums and slice_sizes define the gather op that is the inverse of
|
|
|
|
# the scatter op specified by scatter_dnums
|
|
|
|
gather_dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=scatter_dnums.update_window_dims,
|
|
|
|
collapsed_slice_dims=scatter_dnums.inserted_window_dims,
|
|
|
|
start_index_map=scatter_dnums.scatter_dims_to_operand_dims)
|
|
|
|
|
|
|
|
slice_sizes = []
|
|
|
|
pos = 0
|
|
|
|
for i in range(len(operand.shape)):
|
|
|
|
if i in scatter_dnums.inserted_window_dims:
|
|
|
|
slice_sizes.append(1)
|
|
|
|
else:
|
|
|
|
slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]])
|
|
|
|
pos += 1
|
|
|
|
|
|
|
|
# For consistency with other max operations, if there are two or more values
|
|
|
|
# in updates that are contending to replace the same index location, the
|
|
|
|
# resulting tangent at that location will be the average of the associated
|
|
|
|
# tangents for the values in updates.
|
|
|
|
|
|
|
|
initial_vals = gather(
|
2020-07-14 13:05:31 -07:00
|
|
|
operand, scatter_indices, gather_dnums, np.array(slice_sizes))
|
2020-05-19 07:06:32 +01:00
|
|
|
|
|
|
|
target_vals = gather(
|
2020-07-14 13:05:31 -07:00
|
|
|
val_out, scatter_indices, gather_dnums, np.array(slice_sizes))
|
2020-05-19 07:06:32 +01:00
|
|
|
|
|
|
|
successful_updates = (updates == target_vals)
|
|
|
|
retained_values = (initial_vals == target_vals)
|
|
|
|
|
|
|
|
num_updates = gather(
|
|
|
|
scatter_add(_zeros(operand),
|
|
|
|
scatter_indices,
|
|
|
|
select(successful_updates, _ones(updates), _zeros(updates)),
|
|
|
|
scatter_dnums),
|
|
|
|
scatter_indices,
|
|
|
|
gather_dnums,
|
2020-07-14 13:05:31 -07:00
|
|
|
np.array(slice_sizes))
|
2020-05-19 07:06:32 +01:00
|
|
|
|
|
|
|
num_refs = gather(
|
|
|
|
scatter_add(_zeros(operand),
|
|
|
|
scatter_indices,
|
|
|
|
_ones(updates),
|
|
|
|
scatter_dnums),
|
|
|
|
scatter_indices,
|
|
|
|
gather_dnums,
|
2020-07-14 13:05:31 -07:00
|
|
|
np.array(slice_sizes))
|
2020-05-19 07:06:32 +01:00
|
|
|
|
|
|
|
updates_normalizer = select(retained_values,
|
|
|
|
1.0 / (num_updates + 1),
|
|
|
|
1.0 / num_updates)
|
|
|
|
|
|
|
|
updates_coef = select(successful_updates,
|
|
|
|
updates_normalizer,
|
|
|
|
_zeros(updates))
|
|
|
|
|
|
|
|
operand_normalizer = select(retained_values,
|
|
|
|
1.0 / (num_updates + 1),
|
|
|
|
_zeros(num_updates))
|
|
|
|
|
|
|
|
operand_coef = (-1.0 + operand_normalizer) / num_refs
|
|
|
|
|
|
|
|
# This can be simplified once scatter has transpose implemented
|
|
|
|
target_tangents = gather(
|
2020-07-14 13:05:31 -07:00
|
|
|
g_operand, scatter_indices, gather_dnums, np.array(slice_sizes))
|
2020-05-19 07:06:32 +01:00
|
|
|
|
|
|
|
tangent_updates = (target_tangents * operand_coef +
|
|
|
|
g_updates * updates_coef)
|
|
|
|
|
|
|
|
tangent_out = scatter_add(g_operand,
|
|
|
|
scatter_indices,
|
|
|
|
tangent_updates,
|
2020-07-21 23:16:27 -07:00
|
|
|
scatter_dnums,
|
|
|
|
indices_are_sorted=indices_are_sorted,
|
|
|
|
unique_indices=unique_indices)
|
2020-05-19 07:06:32 +01:00
|
|
|
|
|
|
|
return val_out, tangent_out
|
|
|
|
|
2019-12-18 11:18:33 -08:00
|
|
|
scatter_min_p = standard_primitive(
|
2019-06-21 19:31:41 -07:00
|
|
|
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
|
2019-08-21 00:22:53 -07:00
|
|
|
_scatter_translation_rule)
|
2019-06-21 19:31:41 -07:00
|
|
|
batching.primitive_batchers[scatter_min_p] = (
|
|
|
|
partial(_scatter_batching_rule, scatter_min))
|
2020-05-19 07:06:32 +01:00
|
|
|
ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p)
|
2019-06-21 19:31:41 -07:00
|
|
|
|
2019-12-18 11:18:33 -08:00
|
|
|
scatter_max_p = standard_primitive(
|
2019-06-21 19:31:41 -07:00
|
|
|
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
|
2019-08-21 00:22:53 -07:00
|
|
|
_scatter_translation_rule)
|
2019-06-21 19:31:41 -07:00
|
|
|
batching.primitive_batchers[scatter_max_p] = (
|
|
|
|
partial(_scatter_batching_rule, scatter_max))
|
2020-05-19 07:06:32 +01:00
|
|
|
ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p)
|
2019-03-01 15:41:49 -05:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
2020-07-21 23:16:27 -07:00
|
|
|
dimension_numbers, indices_are_sorted, unique_indices):
|
2019-03-01 15:41:49 -05:00
|
|
|
operand, scatter_indices, updates = primals
|
|
|
|
g_operand, g_scatter_indices, g_updates = tangents
|
|
|
|
dnums = dimension_numbers
|
|
|
|
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero:
|
2019-03-01 15:41:49 -05:00
|
|
|
val_out = scatter_p.bind(
|
|
|
|
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=update_consts, dimension_numbers=dnums,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2020-05-27 13:57:47 +00:00
|
|
|
return val_out, ad_util.Zero.from_value(val_out)
|
2019-03-01 15:41:49 -05:00
|
|
|
|
2020-05-28 13:20:56 +00:00
|
|
|
g_operand = ad.instantiate_zeros(g_operand)
|
|
|
|
g_updates = ad.instantiate_zeros(g_updates)
|
2019-12-20 16:09:55 -05:00
|
|
|
|
2019-03-01 15:41:49 -05:00
|
|
|
# If there are overlapping indices in the scatter, it is unspecified which
|
|
|
|
# update "wins". So we use the following perhaps surprising scheme:
|
|
|
|
# a) attach a positive ID to each update in updates, forming (value, id) pairs
|
|
|
|
# (using a new array dimension because scatter doesn't actually support
|
|
|
|
# pairs).
|
|
|
|
# b) perform the scatter, yielding (value, id) updates, which we split apart.
|
|
|
|
# c) perform the inverse gather on the ids (similar to
|
|
|
|
# _scatter_add_transpose), and use it to build a mask for the tangent of
|
|
|
|
# `updates`.
|
|
|
|
# d) perform a scatter-add on the masked JVP values. A benefit of using
|
|
|
|
# scatter-add here is that we don't need a `scatter` transpose rule.
|
|
|
|
|
|
|
|
# a) add unique positive IDs (iotas) to the updates, and zeros to the operand.
|
|
|
|
operand_shape = operand.shape
|
|
|
|
updates_shape = updates.shape
|
|
|
|
updates_dtype = _dtype(updates)
|
|
|
|
|
|
|
|
new_operand = reshape(operand, (1,) + operand_shape)
|
|
|
|
new_operand = pad(new_operand, _zero(operand),
|
|
|
|
((0, 1, 0),) + tuple((0, 0, 0) for _ in operand_shape))
|
|
|
|
|
2020-02-14 18:09:52 -08:00
|
|
|
# We specify the dtype here in case `updates_shape` is an empty tuple, in
|
|
|
|
# which case numpy defaults to float64.
|
2020-07-14 13:05:31 -07:00
|
|
|
ids_shape = np.array(updates_shape, dtype=np.int32)
|
2019-03-01 15:41:49 -05:00
|
|
|
ids_shape[dnums.update_window_dims,] = 1
|
2020-07-14 13:05:31 -07:00
|
|
|
num_ids = np.prod(ids_shape)
|
2019-03-01 15:41:49 -05:00
|
|
|
update_ids = add(reshape(iota(updates_dtype, num_ids), ids_shape),
|
|
|
|
_ones(updates))
|
|
|
|
|
|
|
|
# TODO(phawkins): there is a potential bug here if the number of updates
|
|
|
|
# is large enough to overflow the number of mantissa bits in a float so IDs
|
|
|
|
# end up colliding. We could also utilize the exponent and sign bits, with a
|
|
|
|
# little more work.
|
2019-11-15 10:02:51 -05:00
|
|
|
assert num_ids < (2 ** dtypes.finfo(updates_dtype).nmant)
|
2019-03-01 15:41:49 -05:00
|
|
|
|
|
|
|
updates = reshape(updates, (1,) + updates_shape)
|
|
|
|
reshaped_update_ids = reshape(update_ids, (1,) + updates_shape)
|
|
|
|
updates_and_ids = concatenate((updates, reshaped_update_ids), 0)
|
|
|
|
|
|
|
|
new_dnums = ScatterDimensionNumbers(
|
|
|
|
update_window_dims=(0,) + tuple(d + 1 for d in dnums.update_window_dims),
|
|
|
|
inserted_window_dims=tuple(d + 1 for d in dnums.inserted_window_dims),
|
|
|
|
scatter_dims_to_operand_dims=tuple(d + 1 for d in dnums.scatter_dims_to_operand_dims))
|
|
|
|
outputs = scatter_p.bind(
|
|
|
|
new_operand, scatter_indices, updates_and_ids, update_jaxpr=update_jaxpr,
|
2020-07-21 23:16:27 -07:00
|
|
|
update_consts=update_consts, dimension_numbers=new_dnums,
|
|
|
|
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
|
2019-03-01 15:41:49 -05:00
|
|
|
val_out = index_in_dim(outputs, 0, keepdims=False)
|
|
|
|
scattered_ids = index_in_dim(outputs, 1, keepdims=False)
|
|
|
|
|
|
|
|
# b) compute the inverse gather that "undoes" the scatter on the id values.
|
|
|
|
gather_dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=dnums.update_window_dims,
|
|
|
|
collapsed_slice_dims=dnums.inserted_window_dims,
|
|
|
|
start_index_map=dnums.scatter_dims_to_operand_dims)
|
|
|
|
slice_sizes = []
|
|
|
|
pos = 0
|
2020-01-08 13:17:55 -05:00
|
|
|
for i in range(len(scattered_ids.shape)):
|
2019-03-01 15:41:49 -05:00
|
|
|
if i in dnums.inserted_window_dims:
|
|
|
|
slice_sizes.append(1)
|
|
|
|
else:
|
|
|
|
slice_sizes.append(updates_shape[dnums.update_window_dims[pos]])
|
|
|
|
pos += 1
|
|
|
|
gathered_update_ids = gather(scattered_ids, scatter_indices,
|
|
|
|
dimension_numbers=gather_dnums,
|
|
|
|
slice_sizes=slice_sizes)
|
|
|
|
|
|
|
|
# c) mask off input JVP elements that do not correspond to a primal output.
|
|
|
|
masked_g_operand = select(eq(scattered_ids, _zeros(scattered_ids)),
|
|
|
|
g_operand, _zeros(g_operand))
|
|
|
|
masked_g_updates = select(eq(update_ids, gathered_update_ids),
|
|
|
|
g_updates, _zeros(g_updates))
|
|
|
|
|
|
|
|
# d) perform a scatter-add to compute the tangent output.
|
|
|
|
tangent_out = scatter_add(masked_g_operand, scatter_indices, masked_g_updates,
|
2020-07-21 23:16:27 -07:00
|
|
|
dimension_numbers=dnums,
|
|
|
|
indices_are_sorted=indices_are_sorted,
|
|
|
|
unique_indices=unique_indices)
|
2019-03-01 15:41:49 -05:00
|
|
|
return val_out, tangent_out
|
|
|
|
|
|
|
|
|
2019-12-18 11:18:33 -08:00
|
|
|
scatter_p = standard_primitive(
|
2019-03-01 15:41:49 -05:00
|
|
|
_scatter_shape_rule, _scatter_dtype_rule, 'scatter',
|
2019-08-21 00:22:53 -07:00
|
|
|
_scatter_translation_rule)
|
2019-02-01 13:42:16 -05:00
|
|
|
ad.primitive_jvps[scatter_p] = _scatter_jvp
|
2019-03-01 15:41:49 -05:00
|
|
|
batching.primitive_batchers[scatter_p] = (
|
|
|
|
partial(_scatter_batching_rule, scatter))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-06-09 13:09:50 -07:00
|
|
|
def _reduce_shape_rule(operand, init_value, *, computation, jaxpr, consts,
|
|
|
|
dimensions):
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.delete(operand.shape, dimensions))
|
2020-06-09 13:09:50 -07:00
|
|
|
|
|
|
|
def _reduce_translation_rule(c, operand, init_value, *, computation, jaxpr,
|
|
|
|
consts, dimensions):
|
|
|
|
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
|
|
|
return xops.Reduce(c, [operand], [init_value], xla_computation, dimensions)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, consts,
|
|
|
|
dimensions):
|
2018-12-14 08:42:02 -08:00
|
|
|
operand, init_value = batched_args
|
|
|
|
operand_bdim, init_value_bdim = batch_dims
|
2020-06-09 13:09:50 -07:00
|
|
|
if init_value_bdim is None:
|
|
|
|
assert operand_bdim is not None
|
|
|
|
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
|
2020-07-14 13:05:31 -07:00
|
|
|
new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim)))
|
2020-06-09 13:09:50 -07:00
|
|
|
return reduce(operand, init_value, computation, new_dimensions), new_operand_bdim
|
|
|
|
else:
|
|
|
|
raise NotImplementedError # loop and stack
|
2018-12-14 08:42:02 -08:00
|
|
|
|
2019-12-18 11:18:33 -08:00
|
|
|
def _reduction_computation(c, jaxpr, consts, init_value):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(init_value)
|
2020-07-30 12:59:36 -07:00
|
|
|
axis_env = xla.AxisEnv(1, (), (), None) # no parallel primitives inside reductions
|
2019-08-29 20:25:02 -07:00
|
|
|
subc = xla_bridge.make_computation_builder("reduction_computation")
|
2020-04-17 14:38:50 -04:00
|
|
|
assert len(consts) == 0, "Reduction computations cannot have constants"
|
2020-04-23 18:30:47 -04:00
|
|
|
args = [xb.parameter(subc, 0, shape), xb.parameter(subc, 1, shape)]
|
2020-01-07 13:11:32 -08:00
|
|
|
out, = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
|
2020-05-11 17:43:55 -04:00
|
|
|
return subc.build(out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-09-16 15:47:43 -07:00
|
|
|
def _masking_defreducer(prim, identity):
|
|
|
|
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
|
2019-09-03 17:09:27 -07:00
|
|
|
|
2019-09-16 15:47:43 -07:00
|
|
|
def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
|
2020-06-03 22:40:48 +02:00
|
|
|
axes, input_shape=None):
|
2019-09-03 17:09:27 -07:00
|
|
|
(padded_val,), (logical_shape,) = padded_vals, logical_shapes
|
2019-09-13 16:30:22 -07:00
|
|
|
padded_shape = masking.padded_shape_as_value(padded_val.shape)
|
2020-07-14 13:05:31 -07:00
|
|
|
masks = [broadcasted_iota(np.int32, padded_shape, i) < d
|
2019-09-03 17:09:27 -07:00
|
|
|
for i, d in enumerate(logical_shape) if i in axes]
|
|
|
|
mask = _reduce(operator.and_, masks)
|
2019-09-16 15:47:43 -07:00
|
|
|
masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype))
|
2020-06-03 22:40:48 +02:00
|
|
|
bind = prim.bind if input_shape is None else partial(prim.bind, input_shape=padded_shape)
|
|
|
|
return bind(masked_val, axes=axes)
|
2019-09-03 17:09:27 -07:00
|
|
|
|
2020-06-09 13:09:50 -07:00
|
|
|
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
|
|
|
|
_reduce_translation_rule)
|
2019-08-03 21:27:06 -07:00
|
|
|
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-12-16 20:48:19 -05:00
|
|
|
def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
2020-07-14 13:05:31 -07:00
|
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
2019-12-16 20:48:19 -05:00
|
|
|
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
|
2020-07-14 13:05:31 -07:00
|
|
|
"of number.".format(name, np.dtype(operand.dtype).name))
|
2019-12-16 20:48:19 -05:00
|
|
|
return dtypes.canonicalize_dtype(operand.dtype)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_sum_shape_rule(operand, *, axes):
|
|
|
|
return _reduce_op_shape_rule(operand, axes=axes)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_sum_translation_rule(c, operand, *, axes):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(operand).numpy_dtype()
|
2019-11-12 18:38:07 -08:00
|
|
|
scalar = ShapedArray((), dtype)
|
2020-07-14 13:05:31 -07:00
|
|
|
return xops.Reduce(c, [operand], [xb.constant(c, np.array(0, dtype))],
|
2020-04-23 18:30:47 -04:00
|
|
|
xla.primitive_subcomputation(add_p, scalar, scalar),
|
|
|
|
axes)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
input_shape = operand.aval.shape
|
2020-07-14 13:05:31 -07:00
|
|
|
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
|
2018-11-17 18:03:33 -08:00
|
|
|
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
|
|
|
|
assert result.shape == input_shape
|
|
|
|
return [result]
|
|
|
|
|
2019-12-16 20:48:19 -05:00
|
|
|
reduce_sum_p = standard_primitive(
|
|
|
|
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
|
|
|
'reduce_sum', _reduce_sum_translation_rule)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
batching.defreducer(reduce_sum_p)
|
2019-09-16 15:47:43 -07:00
|
|
|
_masking_defreducer(reduce_sum_p,
|
2020-07-14 13:05:31 -07:00
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(0, dtype), shape))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-06-03 22:40:48 +02:00
|
|
|
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
2020-06-30 21:18:46 -07:00
|
|
|
del input_shape # Unused.
|
|
|
|
if len(axes) != len(set(axes)):
|
|
|
|
raise ValueError(f"duplicate value in 'axes' of reduction: {axes}")
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.delete(operand.shape, axes))
|
2019-05-05 14:31:46 -04:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_prod_translation_rule(c, operand, *, axes):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(operand).numpy_dtype()
|
2019-11-12 18:38:07 -08:00
|
|
|
scalar = ShapedArray((), dtype)
|
2020-07-14 13:05:31 -07:00
|
|
|
return xops.Reduce(c, [operand], [xb.constant(c, np.array(1, dtype))],
|
2020-04-23 18:30:47 -04:00
|
|
|
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
|
2019-05-05 14:31:46 -04:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
2020-04-03 16:09:48 -04:00
|
|
|
operand, = primals
|
|
|
|
tangent, = tangents
|
2020-07-14 13:05:31 -07:00
|
|
|
input_shape = np.array(operand.shape)
|
2019-05-05 14:31:46 -04:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
n = np.prod(input_shape[list(axes)])
|
|
|
|
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
2019-05-05 14:31:46 -04:00
|
|
|
|
|
|
|
# Move the reduced axes to the front, and flatten them to 1D.
|
|
|
|
permutation = axes + tuple(non_axes)
|
|
|
|
new_shape = (n,) + tuple(input_shape[non_axes])
|
|
|
|
operand = reshape(operand, new_shape, permutation)
|
|
|
|
tangent = reshape(tangent, new_shape, permutation)
|
|
|
|
|
2020-04-03 16:09:48 -04:00
|
|
|
def _reduce_prod_tree(x, axis=0):
|
|
|
|
"""Reduce by repeatedly splitting the array and multiplying."""
|
|
|
|
while x.shape[axis] > 1:
|
|
|
|
n = x.shape[axis]
|
|
|
|
n1 = (n + 1) // 2
|
|
|
|
n2 = n - n1
|
|
|
|
x1 = slice_in_dim(x, 0, n1)
|
|
|
|
x2 = slice_in_dim(x, n1, None)
|
|
|
|
if n2 != n1:
|
|
|
|
paddings = [(0, 0, 0)] * len(x.shape)
|
|
|
|
paddings[axis] = (0, 1, 0)
|
|
|
|
x2 = pad(x2, _const(x, 1), paddings)
|
|
|
|
x = x1 * x2
|
2020-07-13 09:43:19 -04:00
|
|
|
if x.shape[axis] == 0:
|
|
|
|
return full(input_shape[non_axes], _one(x))
|
|
|
|
return squeeze(x, (axis,))
|
2020-04-03 16:09:48 -04:00
|
|
|
|
|
|
|
return api.jvp(_reduce_prod_tree, (operand,), (tangent,))
|
2019-05-05 14:31:46 -04:00
|
|
|
|
|
|
|
|
2019-12-16 20:48:19 -05:00
|
|
|
reduce_prod_p = standard_primitive(
|
Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
2020-01-16 00:36:00 +00:00
|
|
|
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
2019-12-16 20:48:19 -05:00
|
|
|
'reduce_prod', _reduce_prod_translation_rule)
|
2020-04-03 16:09:48 -04:00
|
|
|
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
2019-05-05 14:31:46 -04:00
|
|
|
batching.defreducer(reduce_prod_p)
|
2020-06-03 22:40:48 +02:00
|
|
|
_masking_defreducer(reduce_prod_p,
|
2020-07-14 13:05:31 -07:00
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(1, dtype), shape))
|
2019-05-05 14:31:46 -04:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_chooser_shape_rule(operand, *, axes):
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.delete(operand.shape, axes))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_chooser_translation_rule(prim, identity, c, operand, *, axes):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(operand).numpy_dtype()
|
2019-11-12 18:38:07 -08:00
|
|
|
scalar = ShapedArray((), dtype)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Reduce(c, [operand], [xb.constant(c, identity(dtype))],
|
|
|
|
xla.primitive_subcomputation(prim, scalar, scalar), axes)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
2018-11-17 18:03:33 -08:00
|
|
|
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
|
|
|
|
# locations in a single pass (rather than comparing equality) and use a
|
|
|
|
# gather, and/or even push along the chosen elements of g (b/112040122)
|
|
|
|
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
|
|
|
location_indicators = convert_element_type(
|
|
|
|
_eq_meet(operand, reshape(ans, shape)), g.dtype)
|
|
|
|
counts = _reduce_sum(location_indicators, axes)
|
|
|
|
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_max_translation_rule = partial(_reduce_chooser_translation_rule, max_p,
|
|
|
|
_get_max_identity)
|
Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
2020-01-16 00:36:00 +00:00
|
|
|
reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
2019-02-01 13:42:16 -05:00
|
|
|
'reduce_max', _reduce_max_translation_rule)
|
|
|
|
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
batching.defreducer(reduce_max_p)
|
2020-06-03 22:40:48 +02:00
|
|
|
_masking_defreducer(reduce_max_p,
|
2020-07-14 13:05:31 -07:00
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(-np.inf, dtype), shape))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_min_translation_rule = partial(
|
|
|
|
_reduce_chooser_translation_rule, min_p, _get_min_identity)
|
Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
2020-01-16 00:36:00 +00:00
|
|
|
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
2019-02-01 13:42:16 -05:00
|
|
|
'reduce_min', _reduce_min_translation_rule)
|
|
|
|
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
batching.defreducer(reduce_min_p)
|
2020-06-03 22:40:48 +02:00
|
|
|
_masking_defreducer(reduce_min_p,
|
2020-07-14 13:05:31 -07:00
|
|
|
lambda shape, dtype: np.broadcast_to(np.array(np.inf, dtype), shape))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-07-01 11:01:22 -04:00
|
|
|
|
|
|
|
def _argminmax_shape_rule(operand, *, axes, index_dtype):
|
|
|
|
axis, = axes
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.delete(operand.shape, axis))
|
2020-07-01 11:01:22 -04:00
|
|
|
|
|
|
|
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
|
|
|
return index_dtype
|
|
|
|
|
|
|
|
def _argminmax_translation_rule(value_comparator, identity,
|
|
|
|
c, operand, *, axes, index_dtype):
|
|
|
|
axis, = axes
|
|
|
|
shape = c.get_shape(operand)
|
|
|
|
dtype = shape.numpy_dtype()
|
|
|
|
|
|
|
|
subc = xb.make_computation_builder("argminmax_comparator")
|
|
|
|
value_shape = xc.Shape.array_shape(shape.xla_element_type(), ())
|
|
|
|
index_shape = xc.Shape.array_shape(index_dtype, ())
|
|
|
|
x_value = xb.parameter(subc, 0, value_shape)
|
|
|
|
x_index = xb.parameter(subc, 1, index_shape)
|
|
|
|
y_value = xb.parameter(subc, 2, value_shape)
|
|
|
|
y_index = xb.parameter(subc, 3, index_shape)
|
|
|
|
which_value = value_comparator(x_value, y_value)
|
|
|
|
which_index = xops.Or(which_value, xops.And(xops.Eq(x_value, y_value),
|
|
|
|
xops.Lt(x_index, y_index)))
|
|
|
|
xops.Tuple(subc, [xops.Select(which_value, x_value, y_value),
|
|
|
|
xops.Select(which_index, x_index, y_index)])
|
|
|
|
comparator = subc.build()
|
|
|
|
|
|
|
|
iota_shape = xc.Shape.array_shape(index_dtype, shape.dimensions())
|
|
|
|
iota = xc.ops.Iota(c, iota_shape, axis)
|
|
|
|
out = xops.Reduce(
|
|
|
|
c, [operand, iota],
|
|
|
|
[xb.constant(c, identity(dtype)),
|
2020-07-14 13:05:31 -07:00
|
|
|
xb.constant(c, np.array(0, index_dtype))], comparator, [axis])
|
2020-07-01 11:01:22 -04:00
|
|
|
return xops.GetTupleElement(out, 1)
|
|
|
|
|
|
|
|
def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
|
|
|
|
axis, = axes
|
|
|
|
idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis))
|
2020-07-14 13:05:31 -07:00
|
|
|
maxval = np.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype)
|
2020-07-01 11:01:22 -04:00
|
|
|
maxval = broadcast(tie_in(a, maxval), a.shape)
|
|
|
|
mask_idxs = select(eq(a, expand_dims(op(a, (axis,)), (axis,))), idxs,
|
|
|
|
maxval)
|
|
|
|
return _reduce_min(mask_idxs, (axis,))
|
|
|
|
|
|
|
|
_argmin_translation_rule = partial(_argminmax_translation_rule, xops.Lt,
|
|
|
|
_get_min_identity)
|
|
|
|
_argmax_translation_rule = partial(_argminmax_translation_rule, xops.Gt,
|
|
|
|
_get_max_identity)
|
|
|
|
|
|
|
|
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
|
|
|
'argmin', _argmin_translation_rule)
|
|
|
|
batching.defreducer(argmin_p)
|
|
|
|
ad.defjvp_zero(argmin_p)
|
|
|
|
xla.backend_specific_translations['gpu'][argmin_p] = xla.lower_fun(
|
|
|
|
partial(_argminmax_gpu_translation_rule, _reduce_min),
|
|
|
|
multiple_results=False)
|
|
|
|
|
|
|
|
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
|
|
|
'argmax', _argmax_translation_rule)
|
|
|
|
batching.defreducer(argmax_p)
|
|
|
|
ad.defjvp_zero(argmax_p)
|
|
|
|
xla.backend_specific_translations['gpu'][argmax_p] = xla.lower_fun(
|
|
|
|
partial(_argminmax_gpu_translation_rule, _reduce_max),
|
|
|
|
multiple_results=False)
|
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_logical_shape_rule(operand, *, axes):
|
2020-07-14 13:05:31 -07:00
|
|
|
if operand.dtype != np.bool_:
|
2018-12-14 08:42:02 -08:00
|
|
|
msg = "logical reduction requires operand dtype bool, got {}."
|
|
|
|
raise TypeError(msg.format(operand.dtype))
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.delete(operand.shape, axes))
|
2018-12-14 08:42:02 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_logical_translation_rule(prim, identity, c, operand, *, axes):
|
2020-07-14 13:05:31 -07:00
|
|
|
scalar = ShapedArray((), np.bool_)
|
|
|
|
return xops.Reduce(c, [operand], [xb.constant(c, identity(np.bool_))],
|
2020-04-23 18:30:47 -04:00
|
|
|
xla.primitive_subcomputation(prim, scalar, scalar), axes)
|
2018-12-14 08:42:02 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_or_translation_rule = partial(_reduce_logical_translation_rule,
|
|
|
|
or_p, _get_max_identity)
|
2020-07-14 13:05:31 -07:00
|
|
|
reduce_or_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_),
|
2019-02-01 13:42:16 -05:00
|
|
|
'reduce_or', _reduce_or_translation_rule)
|
2018-12-14 08:42:02 -08:00
|
|
|
batching.defreducer(reduce_or_p)
|
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_and_translation_rule = partial(_reduce_logical_translation_rule,
|
|
|
|
and_p, _get_min_identity)
|
2020-07-14 13:05:31 -07:00
|
|
|
reduce_and_p = standard_primitive(_reduce_logical_shape_rule, _fixed_dtype(np.bool_),
|
2019-02-01 13:42:16 -05:00
|
|
|
'reduce_and', _reduce_and_translation_rule)
|
2018-12-14 08:42:02 -08:00
|
|
|
batching.defreducer(reduce_and_p)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
2018-11-17 18:03:33 -08:00
|
|
|
if operand.dtype != init_value.dtype:
|
|
|
|
msg = ("reduce_window got inconsistent dtypes for operand and init_value: "
|
|
|
|
" got operand dtype {} and init_value dtype {}.")
|
|
|
|
raise TypeError(msg.format(operand.dtype, init_value.dtype))
|
2020-07-20 17:27:24 -04:00
|
|
|
return _common_reduce_window_shape_rule(
|
|
|
|
operand, window_dimensions, window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
2019-12-18 11:18:33 -08:00
|
|
|
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.ReduceWindowWithGeneralPadding(
|
|
|
|
operand, init_value, xla_computation, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, base_dilation, window_dilation, padding)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-13 10:22:26 -04:00
|
|
|
def _generic_reduce_window_batch_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
batched_args, batch_dims, *, jaxpr, consts, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation, window_dilation):
|
2019-07-13 10:22:26 -04:00
|
|
|
operand, init = batched_args
|
|
|
|
bdim, init_bdim = batch_dims
|
|
|
|
if init_bdim is not None:
|
|
|
|
raise NotImplementedError("reduce_window batching is not implemented for "
|
|
|
|
"initial values")
|
|
|
|
|
2020-07-20 17:27:24 -04:00
|
|
|
def reduce_window(x, window_dimensions, window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
2019-07-13 10:22:26 -04:00
|
|
|
return reduce_window_p.bind(
|
|
|
|
x, init, jaxpr=jaxpr, consts=consts, window_dimensions=window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
|
|
|
|
window_dilation=window_dilation)
|
|
|
|
return _reduce_window_batch_rule(
|
|
|
|
reduce_window, (operand,), (bdim,), window_dimensions=window_dimensions,
|
|
|
|
window_strides=window_strides, padding=padding, base_dilation=base_dilation,
|
|
|
|
window_dilation=window_dilation)
|
2019-07-13 10:22:26 -04:00
|
|
|
|
|
|
|
|
2019-12-18 11:18:33 -08:00
|
|
|
reduce_window_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_window_shape_rule, _input_dtype, 'reduce_window',
|
2019-08-21 00:22:53 -07:00
|
|
|
_reduce_window_translation_rule)
|
2019-07-13 10:22:26 -04:00
|
|
|
batching.primitive_batchers[reduce_window_p] = _generic_reduce_window_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_window_sum_shape_rule(operand, *, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation, window_dilation):
|
2020-07-14 13:05:31 -07:00
|
|
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
2020-01-29 10:51:39 -05:00
|
|
|
msg = "operand to reduce_window_sum must have a number dtype, got {}"
|
2020-07-14 13:05:31 -07:00
|
|
|
raise TypeError(msg.format(np.dtype(operand.dtype).name))
|
2019-02-01 13:42:16 -05:00
|
|
|
return _common_reduce_window_shape_rule(operand, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_window_sum_translation_rule(c, operand, *, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(operand).numpy_dtype()
|
2019-11-12 18:38:07 -08:00
|
|
|
scalar = ShapedArray((), dtype)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.ReduceWindowWithGeneralPadding(
|
2020-07-14 13:05:31 -07:00
|
|
|
operand, xb.constant(c, np.array(0, dtype)),
|
2020-04-23 18:30:47 -04:00
|
|
|
xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, base_dilation, window_dilation, padding)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_window_sum_transpose_rule(cotangent, operand, *, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(operand)
|
|
|
|
input_shape = operand.aval.shape
|
2018-11-17 18:03:33 -08:00
|
|
|
pads = _conv_general_vjp_lhs_padding(
|
2020-07-13 09:49:52 -04:00
|
|
|
input_shape, window_dimensions, window_strides, cotangent.shape, padding,
|
2020-07-20 17:27:24 -04:00
|
|
|
base_dilation, window_dilation)
|
|
|
|
ones = [1] * len(input_shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
padding_config = [(lo, hi, stride - 1)
|
|
|
|
for (lo, hi), stride in zip(pads, window_strides)]
|
|
|
|
pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
|
2020-07-20 17:27:24 -04:00
|
|
|
result = _reduce_window_sum(pad_cotangent, window_dimensions, base_dilation,
|
|
|
|
[(0, 0)] * len(input_shape),
|
|
|
|
base_dilation=ones,
|
|
|
|
window_dilation=window_dilation)
|
|
|
|
assert result.shape == input_shape, (result.shape, input_shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
return [result]
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
2019-01-28 14:33:57 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = bdims
|
|
|
|
|
|
|
|
if bdim is not None:
|
|
|
|
window_dimensions = \
|
|
|
|
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
|
|
|
|
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
|
2020-07-13 09:49:52 -04:00
|
|
|
padding = padding[:bdim] + ((0, 0),) + padding[bdim:]
|
2020-07-20 17:27:24 -04:00
|
|
|
base_dilation = base_dilation[:bdim] + (1,) + base_dilation[bdim:]
|
|
|
|
window_dilation = window_dilation[:bdim] + (1,) + window_dilation[bdim:]
|
|
|
|
|
|
|
|
operand = reduce_window(operand, window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation)
|
2019-07-06 11:58:33 -07:00
|
|
|
return operand, bdim
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
reduce_window_sum_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_window_sum_shape_rule, _input_dtype, 'reduce_window_sum',
|
|
|
|
_reduce_window_sum_translation_rule)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule)
|
2019-06-26 10:19:42 -04:00
|
|
|
batching.primitive_batchers[reduce_window_sum_p] = partial(
|
|
|
|
_reduce_window_batch_rule, _reduce_window_sum)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _reduce_window_chooser_translation_rule(
|
2020-07-20 17:27:24 -04:00
|
|
|
prim, identity, c, operand, *, window_dimensions, window_strides, padding,
|
|
|
|
base_dilation, window_dilation):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(operand).numpy_dtype()
|
2019-11-12 18:38:07 -08:00
|
|
|
scalar = ShapedArray((), dtype)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.ReduceWindowWithGeneralPadding(
|
|
|
|
operand, xb.constant(c, identity(dtype)),
|
|
|
|
xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, base_dilation, window_dilation, padding)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _reduce_window_chooser_jvp_rule(prim, g, operand, *, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
2018-11-17 18:03:33 -08:00
|
|
|
assert prim is max_p or prim is min_p
|
|
|
|
select_prim = ge_p if prim is max_p else le_p
|
|
|
|
return _select_and_gather_add(g, operand, select_prim, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _common_reduce_window_shape_rule(operand, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation):
|
2018-11-17 18:03:33 -08:00
|
|
|
_check_shapelike("reduce_window", "window_dimensions", window_dimensions)
|
|
|
|
_check_shapelike("reduce_window", "window_strides", window_strides)
|
2020-07-20 17:27:24 -04:00
|
|
|
_check_shapelike("reduce_window", "base_dilation", base_dilation)
|
|
|
|
_check_shapelike("reduce_window", "window_dilation", window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
if operand.ndim != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got the wrong number of window_dimensions for "
|
|
|
|
"operand: got operand shape {} with window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(operand.shape, window_dimensions))
|
|
|
|
if len(window_strides) != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got inconsistent window_strides and "
|
|
|
|
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(window_strides, window_dimensions))
|
2020-07-20 17:27:24 -04:00
|
|
|
if len(base_dilation) != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got inconsistent base_dilation and "
|
|
|
|
"window_dimensions: got base_dilation {} and window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(base_dilation, window_dimensions))
|
|
|
|
if len(window_dilation) != len(window_dimensions):
|
|
|
|
msg = ("reduce_window got inconsistent window_dilation and "
|
|
|
|
"window_dimensions: got window_dilation {} and window_dimensions "
|
|
|
|
"{}.")
|
|
|
|
raise TypeError(msg.format(window_dilation, window_dimensions))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
return reduce_window_shape_tuple(operand.shape, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation=None,
|
|
|
|
window_dilation=None):
|
|
|
|
if base_dilation is not None:
|
|
|
|
operand_shape = _dilate_shape(operand_shape, base_dilation)
|
|
|
|
if window_dilation is not None:
|
|
|
|
window_dimensions = _dilate_shape(window_dimensions, window_dilation)
|
2020-07-14 13:05:31 -07:00
|
|
|
operand_padded = np.add(operand_shape, np.add(*zip(*padding)))
|
|
|
|
t = np.floor_divide(
|
|
|
|
np.subtract(operand_padded, window_dimensions), window_strides) + 1
|
2018-11-17 18:03:33 -08:00
|
|
|
return tuple(t)
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_window_max_translation_rule = partial(
|
|
|
|
_reduce_window_chooser_translation_rule, max_p, _get_max_identity)
|
2018-11-17 18:03:33 -08:00
|
|
|
reduce_window_max_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
|
|
|
|
_reduce_window_max_translation_rule)
|
|
|
|
ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, max_p))
|
2019-06-26 10:19:42 -04:00
|
|
|
batching.primitive_batchers[reduce_window_max_p] = partial(
|
|
|
|
_reduce_window_batch_rule, _reduce_window_max)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
_reduce_window_min_translation_rule = partial(
|
|
|
|
_reduce_window_chooser_translation_rule, min_p, _get_min_identity)
|
2018-11-17 18:03:33 -08:00
|
|
|
reduce_window_min_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_common_reduce_window_shape_rule, _input_dtype, 'reduce_window_min',
|
|
|
|
_reduce_window_min_translation_rule)
|
|
|
|
ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, min_p))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-06-26 10:19:42 -04:00
|
|
|
_reduce_window_min_batch_rule = partial(_reduce_window_batch_rule,
|
|
|
|
_reduce_window_min)
|
|
|
|
batching.primitive_batchers[reduce_window_min_p] = partial(
|
|
|
|
_reduce_window_batch_rule, _reduce_window_min)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_and_scatter_shape_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
|
2018-11-17 18:03:33 -08:00
|
|
|
scatter_consts, window_dimensions, window_strides, padding):
|
|
|
|
_check_shapelike("select_and_scatter", "window_dimensions", window_dimensions)
|
|
|
|
_check_shapelike("select_and_scatter", "window_strides", window_strides)
|
|
|
|
if len(window_dimensions) != len(window_strides):
|
|
|
|
msg = ("select_and_scatter got inconsistent window_strides and "
|
|
|
|
"window_dimensions: got window_strides {} and window_dimensions {}.")
|
|
|
|
raise TypeError(msg.format(window_strides, window_dimensions))
|
|
|
|
return operand.shape
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_and_scatter_translation(
|
2020-04-07 09:38:10 -04:00
|
|
|
c, operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
|
2019-12-18 11:18:33 -08:00
|
|
|
scatter_consts, window_dimensions, window_strides, padding):
|
|
|
|
select = _reduction_computation(c, select_jaxpr, select_consts, init_value)
|
|
|
|
scatter = _reduction_computation(c, scatter_jaxpr, scatter_consts, init_value)
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.SelectAndScatterWithGeneralPadding(
|
2020-07-13 09:49:52 -04:00
|
|
|
operand, select, window_dimensions, window_strides, padding, source,
|
2020-04-23 18:30:47 -04:00
|
|
|
init_value, scatter)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-12-18 11:18:33 -08:00
|
|
|
select_and_scatter_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_select_and_scatter_shape_rule, _input_dtype, 'select_and_scatter',
|
2019-08-21 00:22:53 -07:00
|
|
|
_select_and_scatter_translation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_and_scatter_add_shape_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
source, operand, *, select_prim, window_dimensions, window_strides,
|
|
|
|
padding):
|
2018-11-17 18:03:33 -08:00
|
|
|
return operand.shape
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_and_scatter_add_translation(
|
2020-04-07 09:38:10 -04:00
|
|
|
c, source, operand, *, select_prim, window_dimensions, window_strides,
|
2018-11-17 18:03:33 -08:00
|
|
|
padding):
|
2020-05-11 17:43:55 -04:00
|
|
|
dtype = c.get_shape(operand).numpy_dtype()
|
2019-11-12 18:38:07 -08:00
|
|
|
scalar = ShapedArray((), dtype)
|
2019-12-12 05:14:57 -08:00
|
|
|
select = xla.primitive_subcomputation(select_prim, scalar, scalar)
|
|
|
|
scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
|
2020-07-14 13:05:31 -07:00
|
|
|
zero = xb.constant(c, np.array(0, dtype))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.SelectAndScatterWithGeneralPadding(
|
2020-07-13 09:49:52 -04:00
|
|
|
operand, select, window_dimensions, window_strides, padding, source, zero,
|
2020-04-23 18:30:47 -04:00
|
|
|
scatter)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-04-20 17:06:35 -07:00
|
|
|
def _select_and_scatter_add_jvp(
|
2020-04-07 09:38:10 -04:00
|
|
|
primals, tangents, *, select_prim, window_dimensions, window_strides,
|
2019-04-20 17:06:35 -07:00
|
|
|
padding):
|
|
|
|
source, operand = primals
|
|
|
|
g_source, g_operand = tangents
|
|
|
|
val_out = _select_and_scatter_add(
|
|
|
|
source, operand, select_prim, window_dimensions, window_strides,
|
|
|
|
padding)
|
|
|
|
del g_operand
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(g_source) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(val_out)
|
2019-04-20 17:06:35 -07:00
|
|
|
else:
|
|
|
|
tangent_out = _select_and_scatter_add(
|
|
|
|
g_source, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding)
|
|
|
|
return val_out, tangent_out
|
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_and_scatter_add_transpose(
|
2020-04-07 09:38:10 -04:00
|
|
|
t, source, operand, *, select_prim, window_dimensions, window_strides,
|
2018-11-17 18:03:33 -08:00
|
|
|
padding):
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(source) and not ad.is_undefined_primal(operand)
|
2020-07-20 17:27:24 -04:00
|
|
|
ones = (1,) * len(window_dimensions)
|
2019-04-20 17:06:35 -07:00
|
|
|
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, ones, ones)
|
2019-04-20 17:06:35 -07:00
|
|
|
return [source_t, None]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
2019-01-28 14:33:57 -08:00
|
|
|
source, operand = batched_args
|
|
|
|
s_bdims, o_bdims = batch_dims
|
|
|
|
|
|
|
|
if s_bdims is not None and o_bdims is not None:
|
2019-01-28 19:08:05 -08:00
|
|
|
#TODO(#212): use a map construct instead of unrolling.
|
2019-07-27 15:46:14 -07:00
|
|
|
source = batching.moveaxis(source, s_bdims, 0)
|
|
|
|
operand = batching.moveaxis(operand, o_bdims, 0)
|
2019-01-28 14:33:57 -08:00
|
|
|
outputs = [
|
|
|
|
_select_and_scatter_add(s, o, **kwargs) for s, o in zip(source, operand)]
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
outputs = [broadcast(out, (1,)) for out in outputs]
|
2019-01-28 14:33:57 -08:00
|
|
|
outputs = concatenate(outputs, 0)
|
|
|
|
return outputs, 0
|
|
|
|
elif s_bdims is not None:
|
2019-01-28 19:08:05 -08:00
|
|
|
#TODO(#212): use a map construct instead of unrolling.
|
2019-07-27 15:46:14 -07:00
|
|
|
source = batching.moveaxis(source, s_bdims, 0)
|
2019-01-28 14:33:57 -08:00
|
|
|
outputs = [
|
|
|
|
_select_and_scatter_add(s, operand, **kwargs) for s in source]
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
outputs = [broadcast(out, (1,)) for out in outputs]
|
2019-01-28 14:33:57 -08:00
|
|
|
outputs = concatenate(outputs, 0)
|
|
|
|
return outputs, 0
|
|
|
|
elif o_bdims is not None:
|
2019-01-28 19:08:05 -08:00
|
|
|
#TODO(#212): use a map construct instead of unrolling.
|
2019-07-27 15:46:14 -07:00
|
|
|
operand = batching.moveaxis(operand, o_bdims, 0)
|
2019-01-28 14:33:57 -08:00
|
|
|
outputs = [
|
|
|
|
_select_and_scatter_add(source, o, **kwargs) for o in operand]
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
outputs = [broadcast(out, (1,)) for out in outputs]
|
2019-01-28 14:33:57 -08:00
|
|
|
outputs = concatenate(outputs, 0)
|
|
|
|
return outputs, 0
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
select_and_scatter_add_p = standard_primitive(
|
2019-02-01 13:42:16 -05:00
|
|
|
_select_and_scatter_add_shape_rule, _input_dtype, 'select_and_scatter_add',
|
|
|
|
_select_and_scatter_add_translation)
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.primitive_transposes[select_and_scatter_add_p] = \
|
2019-02-01 13:42:16 -05:00
|
|
|
_select_and_scatter_add_transpose
|
2019-04-20 17:06:35 -07:00
|
|
|
ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
|
2019-01-28 14:33:57 -08:00
|
|
|
batching.primitive_batchers[select_and_scatter_add_p] = \
|
2019-02-01 13:42:16 -05:00
|
|
|
_select_and_scatter_add_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-28 14:29:17 -05:00
|
|
|
def _select_and_gather_add_shape_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
tangents, operand, *, select_prim, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation, window_dilation):
|
2018-11-17 18:03:33 -08:00
|
|
|
if tangents.shape != operand.shape:
|
|
|
|
msg = ("select_and_gather_add tangents and operand shapes must match, "
|
|
|
|
"got {} and {}.")
|
|
|
|
raise TypeError(msg.format(tangents.shape, operand.shape))
|
2020-07-20 17:27:24 -04:00
|
|
|
return _common_reduce_window_shape_rule(
|
|
|
|
operand, window_dimensions, window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-01-28 14:29:17 -05:00
|
|
|
_UINT_DTYPES = {
|
2020-07-14 13:05:31 -07:00
|
|
|
16: np.uint16,
|
|
|
|
32: np.uint32,
|
|
|
|
64: np.uint64,
|
2019-01-28 14:29:17 -05:00
|
|
|
}
|
|
|
|
|
2020-05-14 19:17:44 -04:00
|
|
|
_INT_DTYPES = {
|
2020-07-14 13:05:31 -07:00
|
|
|
16: np.int16,
|
|
|
|
32: np.int32,
|
|
|
|
64: np.int64,
|
2020-05-14 19:17:44 -04:00
|
|
|
}
|
2019-01-28 14:29:17 -05:00
|
|
|
|
|
|
|
def _select_and_gather_add_translation(
|
2020-04-07 09:38:10 -04:00
|
|
|
c, tangents, operand, *, select_prim, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation, window_dilation, max_bits=64):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(operand)
|
2019-07-01 22:26:36 -04:00
|
|
|
dtype = shape.numpy_dtype()
|
2019-06-28 20:27:10 -04:00
|
|
|
etype = shape.xla_element_type()
|
2019-11-15 10:02:51 -05:00
|
|
|
nbits = dtypes.finfo(dtype).bits
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
assert nbits <= max_bits
|
|
|
|
double_word_reduction = nbits * 2 <= max_bits
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
const = lambda c, dtype, x: xb.constant(c, np.array(x, dtype=dtype),
|
2020-04-23 18:30:47 -04:00
|
|
|
canonicalize_types=False)
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
if double_word_reduction:
|
2019-11-21 11:52:58 -05:00
|
|
|
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
|
2019-10-09 15:05:54 -04:00
|
|
|
# we implement a pair-wise ReduceWindow by packing two k-bit values into
|
|
|
|
# 2k-bit unsigned integer using bit tricks.
|
2019-07-01 22:26:36 -04:00
|
|
|
word_dtype = _UINT_DTYPES[nbits]
|
|
|
|
double_word_dtype = _UINT_DTYPES[nbits * 2]
|
2019-08-04 12:34:03 -04:00
|
|
|
word_type = xla_client.dtype_to_etype(word_dtype)
|
|
|
|
double_word_type = xla_client.dtype_to_etype(double_word_dtype)
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
# Packs two values into a tuple.
|
|
|
|
def pack(a, b):
|
2020-04-23 18:30:47 -04:00
|
|
|
a = xops.BitcastConvertType(a, word_type)
|
|
|
|
b = xops.BitcastConvertType(b, word_type)
|
|
|
|
a = xops.ConvertElementType(a, double_word_type)
|
|
|
|
b = xops.ConvertElementType(b, double_word_type)
|
|
|
|
a = xops.ShiftLeft(a, const(c, double_word_dtype, nbits))
|
|
|
|
return xops.Or(a, b)
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
# Unpacks the first element of a tuple.
|
|
|
|
def fst(c, t):
|
2020-04-23 18:30:47 -04:00
|
|
|
st = xops.ShiftRightLogical(t, const(c, double_word_dtype, nbits))
|
|
|
|
return xops.BitcastConvertType(xops.ConvertElementType(st, word_type), etype)
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
# Unpacks the second element of a tuple.
|
|
|
|
def snd(t):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.BitcastConvertType(xops.ConvertElementType(t, word_type), etype)
|
2019-06-28 20:27:10 -04:00
|
|
|
|
2019-07-01 22:26:36 -04:00
|
|
|
else:
|
2019-07-02 13:23:05 -04:00
|
|
|
# The double-word trick above only works if we have a sufficiently large
|
|
|
|
# type. As an alternative, we can pack two half words into a single word,
|
|
|
|
# at the cost of precision.
|
|
|
|
# TODO(b/73062247): add support for tuple reductions and remove this case.
|
2019-06-28 20:27:10 -04:00
|
|
|
warnings.warn("Using reduced precision for gradient of reduce-window "
|
2019-07-02 13:23:05 -04:00
|
|
|
"min/max operator to work around missing XLA support for "
|
|
|
|
"pair-reductions. This is likely from a second or "
|
|
|
|
"higher derivative of a max-pooling operation.")
|
2019-07-02 11:34:49 -04:00
|
|
|
r_nbits = nbits // 2
|
|
|
|
# Drop/round the bottom mantissa bits.
|
2019-11-15 10:02:51 -05:00
|
|
|
nexp = dtypes.finfo(dtype).nexp
|
2019-07-02 11:34:49 -04:00
|
|
|
nmant = r_nbits - nexp - 1
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
double_word_dtype = word_dtype = _UINT_DTYPES[nbits]
|
2019-08-04 12:34:03 -04:00
|
|
|
word_type = xla_client.dtype_to_etype(word_dtype)
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
# Packs two values into a tuple.
|
|
|
|
def pack(a, b):
|
2020-04-23 18:30:47 -04:00
|
|
|
a = xops.ReducePrecision(a, exponent_bits=nexp, mantissa_bits=nmant)
|
|
|
|
b = xops.ReducePrecision(b, exponent_bits=nexp, mantissa_bits=nmant)
|
|
|
|
a = xops.BitcastConvertType(a, word_type)
|
|
|
|
b = xops.BitcastConvertType(b, word_type)
|
|
|
|
b = xops.ShiftRightLogical(b, const(c, word_dtype, r_nbits))
|
|
|
|
return xops.Or(a, b)
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
# Unpacks the first element of a tuple.
|
|
|
|
def fst(c, t):
|
2020-04-23 18:30:47 -04:00
|
|
|
st = xops.And(t, const(c, word_dtype, ((1 << r_nbits) - 1) << r_nbits))
|
|
|
|
return xops.BitcastConvertType(st, etype)
|
2019-07-01 22:26:36 -04:00
|
|
|
|
|
|
|
# Unpacks the second element of a tuple.
|
|
|
|
def snd(t):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.BitcastConvertType(xops.ShiftLeft(t, const(c, word_dtype, r_nbits)),
|
2019-07-01 22:26:36 -04:00
|
|
|
etype)
|
|
|
|
|
|
|
|
def reducer():
|
|
|
|
c = xla_bridge.make_computation_builder("select_and_gather_pair_reducer")
|
2020-04-23 18:30:47 -04:00
|
|
|
x = xb.parameter(c, 0,
|
2020-07-14 13:05:31 -07:00
|
|
|
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
|
2020-04-23 18:30:47 -04:00
|
|
|
y = xb.parameter(c, 1,
|
2020-07-14 13:05:31 -07:00
|
|
|
xla_client.Shape.array_shape(np.dtype(double_word_dtype), ()))
|
2019-07-01 22:26:36 -04:00
|
|
|
assert select_prim is ge_p or select_prim is le_p
|
2020-04-23 18:30:47 -04:00
|
|
|
which = xops.Ge if select_prim is ge_p else xops.Le
|
|
|
|
xops.Select(which(fst(c, x), fst(c, y)), x, y)
|
2020-05-11 17:43:55 -04:00
|
|
|
return c.build()
|
2019-07-01 22:26:36 -04:00
|
|
|
|
2019-01-28 14:29:17 -05:00
|
|
|
|
2019-11-21 11:52:58 -05:00
|
|
|
assert select_prim is ge_p or select_prim is le_p, select_prim
|
2020-07-14 13:05:31 -07:00
|
|
|
init = -np.inf if select_prim is ge_p else np.inf
|
2020-04-23 18:30:47 -04:00
|
|
|
out = xops.ReduceWindowWithGeneralPadding(
|
|
|
|
pack(operand, tangents), pack(const(c, dtype, init), const(c, dtype, 0)),
|
2020-07-20 17:27:24 -04:00
|
|
|
reducer(), window_dimensions, window_strides, base_dilation,
|
|
|
|
window_dilation, padding)
|
2019-07-01 22:26:36 -04:00
|
|
|
return snd(out)
|
2019-01-28 14:29:17 -05:00
|
|
|
|
2019-04-20 17:06:56 -07:00
|
|
|
def _select_and_gather_add_jvp(
|
2020-04-07 09:38:10 -04:00
|
|
|
primals, tangents, *, select_prim, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation, window_dilation):
|
2019-04-20 17:06:56 -07:00
|
|
|
source, operand = primals
|
|
|
|
g_source, g_operand = tangents
|
|
|
|
val_out = _select_and_gather_add(
|
|
|
|
source, operand, select_prim, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation, window_dilation)
|
2019-04-20 17:06:56 -07:00
|
|
|
del g_operand
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(g_source) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(val_out)
|
2019-04-20 17:06:56 -07:00
|
|
|
else:
|
|
|
|
tangent_out = _select_and_gather_add(
|
|
|
|
g_source, operand, select_prim, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation, window_dilation)
|
2019-04-20 17:06:56 -07:00
|
|
|
return val_out, tangent_out
|
|
|
|
|
2019-01-28 14:29:17 -05:00
|
|
|
def _select_and_gather_add_transpose(
|
2020-04-07 09:38:10 -04:00
|
|
|
t, tangents, operand, *, select_prim, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation, window_dilation):
|
|
|
|
assert select_prim in (le_p, ge_p)
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
assert ad.is_undefined_primal(tangents) and not ad.is_undefined_primal(operand)
|
2020-07-20 17:27:24 -04:00
|
|
|
if any(d != 1 for d in window_dilation):
|
|
|
|
msg = ("VJP not implemented for select_and_gather (MaxPool) with window "
|
|
|
|
"dilation, got window_dilation={}.")
|
|
|
|
raise NotImplementedError(msg.format(window_dilation))
|
|
|
|
has_base_dilation = any(d != 1 for d in base_dilation)
|
|
|
|
if has_base_dilation:
|
|
|
|
select_identity = (_get_max_identity if select_prim is ge_p
|
|
|
|
else _get_min_identity)
|
|
|
|
operand = pad(operand, select_identity(operand.dtype),
|
|
|
|
tuple((0, 0, d - 1) for d in base_dilation))
|
2018-11-17 18:03:33 -08:00
|
|
|
result = _select_and_scatter_add(t, operand, select_prim, window_dimensions,
|
|
|
|
window_strides, padding)
|
2020-07-20 17:27:24 -04:00
|
|
|
if has_base_dilation:
|
|
|
|
result = slice(operand, (0,) * len(operand.shape), operand.shape,
|
|
|
|
base_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
return [result, None]
|
|
|
|
|
2019-11-21 11:52:58 -05:00
|
|
|
def _select_and_gather_add_batching_rule(
|
2020-04-07 09:38:10 -04:00
|
|
|
batched_args, batch_dims, *, select_prim, window_dimensions, window_strides,
|
2020-07-20 17:27:24 -04:00
|
|
|
padding, base_dilation, window_dilation):
|
2019-11-21 11:52:58 -05:00
|
|
|
t, x = batched_args
|
|
|
|
t_bdim, x_bdim = batch_dims
|
|
|
|
size = next(a.shape[bdim] for a, bdim in zip(batched_args, batch_dims)
|
|
|
|
if bdim is not None)
|
|
|
|
t = batching.bdim_at_front(t, t_bdim, size)
|
|
|
|
x = batching.bdim_at_front(x, x_bdim, size)
|
|
|
|
window_dimensions = (1,) + window_dimensions
|
|
|
|
window_strides = (1,) + window_strides
|
2020-07-13 09:49:52 -04:00
|
|
|
padding = ((0, 0),) + padding
|
2020-07-20 17:27:24 -04:00
|
|
|
base_dilation = (1,) + base_dilation
|
|
|
|
window_dilation = (1,) + window_dilation
|
2019-11-21 11:52:58 -05:00
|
|
|
out = _select_and_gather_add(t, x, select_prim, window_dimensions,
|
2020-07-20 17:27:24 -04:00
|
|
|
window_strides, padding, base_dilation,
|
|
|
|
window_dilation)
|
2019-11-21 11:52:58 -05:00
|
|
|
return (out, 0)
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
select_and_gather_add_p = standard_primitive(
|
2019-01-28 14:29:17 -05:00
|
|
|
_select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add',
|
|
|
|
_select_and_gather_add_translation)
|
2019-04-20 17:06:56 -07:00
|
|
|
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
|
2018-11-17 18:03:33 -08:00
|
|
|
ad.primitive_transposes[select_and_gather_add_p] = \
|
2019-11-21 11:52:58 -05:00
|
|
|
_select_and_gather_add_transpose
|
|
|
|
batching.primitive_batchers[select_and_gather_add_p] = \
|
|
|
|
_select_and_gather_add_batching_rule
|
2019-06-28 20:27:10 -04:00
|
|
|
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
|
|
|
|
_select_and_gather_add_translation,
|
2019-07-01 22:26:36 -04:00
|
|
|
max_bits=32)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-06 11:22:01 -04:00
|
|
|
# Parallel prefix-scan. See:
|
|
|
|
# https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
|
|
|
|
# and
|
|
|
|
# Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report
|
|
|
|
# CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.
|
|
|
|
#
|
|
|
|
# Unlike the Blelloch algorithm, we use an out-of-place algorithm that uses 2n
|
|
|
|
# space. This is somewhat wasteful if we are interested only in the output of
|
|
|
|
# the forward pass, but more memory-efficient if we intend to differentiate
|
|
|
|
# through the implementation of the scan.
|
|
|
|
def _prescan_power_of_two(x, axis: int, op: Callable, unit):
|
|
|
|
n = x.shape[axis]
|
|
|
|
assert n != 0 and n & (n - 1) == 0, "n must be a power of 2"
|
|
|
|
|
|
|
|
# Upsweep
|
|
|
|
xs = []
|
|
|
|
for d in range(0, n.bit_length() - 1):
|
|
|
|
x1 = slice_in_dim(x, 0, None, stride=2, axis=axis)
|
|
|
|
xs.append(x1)
|
|
|
|
x2 = slice_in_dim(x, 1, None, stride=2, axis=axis)
|
|
|
|
x = op(x1, x2)
|
|
|
|
total = x
|
|
|
|
|
|
|
|
# Downsweep
|
|
|
|
x = full_like(total, unit)
|
|
|
|
pad_left = [(0, 0, 0)] * len(x.shape)
|
|
|
|
pad_left[axis] = (1, 0, 1)
|
|
|
|
pad_right = [(0, 0, 0)] * len(x.shape)
|
|
|
|
pad_right[axis] = (0, 1, 1)
|
|
|
|
for w in reversed(xs):
|
|
|
|
x1 = pad(x, _const(x, 0), pad_right)
|
|
|
|
x2 = pad(x, _const(x, 0), pad_left)
|
|
|
|
w = pad(w, _const(x, 0), pad_left)
|
|
|
|
x = x1 + op(x2, w)
|
|
|
|
|
|
|
|
return x, total
|
|
|
|
|
|
|
|
|
2020-06-30 11:36:27 -07:00
|
|
|
def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any):
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.issubdtype(x.dtype, np.integer):
|
|
|
|
if np.isposinf(unit):
|
|
|
|
unit = np.iinfo(x.dtype).max
|
|
|
|
elif np.isneginf(unit):
|
|
|
|
unit = np.iinfo(x.dtype).min
|
2020-04-06 11:22:01 -04:00
|
|
|
n = x.shape[axis]
|
|
|
|
if n == 0:
|
|
|
|
return x
|
|
|
|
# Pads to the next largest power of two
|
|
|
|
nbits = n.bit_length()
|
|
|
|
if n == (1 << (nbits - 1)):
|
|
|
|
nbits -= 1
|
|
|
|
padding = [(0, 0, 0)] * len(x.shape)
|
|
|
|
padding[axis] = (0, (1 << nbits) - n, 0)
|
|
|
|
x = pad(x, _const(x, unit), padding)
|
|
|
|
x, total = _prescan_power_of_two(x, axis, op, unit)
|
|
|
|
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
|
|
|
|
|
2020-06-30 11:36:27 -07:00
|
|
|
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
|
|
|
|
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
|
2020-07-14 13:05:31 -07:00
|
|
|
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-np.inf)
|
|
|
|
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=np.inf)
|
2020-04-06 15:14:22 -04:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _cumred_shape_rule(x, *, axis: int):
|
2020-04-06 11:22:01 -04:00
|
|
|
if axis < 0 or axis >= x.ndim:
|
|
|
|
raise ValueError(
|
|
|
|
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
|
|
|
return x.shape
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _cumsum_transpose_rule(t, *, axis: int):
|
2020-04-06 15:14:22 -04:00
|
|
|
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]
|
2020-04-06 11:22:01 -04:00
|
|
|
|
2020-06-28 20:39:20 +01:00
|
|
|
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
|
|
|
|
prefix_scan: Callable):
|
2020-04-06 15:14:22 -04:00
|
|
|
# Irrespective of backend, we always use the parallel prefix scan
|
|
|
|
# implementation when differentiating because reduce_window is not
|
|
|
|
# arbitrarily differentiable.
|
2020-06-28 20:39:20 +01:00
|
|
|
return api.jvp(partial(prefix_scan, axis=axis), primals, tangents)
|
2020-06-28 18:21:09 +01:00
|
|
|
|
|
|
|
|
2020-06-28 20:21:35 +01:00
|
|
|
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
|
2020-04-07 09:38:10 -04:00
|
|
|
axis: int):
|
2020-04-06 11:22:01 -04:00
|
|
|
# On TPU, an implementation using reduce_window is handled specially by the
|
2020-04-06 15:14:22 -04:00
|
|
|
# compiler and is efficient. On other backends, it is O(n^2).
|
2020-04-06 11:22:01 -04:00
|
|
|
n = x.shape[axis]
|
2020-04-06 12:33:55 -04:00
|
|
|
if n == 0:
|
|
|
|
return x
|
2020-07-13 09:49:52 -04:00
|
|
|
padding = [(0, 0)] * x.ndim
|
|
|
|
padding[axis] = (n - 1, 0)
|
2020-04-06 11:22:01 -04:00
|
|
|
strides = [1] * x.ndim
|
|
|
|
window_dims = [1] * x.ndim
|
|
|
|
window_dims[axis] = n
|
2020-07-13 09:49:52 -04:00
|
|
|
return window_reduce(x, window_dims, strides, padding)
|
2020-04-06 11:22:01 -04:00
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
|
2020-04-06 11:22:01 -04:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
axis = axis if axis < bdim else axis + 1
|
|
|
|
return prim.bind(operand, axis=axis), bdim
|
|
|
|
|
|
|
|
|
|
|
|
cumsum_p = standard_primitive(
|
|
|
|
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumsum"),
|
2020-04-06 15:14:22 -04:00
|
|
|
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
|
|
|
|
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
|
2020-04-06 11:22:01 -04:00
|
|
|
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
2020-06-28 20:21:35 +01:00
|
|
|
partial(_cumred_tpu_translation_rule, _reduce_window_sum),
|
2020-04-06 11:22:01 -04:00
|
|
|
multiple_results=False)
|
|
|
|
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
|
|
|
|
|
|
|
|
2020-06-28 20:39:20 +01:00
|
|
|
def _cumulative_reduction_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
|
2020-06-28 20:28:31 +01:00
|
|
|
reducer_p = standard_primitive(
|
|
|
|
_cumred_shape_rule, partial(_reduce_number_dtype_rule, name),
|
|
|
|
name, xla.lower_fun(prefix_scan_fn, multiple_results=False))
|
2020-06-28 21:33:42 +01:00
|
|
|
ad.primitive_jvps[reducer_p] = jvp_rule
|
2020-06-28 20:28:31 +01:00
|
|
|
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
|
|
|
|
partial(_cumred_tpu_translation_rule, reduce_window_fn),
|
|
|
|
multiple_results=False)
|
|
|
|
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
|
2020-06-28 20:31:30 +01:00
|
|
|
return reducer_p
|
2020-04-06 11:22:01 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-28 20:39:20 +01:00
|
|
|
cumprod_p = _cumulative_reduction_primitive("cumprod", _cumprod_prefix_scan,
|
|
|
|
partial(_cumulative_jvp_rule,
|
|
|
|
prefix_scan=_cumprod_prefix_scan),
|
|
|
|
_reduce_window_prod)
|
2020-06-28 18:21:09 +01:00
|
|
|
|
2020-06-28 20:39:20 +01:00
|
|
|
cummax_p = _cumulative_reduction_primitive("cummax", _cummax_prefix_scan,
|
|
|
|
partial(_cumulative_jvp_rule,
|
|
|
|
prefix_scan=_cummax_prefix_scan),
|
|
|
|
_reduce_window_max)
|
2020-06-28 18:21:09 +01:00
|
|
|
|
2020-06-28 20:39:20 +01:00
|
|
|
cummin_p = _cumulative_reduction_primitive("cummin", _cummin_prefix_scan,
|
|
|
|
partial(_cumulative_jvp_rule,
|
|
|
|
prefix_scan=_cummin_prefix_scan),
|
|
|
|
_reduce_window_min)
|
2020-06-28 18:21:09 +01:00
|
|
|
|
|
|
|
|
2020-05-14 11:13:15 -04:00
|
|
|
def _sort_abstract_eval(*args, **kwargs):
|
|
|
|
args = tuple(raise_to_shaped(arg) for arg in args)
|
|
|
|
if any(arg.shape != args[0].shape for arg in args[1:]):
|
|
|
|
shapes = " ".join(str(a.shape) for a in args)
|
|
|
|
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
|
|
|
|
return args
|
|
|
|
|
2020-05-14 19:17:44 -04:00
|
|
|
|
|
|
|
def _float_to_int_for_sort(x):
|
|
|
|
# Switch from a floating point value to a integer value in such a way that
|
|
|
|
# when using the integer value to compare, we get the same result for normal
|
|
|
|
# values, and -nan is treated as the smallest value, and nan is treated as
|
|
|
|
# the largest value.
|
|
|
|
# If f is a float, and
|
|
|
|
# x = bit_cast<int32>(f);
|
|
|
|
# y = x < 0 ? int32_max - x : x;
|
|
|
|
# then y is ordered as an int32 such that finite values have the obvious
|
|
|
|
# order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
|
|
|
|
# and end of the ordering.
|
|
|
|
# Note that in order to avoid -x to overflow, we calculate
|
|
|
|
# int32_max - x as unsigned, and then convert back to signed.
|
|
|
|
if x.dtype == dtypes.bfloat16:
|
2020-07-14 13:05:31 -07:00
|
|
|
x = convert_element_type(x, np.float32)
|
|
|
|
nbits = np.finfo(x).bits
|
2020-05-14 19:17:44 -04:00
|
|
|
signed_dtype = _INT_DTYPES[nbits]
|
|
|
|
unsigned_dtype = _UINT_DTYPES[nbits]
|
|
|
|
|
|
|
|
signed = bitcast_convert_type(x, signed_dtype)
|
|
|
|
unsigned = bitcast_convert_type(x, unsigned_dtype)
|
|
|
|
flipped = bitcast_convert_type(
|
2020-07-14 13:05:31 -07:00
|
|
|
sub(unsigned_dtype(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
|
2020-05-14 19:17:44 -04:00
|
|
|
return select(lt(signed, _zero(signed)), flipped, signed)
|
|
|
|
|
2020-07-09 20:05:19 -07:00
|
|
|
# Default comparator that sorts the operands lexicographically on the
|
|
|
|
# first `num_keys` arguments.
|
2020-05-14 19:17:44 -04:00
|
|
|
# For floating point types, a total order is created where
|
|
|
|
# -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN.
|
|
|
|
# For complex types, the (real, imag) pairs are sorted lexicographically
|
|
|
|
# (following NumPy's semantics).
|
2020-07-09 20:05:19 -07:00
|
|
|
# This code adds complex-number support and lexicographic ordering to the algorithm from:
|
2020-05-14 19:17:44 -04:00
|
|
|
# https://github.com/tensorflow/tensorflow/blob/ba43780830f09da72081fe5061c436f1c6203a92/tensorflow/compiler/xla/client/lib/comparators.h#L33
|
2020-07-09 20:05:19 -07:00
|
|
|
def _sort_lt_comparator(*operands, num_keys=1):
|
2020-05-14 19:17:44 -04:00
|
|
|
assert len(operands) >= 2 and len(operands) % 2 == 0, operands
|
2020-07-09 20:05:19 -07:00
|
|
|
assert len(operands) // 2 >= num_keys, (operands, num_keys)
|
|
|
|
x_keys, y_keys = [], []
|
|
|
|
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
|
|
|
|
assert x.dtype == y.dtype, (x.dtype, y.dtype)
|
2020-07-14 13:05:31 -07:00
|
|
|
if np.issubdtype(x.dtype, np.complexfloating):
|
2020-07-09 20:05:19 -07:00
|
|
|
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
|
|
|
|
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
|
2020-07-14 13:05:31 -07:00
|
|
|
elif np.issubdtype(x.dtype, np.floating):
|
2020-07-09 20:05:19 -07:00
|
|
|
x_keys.append(_float_to_int_for_sort(x))
|
|
|
|
y_keys.append(_float_to_int_for_sort(y))
|
|
|
|
else:
|
|
|
|
x_keys.append(x)
|
|
|
|
y_keys.append(y)
|
2020-05-14 19:17:44 -04:00
|
|
|
|
|
|
|
p = None
|
|
|
|
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
|
|
|
|
p = (bitwise_or(lt(xk, yk), bitwise_and(eq(xk, yk), p)) if p is not None
|
|
|
|
else lt(xk, yk))
|
|
|
|
return p
|
|
|
|
|
2020-07-09 20:05:19 -07:00
|
|
|
|
2020-07-10 09:58:35 -07:00
|
|
|
def _sort_translation_rule(c, *operands, dimension, is_stable, num_keys):
|
2020-05-14 19:17:44 -04:00
|
|
|
types = [c.get_shape(x).xla_element_type() for x in operands]
|
|
|
|
subc = xla_bridge.make_computation_builder("sort_lt_comparator")
|
|
|
|
params = [xb.parameter(subc, 2 * i + j, xc.Shape.array_shape(typ, ()))
|
|
|
|
for i, typ in enumerate(types) for j in range(2)]
|
2020-07-10 09:58:35 -07:00
|
|
|
result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
|
2020-05-14 19:17:44 -04:00
|
|
|
multiple_results=False)(subc, *params)
|
|
|
|
comparator = subc.build(result)
|
2020-06-26 18:40:00 +01:00
|
|
|
out = xops.Sort(c, operands, dimension=dimension, is_stable=is_stable,
|
2020-05-14 19:17:44 -04:00
|
|
|
comparator=comparator)
|
2020-05-14 11:13:15 -04:00
|
|
|
return out if len(operands) != 1 else xops.Tuple(c, [out])
|
|
|
|
|
2020-07-10 09:58:35 -07:00
|
|
|
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
|
2020-05-14 11:13:15 -04:00
|
|
|
shape = primals[0].shape
|
|
|
|
iotas = []
|
|
|
|
for dim, size in enumerate(shape):
|
2020-07-14 13:05:31 -07:00
|
|
|
dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64
|
2020-05-14 11:13:15 -04:00
|
|
|
iotas.append(broadcasted_iota(dtype, shape, dim))
|
2020-06-26 18:40:00 +01:00
|
|
|
primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension,
|
2020-07-10 09:58:35 -07:00
|
|
|
is_stable=is_stable, num_keys=num_keys)
|
2020-05-14 11:13:15 -04:00
|
|
|
idx = tuple(primals[-1] if i == dimension else iotas[i]
|
|
|
|
for i in range(len(shape)))
|
2020-05-27 13:57:47 +00:00
|
|
|
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
|
2020-05-14 11:13:15 -04:00
|
|
|
return tuple(primals[:-1]), tangents_out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-10 09:58:35 -07:00
|
|
|
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
|
2020-05-14 11:13:15 -04:00
|
|
|
prototype_arg, new_bdim = next(
|
|
|
|
(a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
|
|
|
|
new_args = []
|
|
|
|
for arg, bdim in zip(batched_args, batch_dims):
|
|
|
|
if bdim is None:
|
2020-07-14 13:05:31 -07:00
|
|
|
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
|
2020-05-14 11:13:15 -04:00
|
|
|
new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
|
|
|
|
else:
|
|
|
|
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
|
|
|
|
new_dimension = dimension + (new_bdim <= dimension)
|
|
|
|
bdims = (new_bdim,) * len(new_args)
|
2020-07-10 09:58:35 -07:00
|
|
|
return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),
|
2020-06-26 18:40:00 +01:00
|
|
|
bdims)
|
2020-05-14 11:13:15 -04:00
|
|
|
|
|
|
|
|
|
|
|
sort_p = Primitive('sort')
|
|
|
|
sort_p.multiple_results = True
|
|
|
|
sort_p.def_impl(partial(xla.apply_primitive, sort_p))
|
|
|
|
sort_p.def_abstract_eval(_sort_abstract_eval)
|
|
|
|
xla.translations[sort_p] = _sort_translation_rule
|
|
|
|
ad.primitive_jvps[sort_p] = _sort_jvp
|
2019-08-01 12:39:33 -04:00
|
|
|
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _top_k_abstract_eval(operand, *, k):
|
2020-02-24 07:31:46 -08:00
|
|
|
if k < 0:
|
|
|
|
raise ValueError("k argument to top_k must be nonnegative, got {}".format(k))
|
2020-02-20 17:15:25 -08:00
|
|
|
if len(operand.shape) == 0:
|
|
|
|
raise TypeError("top_k operand must have >= 1 dimension, got {}"
|
|
|
|
.format(operand.shape))
|
2020-02-24 07:31:46 -08:00
|
|
|
shape = list(operand.shape)
|
|
|
|
if shape[-1] < k:
|
|
|
|
msg = "k argument to top_k must be no larger than minor dimension; {} vs {}"
|
|
|
|
raise ValueError(msg.format(k, shape))
|
|
|
|
shape[-1] = k
|
|
|
|
return (ShapedArray(shape, operand.dtype),
|
2020-07-14 13:05:31 -07:00
|
|
|
ShapedArray(shape, np.dtype(np.int32)))
|
2020-02-20 17:15:25 -08:00
|
|
|
|
2020-04-19 11:49:15 -07:00
|
|
|
def _top_k_jvp(primals, tangents, *, k):
|
|
|
|
operand, = primals
|
|
|
|
tangent, = tangents
|
|
|
|
primals_out = top_k(operand, k)
|
2020-05-27 13:57:47 +00:00
|
|
|
if type(tangent) is ad_util.Zero:
|
|
|
|
tangent_out = ad_util.Zero.from_value(primals_out[0])
|
2020-04-19 11:49:15 -07:00
|
|
|
else:
|
|
|
|
_, k_idxs = primals_out
|
|
|
|
idx_shape = k_idxs.shape
|
|
|
|
rank = len(idx_shape)
|
|
|
|
gather_index_shape = idx_shape + (1,)
|
|
|
|
gather_indices = []
|
|
|
|
for i in range(rank-1):
|
|
|
|
_iota = iota(k_idxs.dtype, idx_shape[i])
|
2020-07-30 12:59:36 -07:00
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
_iota = tie_in(operand, _iota)
|
2020-04-19 11:49:15 -07:00
|
|
|
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
|
|
|
|
gather_indices.append(_iota)
|
|
|
|
gather_indices.append(reshape(k_idxs, gather_index_shape))
|
|
|
|
gather_indices = concatenate(gather_indices, dimension=rank)
|
|
|
|
slice_sizes = (1,) * rank
|
|
|
|
dnums = GatherDimensionNumbers(
|
|
|
|
offset_dims=(),
|
|
|
|
collapsed_slice_dims=tuple(range(rank)),
|
|
|
|
start_index_map=tuple(range(rank)))
|
2020-05-27 13:57:47 +00:00
|
|
|
tangent_out = gather(tangent, gather_indices, dnums, slice_sizes)
|
|
|
|
return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1]))
|
2020-04-19 11:49:15 -07:00
|
|
|
|
|
|
|
def _top_k_batch_rule(batched_args, batch_dims, *, k):
|
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
if bdim == operand.ndim-1:
|
2020-07-14 13:05:31 -07:00
|
|
|
perm = np.arange(operand.ndim)
|
2020-04-19 11:49:15 -07:00
|
|
|
perm[bdim-1], perm[bdim] = perm[bdim], perm[bdim-1]
|
|
|
|
top_k_v, top_k_i = top_k(transpose(operand, perm), k=k)
|
|
|
|
return (transpose(top_k_v, perm),
|
|
|
|
transpose(top_k_i, perm)), (bdim, bdim)
|
|
|
|
else:
|
|
|
|
return top_k(operand, k=k), (bdim, bdim)
|
|
|
|
|
2020-02-20 17:15:25 -08:00
|
|
|
top_k_p = Primitive('top_k')
|
|
|
|
top_k_p.multiple_results = True
|
|
|
|
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
|
|
|
|
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
|
|
|
xla.translations[top_k_p] = partial(standard_translate, 'top_k')
|
2020-04-19 11:49:15 -07:00
|
|
|
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
|
|
|
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
|
2020-02-20 17:15:25 -08:00
|
|
|
|
2020-05-27 13:57:47 +00:00
|
|
|
def _tie_in_transpose_rule(t, x, y):
|
|
|
|
if ad.is_undefined_primal(x):
|
|
|
|
return [ad_util.Zero(x.aval), t]
|
|
|
|
else:
|
|
|
|
return [ad_util.Zero.from_value(x), t]
|
2018-12-13 07:24:14 -08:00
|
|
|
|
2019-02-01 13:42:16 -05:00
|
|
|
def _tie_in_batch_rule(batched_args, batch_dims):
|
2018-12-13 07:24:14 -08:00
|
|
|
y = tie_in(*batched_args)
|
|
|
|
_, bdim_y = batch_dims
|
|
|
|
return y, bdim_y
|
|
|
|
|
2020-06-01 13:24:40 -07:00
|
|
|
def _tie_in_impl(x, y):
|
|
|
|
core.check_valid_jaxtype(x)
|
|
|
|
core.check_valid_jaxtype(y)
|
|
|
|
return y
|
|
|
|
|
2018-12-13 07:24:14 -08:00
|
|
|
tie_in_p = Primitive('tie_in')
|
2020-06-01 13:24:40 -07:00
|
|
|
tie_in_p.def_impl(_tie_in_impl)
|
2019-11-22 10:53:11 -08:00
|
|
|
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
|
2018-12-13 07:24:14 -08:00
|
|
|
xla.translations[tie_in_p] = lambda c, x, y: y
|
2020-05-27 13:57:47 +00:00
|
|
|
ad.deflinear2(tie_in_p, _tie_in_transpose_rule)
|
2019-02-01 13:42:16 -05:00
|
|
|
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
|
2019-09-13 16:30:22 -07:00
|
|
|
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
|
2019-02-01 13:42:16 -05:00
|
|
|
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2019-02-04 16:31:24 -08:00
|
|
|
def _stop_gradient_jvp_rule(primals, tangents):
|
|
|
|
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
|
|
|
x, = primals
|
2020-05-27 13:57:47 +00:00
|
|
|
return stop_gradient(x), ad_util.Zero.from_value(x)
|
2019-01-31 22:08:51 -08:00
|
|
|
|
2019-02-04 16:31:24 -08:00
|
|
|
def _stop_gradient_batch_rule(batched_args, batch_dims):
|
|
|
|
x, = batched_args
|
|
|
|
dim, = batch_dims
|
|
|
|
return stop_gradient(x), dim
|
|
|
|
|
2020-04-23 13:12:24 -07:00
|
|
|
ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
|
|
|
|
batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
|
|
|
|
|
2019-01-31 22:08:51 -08:00
|
|
|
|
2019-10-09 15:05:54 -04:00
|
|
|
def create_token(x):
|
|
|
|
"""Creates an XLA token value with no preconditions for sequencing effects.
|
|
|
|
|
|
|
|
Experimental.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: a dummy argument used to tie the CreateToken operator into a trace. The
|
|
|
|
value of `x` is ignored.
|
|
|
|
"""
|
|
|
|
# x is a dummy argument used to tie the operator into a trace.
|
|
|
|
return create_token_p.bind(x)
|
|
|
|
|
|
|
|
create_token_p = Primitive("create_token")
|
|
|
|
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
|
|
|
|
create_token_p.def_abstract_eval(lambda _: abstract_token)
|
2020-04-23 18:30:47 -04:00
|
|
|
xla.translations[create_token_p] = lambda c, _: xops.CreateToken(c)
|
2019-10-09 15:05:54 -04:00
|
|
|
|
|
|
|
def after_all(*operands):
|
|
|
|
"""Merges one or more XLA token values. Experimental.
|
|
|
|
|
|
|
|
Wraps the XLA AfterAll operator."""
|
|
|
|
return after_all_p.bind(*operands)
|
|
|
|
|
|
|
|
def _after_all_abstract_eval(*operands):
|
|
|
|
if any(x is not abstract_token for x in operands):
|
|
|
|
raise TypeError("Arguments to after_all must be tokens")
|
|
|
|
return abstract_token
|
|
|
|
|
|
|
|
|
|
|
|
def _after_all_translation_rule(c, *operands):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.AfterAll(c, operands)
|
2019-10-09 15:05:54 -04:00
|
|
|
|
|
|
|
after_all_p = Primitive("after_all")
|
|
|
|
after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
|
|
|
|
after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
|
|
|
xla.translations[after_all_p] = _after_all_translation_rule
|
|
|
|
|
|
|
|
|
2020-06-01 12:35:18 -07:00
|
|
|
def infeed(token, shape=None, partitions=None):
|
2019-10-09 15:05:54 -04:00
|
|
|
"""Consumes an infeed value of `shape` from the host. Experimental.
|
|
|
|
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
2020-06-01 12:35:18 -07:00
|
|
|
`partitions` may be specifed inside a `sharded_jit` function.
|
2019-10-09 15:05:54 -04:00
|
|
|
"""
|
|
|
|
flat_shapes, treedef = pytree.flatten(shape)
|
|
|
|
for shape in flat_shapes:
|
|
|
|
if not isinstance(shape, ShapedArray):
|
2020-01-18 08:26:23 -05:00
|
|
|
raise TypeError("shape argument to infeed must be a pytree of "
|
|
|
|
"ShapedArray values, got {}".format(shape))
|
2020-06-01 12:35:18 -07:00
|
|
|
if partitions is not None:
|
|
|
|
# Always replicate token.
|
2020-06-03 15:23:49 -07:00
|
|
|
# We specifically use type() to raise an error for PartitionSpecs.
|
|
|
|
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
|
|
|
raise ValueError(f"'partitions' argument to infeed should be a tuple, "
|
|
|
|
f"got {partitions}")
|
|
|
|
partitions = partitions + (None,)
|
2020-06-01 12:35:18 -07:00
|
|
|
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes),
|
|
|
|
partitions=partitions)
|
2019-10-09 15:05:54 -04:00
|
|
|
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
|
|
|
|
|
2020-06-01 12:35:18 -07:00
|
|
|
def _infeed_abstract_eval(token, *, shapes, partitions):
|
2019-10-09 15:05:54 -04:00
|
|
|
if token is not abstract_token:
|
|
|
|
raise TypeError("First argument to infeed must be a token")
|
|
|
|
return shapes + (abstract_token,)
|
|
|
|
|
|
|
|
|
2020-06-01 12:35:18 -07:00
|
|
|
def _infeed_translation_rule(c, token, *, shapes, partitions):
|
2020-04-23 18:30:47 -04:00
|
|
|
shape = tuple(xla.aval_to_xla_shape(x).with_major_to_minor_layout_if_absent()
|
|
|
|
for x in shapes)
|
2020-06-01 12:35:18 -07:00
|
|
|
build_infeed = partial(xops.InfeedWithToken, token,
|
|
|
|
xla_client.Shape.tuple_shape(shape))
|
|
|
|
if partitions:
|
|
|
|
xs_and_token = xb.with_sharding(c, partitions, build_infeed)
|
|
|
|
else:
|
|
|
|
# Note that infeed will default to replication if inside a sharded
|
|
|
|
# computation and no sharding is specified.
|
|
|
|
xs_and_token = build_infeed()
|
2020-04-23 18:30:47 -04:00
|
|
|
xs = xops.GetTupleElement(xs_and_token, 0)
|
|
|
|
token = xops.GetTupleElement(xs_and_token, 1)
|
|
|
|
outs = [xops.GetTupleElement(xs, i) for i in range(len(shapes))] + [token]
|
|
|
|
return xops.Tuple(c, outs)
|
2019-10-09 15:05:54 -04:00
|
|
|
|
|
|
|
infeed_p = Primitive("infeed")
|
|
|
|
infeed_p.multiple_results = True
|
|
|
|
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
|
|
|
|
infeed_p.def_abstract_eval(_infeed_abstract_eval)
|
|
|
|
xla.translations[infeed_p] = _infeed_translation_rule
|
|
|
|
|
|
|
|
def outfeed(token, xs):
|
|
|
|
"""Outfeeds value `xs` to the host. Experimental.
|
|
|
|
|
|
|
|
`token` is used to sequence infeed and outfeed effects.
|
|
|
|
"""
|
|
|
|
flat_xs, _ = pytree.flatten(xs)
|
|
|
|
return outfeed_p.bind(token, *flat_xs)
|
|
|
|
|
|
|
|
def _outfeed_abstract_eval(token, *xs):
|
|
|
|
if token is not abstract_token:
|
|
|
|
raise TypeError("First argument to outfeed must be a token")
|
|
|
|
return abstract_token
|
|
|
|
|
|
|
|
|
|
|
|
def _outfeed_translation_rule(c, token, *xs):
|
2020-04-23 18:30:47 -04:00
|
|
|
t = xops.Tuple(c, xs)
|
2020-05-11 17:43:55 -04:00
|
|
|
return xops.OutfeedWithToken(t, token, c.get_shape(t))
|
2019-10-09 15:05:54 -04:00
|
|
|
|
|
|
|
outfeed_p = Primitive("outfeed")
|
|
|
|
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
|
|
|
|
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
|
|
|
|
xla.translations[outfeed_p] = _outfeed_translation_rule
|
2019-01-31 22:08:51 -08:00
|
|
|
|
2020-01-24 16:58:00 -05:00
|
|
|
def rng_uniform(a, b, shape):
|
|
|
|
"""Stateful PRNG generator. Experimental and its use is discouraged.
|
|
|
|
|
|
|
|
Returns uniformly distributed random numbers in the range [a, b)
|
|
|
|
|
|
|
|
You should use jax.random for most purposes; this function exists only for
|
|
|
|
niche use cases with special performance requirements.
|
|
|
|
|
|
|
|
This API may be removed at any time.
|
|
|
|
"""
|
|
|
|
return rng_uniform_p.bind(a, b, shape=tuple(shape))
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _rng_uniform_abstract_eval(a, b, *, shape):
|
2020-01-24 16:58:00 -05:00
|
|
|
if a.dtype != b.dtype:
|
|
|
|
raise ValueError(
|
|
|
|
"Arguments to rng_uniform must have identical dtypes, got {} "
|
|
|
|
"and {}.".format(a.dtype, b.dtype))
|
|
|
|
if a.shape != () or b.shape != ():
|
|
|
|
raise ValueError(
|
|
|
|
"Arguments to rng_uniform must be scalars; got shapes {} and {}."
|
|
|
|
.format(a.shape, b.shape))
|
|
|
|
return ShapedArray(shape, a.dtype)
|
|
|
|
|
2020-04-07 09:38:10 -04:00
|
|
|
def _rng_uniform_translation_rule(c, a, b, *, shape):
|
2020-05-11 17:43:55 -04:00
|
|
|
xla_shape = xc.Shape.array_shape(c.get_shape(a).xla_element_type(), shape)
|
2020-04-24 13:43:04 -07:00
|
|
|
return xops.RngUniform(a, b, xla_shape)
|
2020-01-24 16:58:00 -05:00
|
|
|
|
|
|
|
rng_uniform_p = Primitive("rng_uniform")
|
|
|
|
rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p))
|
|
|
|
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
|
|
|
xla.translations[rng_uniform_p] = _rng_uniform_translation_rule
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
### util
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
_ndim = np.ndim
|
2018-12-13 15:29:39 -05:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def _dilate_shape(shape, dilation):
|
|
|
|
"""Utility function for computing the shape resulting from a dilation."""
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.greater(dilation, 0)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "All dilations must be positive, got {}."
|
|
|
|
raise TypeError(msg.format(dilation))
|
|
|
|
dilation = (1,) * (len(shape) - len(dilation)) + tuple(dilation)
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.where(shape == 0, 0,
|
|
|
|
np.multiply(dilation, np.subtract(shape, 1)) + 1)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
2020-01-16 00:36:00 +00:00
|
|
|
def _ceil_divide(x1, x2):
|
2020-07-14 13:05:31 -07:00
|
|
|
return -np.floor_divide(np.negative(x1), x2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def padtype_to_pads(in_shape, window_shape, window_strides, padding):
|
|
|
|
"""Convert padding string to list of pairs of pad values."""
|
2019-03-29 11:09:56 -04:00
|
|
|
PaddingType = xla_client.PaddingType
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
if isinstance(padding, str):
|
|
|
|
mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME}
|
|
|
|
try:
|
|
|
|
padding = mapping[padding.upper()]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "Unrecognized padding type: expected 'VALID' or 'SAME', got {}."
|
2020-03-09 22:06:12 +02:00
|
|
|
raise RuntimeError(msg.format(padding)) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
if padding == PaddingType.SAME:
|
Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
2020-01-16 00:36:00 +00:00
|
|
|
out_shape = _ceil_divide(in_shape, window_strides)
|
2020-07-14 13:05:31 -07:00
|
|
|
pad_sizes = np.maximum(0, (out_shape - 1) * window_strides +
|
Implement shapecheck for more primitives (#1990)
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
2020-01-16 00:36:00 +00:00
|
|
|
window_shape - in_shape)
|
2018-11-17 18:03:33 -08:00
|
|
|
return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
|
|
|
|
elif padding == PaddingType.VALID:
|
|
|
|
return [(0, 0)] * len(in_shape)
|
|
|
|
else:
|
|
|
|
msg = "Unknown padding type: {}."
|
|
|
|
raise TypeError(msg.format(padding))
|
|
|
|
|
|
|
|
|
2019-11-15 10:02:51 -05:00
|
|
|
def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Check that dtypes agree, possibly ignoring float precision."""
|
|
|
|
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
|
|
|
# allows mixed floating point precision, but the HLO verifier often rejects it
|
2020-07-14 13:05:31 -07:00
|
|
|
types = list(map(np.dtype, ttypes)) # canonicalize
|
2018-11-17 18:03:33 -08:00
|
|
|
if ignore_fp_precision:
|
2019-11-15 10:02:51 -05:00
|
|
|
types = [
|
2020-07-14 13:05:31 -07:00
|
|
|
np.floating if dtypes.issubdtype(dtype, np.floating)
|
|
|
|
else np.complexfloating if dtypes.issubdtype(dtype, np.complexfloating)
|
2019-11-15 10:02:51 -05:00
|
|
|
else dtype for dtype in types]
|
|
|
|
if len({dtypes.canonicalize_dtype(t) for t in types}) != 1:
|
2018-11-17 18:03:33 -08:00
|
|
|
if ignore_fp_precision:
|
|
|
|
msg = ("{} requires arguments to have same dtypes up to floating point "
|
|
|
|
"precision, got {}.")
|
|
|
|
else:
|
|
|
|
msg = "{} requires arguments to have the same dtypes, got {}."
|
2019-11-15 10:02:51 -05:00
|
|
|
raise TypeError(msg.format(name, ", ".join(map(str, types))))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2018-12-05 08:22:27 -08:00
|
|
|
def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides):
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Check that conv shapes are valid and are consistent with window_strides."""
|
|
|
|
if len(lhs_shape) != len(rhs_shape):
|
|
|
|
msg = "Arguments to {} must have same rank, got {} and {}."
|
|
|
|
raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape)))
|
|
|
|
if len(lhs_shape) < 2:
|
|
|
|
msg = "Arguments to {} must have rank at least 2, got {} and {}."
|
2018-12-05 08:22:27 -08:00
|
|
|
raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape)))
|
2018-11-17 18:03:33 -08:00
|
|
|
if lhs_shape[1] != rhs_shape[1]:
|
|
|
|
msg = "Arguments to {} must agree on input feature size, got {} and {}."
|
2018-12-05 08:22:27 -08:00
|
|
|
raise TypeError(msg.format(name, lhs_shape[1], rhs_shape[1]))
|
|
|
|
_check_shapelike(name, "window_strides", window_strides)
|
2020-07-14 13:05:31 -07:00
|
|
|
if not np.all(np.greater(window_strides, 0)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "All elements of window_strides must be positive, got {}."
|
|
|
|
raise TypeError(msg.format(window_strides))
|
|
|
|
if len(window_strides) != len(lhs_shape) - 2:
|
|
|
|
msg = "{} window_strides has wrong length: expected {}, got {}."
|
|
|
|
expected_length = len(lhs_shape) - 2
|
2018-12-05 08:22:27 -08:00
|
|
|
raise TypeError(msg.format(name, expected_length, len(window_strides)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2020-04-09 16:21:30 -04:00
|
|
|
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
|
2018-11-17 18:03:33 -08:00
|
|
|
"""Compute the shape tuple of a conv given input shapes in canonical order."""
|
|
|
|
if isinstance(pads, str):
|
|
|
|
pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
|
|
|
|
if len(pads) != len(lhs_shape) - 2:
|
|
|
|
msg = "Wrong number of explicit pads for convolution: expected {}, got {}."
|
|
|
|
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
lhs_padded = np.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2),
|
2020-01-09 14:36:37 -05:00
|
|
|
axis=1))
|
2020-07-14 13:05:31 -07:00
|
|
|
out_space = np.floor_divide(
|
|
|
|
np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
|
|
|
|
out_space = np.maximum(0, out_space)
|
2020-04-09 16:21:30 -04:00
|
|
|
assert lhs_shape[0] % batch_group_count == 0
|
|
|
|
out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0])
|
|
|
|
return tuple(out_shape + tuple(out_space))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
|
|
|
|
dimension_numbers):
|
|
|
|
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
|
2020-07-14 13:05:31 -07:00
|
|
|
lhs_trans = np.take(lhs_shape, lhs_perm)
|
|
|
|
rhs_trans = np.take(rhs_shape, rhs_perm)
|
2018-11-17 18:03:33 -08:00
|
|
|
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.take(out_trans, np.argsort(out_perm)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2019-04-09 22:59:03 -07:00
|
|
|
def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
|
2020-04-09 16:21:30 -04:00
|
|
|
dimension_numbers):
|
2019-04-09 22:59:03 -07:00
|
|
|
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
|
2020-07-14 13:05:31 -07:00
|
|
|
lhs_trans = np.take(lhs_shape, lhs_perm)
|
|
|
|
rhs_trans = np.take(rhs_shape, rhs_perm)
|
2019-04-09 22:59:03 -07:00
|
|
|
if isinstance(padding, str):
|
|
|
|
padding = [_conv_transpose_padding(k, s, padding)
|
|
|
|
for k,s in zip(rhs_trans[2:], window_strides)]
|
2020-07-14 13:05:31 -07:00
|
|
|
padding = list(map(np.sum, padding))
|
2019-04-09 22:59:03 -07:00
|
|
|
unpad_out_space = [(i-1) * s - k + 2
|
|
|
|
for i, k, s in zip(lhs_trans[2:],
|
|
|
|
rhs_trans[2:],
|
|
|
|
window_strides)]
|
2020-07-14 13:05:31 -07:00
|
|
|
out_space = np.sum([unpad_out_space, padding], axis=0).tolist()
|
2019-04-09 22:59:03 -07:00
|
|
|
out_trans = tuple((lhs_trans[0], rhs_trans[0]) + tuple(out_space))
|
2020-07-14 13:05:31 -07:00
|
|
|
return tuple(np.take(out_trans, np.argsort(out_perm)))
|
2019-04-09 22:59:03 -07:00
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def _check_shapelike(fun_name, arg_name, obj):
|
|
|
|
"""Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints)."""
|
2020-07-14 13:05:31 -07:00
|
|
|
if not isinstance(obj, (tuple, list, np.ndarray)):
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "{} {} must be of type tuple/list/ndarray, got {}."
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, type(obj)))
|
|
|
|
# bool(obj) for an ndarray raises an error, so we check len
|
|
|
|
if not len(obj): # pylint: disable=g-explicit-length-test
|
|
|
|
return
|
2020-07-14 13:05:31 -07:00
|
|
|
obj_arr = np.array(obj)
|
2018-11-17 18:03:33 -08:00
|
|
|
if obj_arr.ndim != 1:
|
|
|
|
msg = "{} {} must be rank 1, got {}."
|
|
|
|
raise TypeError(msg.format(obj_arr.ndim))
|
2020-05-01 21:34:29 +02:00
|
|
|
try:
|
|
|
|
canonicalize_shape(obj_arr)
|
|
|
|
except TypeError:
|
2018-11-17 18:03:33 -08:00
|
|
|
msg = "{} {} must have every element be an integer type, got {}."
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj))))
|
|
|
|
if not (obj_arr >= 0).all():
|
|
|
|
msg = "{} {} must have every element be nonnegative, got {}."
|
|
|
|
raise TypeError(msg.format(fun_name, arg_name, obj))
|
|
|
|
|
|
|
|
|
|
|
|
def _dynamic_slice_indices(operand, start_indices):
|
2019-08-15 11:26:30 -04:00
|
|
|
if len(start_indices) != operand.ndim:
|
|
|
|
msg = ("Length of slice indices must match number of operand dimensions ({} "
|
|
|
|
"vs {})")
|
2020-01-18 08:26:23 -05:00
|
|
|
raise ValueError(msg.format(len(start_indices), operand.shape))
|
2019-01-07 12:28:52 -08:00
|
|
|
# map int over operand.shape to raise any dynamic-shape errors
|
2020-06-02 14:37:32 +01:00
|
|
|
safe_map(int, operand.shape)
|
|
|
|
if not isinstance(start_indices, (tuple, list)):
|
|
|
|
if start_indices.ndim != 1:
|
|
|
|
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
|
|
|
.format(start_indices.shape))
|
|
|
|
return select(lt(start_indices, _zeros(start_indices)),
|
|
|
|
add(start_indices, _const(start_indices, operand.shape)),
|
|
|
|
start_indices)
|
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
|
|
|
|
if isinstance(i, (int, np.integer))
|
2020-06-02 14:37:32 +01:00
|
|
|
else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i)
|
|
|
|
for i, d in zip(start_indices, operand.shape)]
|
2019-08-15 11:26:30 -04:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2018-12-13 07:24:14 -08:00
|
|
|
def _const(example, val):
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
if dtypes.is_python_scalar(example):
|
|
|
|
return dtypes.scalar_type_of(example)(val)
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.array(val, _dtype(example))
|
2018-12-13 07:24:14 -08:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
_zeros: Callable = partial(full_like, fill_value=0)
|
|
|
|
_zero: Callable = partial(full_like, shape=(), fill_value=0)
|
|
|
|
_ones: Callable = partial(full_like, fill_value=1)
|
|
|
|
_one: Callable = partial(full_like, shape=(), fill_value=1)
|
|
|
|
_twos: Callable = partial(full_like, fill_value=2)
|
|
|
|
_two: Callable = partial(full_like, shape=(), fill_value=2)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
dtype: Callable = dtypes.result_type
|
|
|
|
_dtype: Callable = dtypes.result_type
|
|
|
|
|
|
|
|
def _iscomplex(x) -> bool:
|
2020-07-14 13:05:31 -07:00
|
|
|
return dtypes.issubdtype(_dtype(x), np.complexfloating)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def ranges_like(*xs):
|
|
|
|
start = 0
|
|
|
|
for x in xs:
|
|
|
|
x_len = len(x)
|
|
|
|
yield range(start, start + x_len)
|
|
|
|
start += x_len
|
|
|
|
|
|
|
|
|
|
|
|
def remaining(original, *removed_lists):
|
2020-07-10 09:29:06 -07:00
|
|
|
removed = set(itertools.chain(*removed_lists))
|
|
|
|
return [i for i in original if i not in removed]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-06-28 09:00:32 -04:00
|
|
|
|
|
|
|
def _canonicalize_precision(precision):
|
|
|
|
if precision is None:
|
|
|
|
return None
|
|
|
|
if isinstance(precision, Precision):
|
|
|
|
return precision
|
|
|
|
else:
|
2019-06-28 12:48:44 -04:00
|
|
|
msg = "Precision argument must be None or a lax.Precision value; got {}"
|
2019-06-28 09:00:32 -04:00
|
|
|
raise ValueError(msg.format(precision))
|
|
|
|
|
|
|
|
|
2018-12-10 17:18:56 -08:00
|
|
|
def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers):
|
2019-02-19 21:28:01 -05:00
|
|
|
"""Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`.
|
2018-12-10 17:18:56 -08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
lhs_shape: tuple of nonnegative integers, shape of the convolution input.
|
|
|
|
rhs_shape: tuple of nonnegative integers, shape of the convolution kernel.
|
2020-05-04 19:02:13 +01:00
|
|
|
dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers
|
|
|
|
object following the convolution dimension number specification format in
|
|
|
|
xla_client.py.
|
2018-12-10 17:18:56 -08:00
|
|
|
|
|
|
|
Returns:
|
2019-02-19 21:28:01 -05:00
|
|
|
A `ConvDimensionNumbers` object that represents `dimension_numbers` in the
|
|
|
|
canonical form used by lax functions.
|
2018-12-10 17:18:56 -08:00
|
|
|
"""
|
2020-05-04 19:02:13 +01:00
|
|
|
if isinstance(dimension_numbers, ConvDimensionNumbers):
|
|
|
|
return dimension_numbers
|
2018-12-10 17:18:56 -08:00
|
|
|
if len(lhs_shape) != len(rhs_shape):
|
|
|
|
msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}."
|
|
|
|
raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))
|
|
|
|
|
|
|
|
if dimension_numbers is None:
|
|
|
|
iota = tuple(range(len(lhs_shape)))
|
|
|
|
return ConvDimensionNumbers(iota, iota, iota)
|
|
|
|
elif isinstance(dimension_numbers, (list, tuple)):
|
|
|
|
if len(dimension_numbers) != 3:
|
|
|
|
msg = "convolution dimension_numbers list/tuple must be length 3, got {}."
|
|
|
|
raise TypeError(msg.format(len(dimension_numbers)))
|
|
|
|
if not all(isinstance(elt, str) for elt in dimension_numbers):
|
|
|
|
msg = "convolution dimension_numbers elements must be strings, got {}."
|
|
|
|
raise TypeError(msg.format(tuple(map(type, dimension_numbers))))
|
|
|
|
msg = ("convolution dimension_numbers[{}] must have len equal to the ndim "
|
|
|
|
"of lhs and rhs, got {} for lhs and rhs shapes {} and {}.")
|
|
|
|
for i, elt in enumerate(dimension_numbers):
|
|
|
|
if len(elt) != len(lhs_shape):
|
|
|
|
raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape))
|
|
|
|
|
|
|
|
lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers)
|
|
|
|
return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
|
|
|
else:
|
|
|
|
msg = "convolution dimension_numbers must be tuple/list or None, got {}."
|
|
|
|
raise TypeError(msg.format(type(dimension_numbers)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2018-12-10 17:18:56 -08:00
|
|
|
def conv_general_permutations(dimension_numbers):
|
|
|
|
"""Utility for convolution dimension permutations relative to Conv HLO."""
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
2018-12-10 17:18:56 -08:00
|
|
|
lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C")
|
|
|
|
for i, (a, b) in enumerate(charpairs):
|
|
|
|
if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1:
|
|
|
|
msg = ("convolution dimension_numbers[{}] must contain the characters "
|
2020-01-22 19:18:00 -06:00
|
|
|
"'{}' and '{}' exactly once, got {}.")
|
2018-12-10 17:18:56 -08:00
|
|
|
raise TypeError(msg.format(i, a, b, dimension_numbers[i]))
|
|
|
|
if len(dimension_numbers[i]) != len(set(dimension_numbers[i])):
|
|
|
|
msg = ("convolution dimension_numbers[{}] cannot have duplicate "
|
|
|
|
"characters, got {}.")
|
|
|
|
raise TypeError(msg.format(i, dimension_numbers[i]))
|
|
|
|
if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) ==
|
|
|
|
set(out_spec) - set(out_char)):
|
|
|
|
msg = ("convolution dimension_numbers elements must each have the same "
|
|
|
|
"set of spatial characters, got {}.")
|
|
|
|
raise TypeError(msg.format(dimension_numbers))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-10 17:18:56 -08:00
|
|
|
def getperm(spec, charpair):
|
|
|
|
spatial = (i for i, c in enumerate(spec) if c not in charpair)
|
|
|
|
if spec is not rhs_spec:
|
|
|
|
spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i]))
|
|
|
|
return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial)
|
|
|
|
|
|
|
|
lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs)
|
|
|
|
return lhs_perm, rhs_perm, out_perm
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _conv_general_proto(dimension_numbers):
|
2018-12-10 17:18:56 -08:00
|
|
|
assert type(dimension_numbers) is ConvDimensionNumbers
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
2019-07-29 15:21:47 -04:00
|
|
|
proto = xla_client.ConvolutionDimensionNumbers()
|
2018-11-17 18:03:33 -08:00
|
|
|
proto.input_batch_dimension = lhs_spec[0]
|
|
|
|
proto.input_feature_dimension = lhs_spec[1]
|
|
|
|
proto.output_batch_dimension = out_spec[0]
|
|
|
|
proto.output_feature_dimension = out_spec[1]
|
|
|
|
proto.kernel_output_feature_dimension = rhs_spec[0]
|
|
|
|
proto.kernel_input_feature_dimension = rhs_spec[1]
|
|
|
|
proto.input_spatial_dimensions.extend(lhs_spec[2:])
|
|
|
|
proto.kernel_spatial_dimensions.extend(rhs_spec[2:])
|
|
|
|
proto.output_spatial_dimensions.extend(out_spec[2:])
|
|
|
|
return proto
|
|
|
|
|
|
|
|
|
|
|
|
def _conv_general_vjp_lhs_padding(
|
|
|
|
in_shape, window_dimensions, window_strides, out_shape, padding,
|
2020-06-02 10:27:14 -04:00
|
|
|
lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]:
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
|
2019-04-03 06:58:16 -07:00
|
|
|
rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation)
|
2018-11-17 18:03:33 -08:00
|
|
|
out_dilated_shape = _dilate_shape(out_shape, window_strides)
|
2020-07-14 13:05:31 -07:00
|
|
|
pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1
|
|
|
|
pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1
|
2018-11-17 18:03:33 -08:00
|
|
|
- out_dilated_shape - pad_before)
|
2020-06-02 10:27:14 -04:00
|
|
|
return safe_zip(pad_before, pad_after)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _conv_general_vjp_rhs_padding(
|
|
|
|
in_shape, window_dimensions, window_strides, out_shape, padding,
|
|
|
|
lhs_dilation, rhs_dilation):
|
|
|
|
lhs_dilated_shape = _dilate_shape(in_shape, lhs_dilation)
|
|
|
|
rhs_dilated_shape = _dilate_shape(window_dimensions, rhs_dilation)
|
|
|
|
out_dilated_shape = _dilate_shape(out_shape, window_strides)
|
|
|
|
total_in_pad = out_dilated_shape + rhs_dilated_shape - lhs_dilated_shape - 1
|
|
|
|
return [(pad[0], tot - pad[0]) for pad, tot in zip(padding, total_in_pad)]
|
|
|
|
|
|
|
|
|
|
|
|
def _balanced_eq(x, z, y):
|
|
|
|
return div(select(_eq_meet(x, z), _ones(z), _zeros(z)),
|
|
|
|
select(_eq_meet(y, z), _twos(z), _ones(z)))
|
|
|
|
|
|
|
|
|
|
|
|
def _eq_meet(a, b):
|
|
|
|
a_dtype, b_dtype = _dtype(a), _dtype(b)
|
|
|
|
if a_dtype != b_dtype:
|
2019-11-15 10:02:51 -05:00
|
|
|
higher_dtype = dtypes.promote_types(a_dtype, b_dtype)
|
2018-11-17 18:03:33 -08:00
|
|
|
if higher_dtype == a_dtype:
|
|
|
|
a = convert_element_type(a, b_dtype)
|
|
|
|
else:
|
|
|
|
b = convert_element_type(b, a_dtype)
|
|
|
|
return eq(a, b)
|
|
|
|
|
|
|
|
|
|
|
|
def _abstractify(x):
|
2019-07-27 15:46:14 -07:00
|
|
|
return raise_to_shaped(core.get_aval(x))
|
2019-08-22 09:22:57 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _check_user_dtype_supported(dtype, fun_name=None):
|
2020-07-14 13:05:31 -07:00
|
|
|
np_dtype = np.dtype(dtype)
|
|
|
|
if np_dtype.kind not in "biufc" and np_dtype.type != dtypes.bfloat16:
|
2020-04-29 14:14:49 -04:00
|
|
|
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
|
|
|
|
raise TypeError(msg)
|
2020-07-14 13:05:31 -07:00
|
|
|
if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
|
2019-08-22 09:22:57 -07:00
|
|
|
msg = ("Explicitly requested dtype {} {} is not available, "
|
|
|
|
"and will be truncated to dtype {}. To enable more dtypes, set the "
|
|
|
|
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
|
|
|
|
"environment variable. "
|
|
|
|
"See https://github.com/google/jax#current-gotchas for more.")
|
|
|
|
fun_name = "requested in {}".format(fun_name) if fun_name else ""
|
2019-11-15 10:02:51 -05:00
|
|
|
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
|
2019-08-22 09:22:57 -07:00
|
|
|
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
|
2020-05-14 11:13:15 -04:00
|
|
|
|
|
|
|
|
|
|
|
def _canonicalize_axis(axis, num_dims):
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
|
|
|
axis = operator.index(axis)
|
|
|
|
if not -num_dims <= axis < num_dims:
|
2020-05-14 11:13:15 -04:00
|
|
|
raise ValueError(
|
|
|
|
"axis {} is out of bounds for array of dimension {}".format(
|
|
|
|
axis, num_dims))
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
if axis < 0:
|
|
|
|
axis = axis + num_dims
|
2020-05-15 20:51:53 -07:00
|
|
|
return axis
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
|
|
|
|
@config.omnistaging_enablers.append
|
|
|
|
def omnistaging_enabler() -> None:
|
|
|
|
global _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
|
|
|
|
del _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
|