mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
make a lax package, revert control flow names (#607)
c.f. #597 pair=skyewm
This commit is contained in:
parent
4f0280fe36
commit
0cf14837c9
@ -307,10 +307,10 @@ Because `jit` aims to specialize Python functions only on shapes and dtypes
|
||||
during tracing, rather than on concrete values, Python control flow that depends
|
||||
on concrete values won’t be able to execute and will instead raise an error. If
|
||||
you want compiled control flow, use structured control flow primitives like
|
||||
lax_control_flow.cond and lax_control_flow.while_loop. Some indexing features,
|
||||
like slice-based indexing `A[i:i+5]` for argument-dependent `i`, or
|
||||
boolean-based indexing `A[bool_ind]` for argument-dependent `bool_ind`, produce
|
||||
abstract values of unknown shape and are thus unsupported in `jit` functions.
|
||||
lax.cond and lax.while_loop. Some indexing features, like slice-based indexing
|
||||
`A[i:i+5]` for argument-dependent `i`, or boolean-based indexing `A[bool_ind]`
|
||||
for argument-dependent `bool_ind`, produce abstract values of unknown shape and
|
||||
are thus unsupported in `jit` functions.
|
||||
|
||||
In general, JAX is intended to be used with a functional style of Python
|
||||
programming. Functions passed to transformations like `grad` and `jit` are
|
||||
|
@ -29,7 +29,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad, lax, lax_control_flow, random
|
||||
from jax import jit, grad, lax, random
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
|
||||
@ -117,7 +117,7 @@ if __name__ == "__main__":
|
||||
loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
|
||||
g = grad(loss)(optimizers.get_params(opt_state))
|
||||
return opt_update(i, g, opt_state)
|
||||
return lax_control_flow.fori_loop(0, num_batches, body_fun, opt_state)
|
||||
return lax.fori_loop(0, num_batches, body_fun, opt_state)
|
||||
|
||||
@jit
|
||||
def evaluate(opt_state, images):
|
||||
|
22
jax/lax/__init__.py
Normal file
22
jax/lax/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from .lax import *
|
||||
from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_reduce_and, _reduce_window_sum, _reduce_window_max,
|
||||
_reduce_window_min, _reduce_window_prod, _float, _complex,
|
||||
_input_dtype, _const, _eq_meet)
|
||||
from .lax_control_flow import *
|
||||
from .lax_parallel import *
|
@ -35,27 +35,27 @@ from six.moves import builtins, xrange
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from .util import partial, prod
|
||||
from ..util import partial, prod
|
||||
|
||||
from . import core
|
||||
from . import ad_util
|
||||
from . import api
|
||||
from . import linear_util as lu
|
||||
from .config import flags
|
||||
from .core import Primitive
|
||||
from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
array_types, make_shaped_array)
|
||||
from .api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
|
||||
pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple)
|
||||
from .interpreters import partial_eval as pe
|
||||
from .interpreters import xla
|
||||
from .interpreters import ad
|
||||
from .interpreters import batching
|
||||
from .interpreters import parallel
|
||||
from .util import curry, memoize, safe_zip, unzip2, prod
|
||||
from .tree_util import build_tree, tree_unflatten
|
||||
from .lib import xla_bridge
|
||||
from .lib.xla_bridge import xla_client
|
||||
from .. import core
|
||||
from .. import ad_util
|
||||
from .. import api
|
||||
from .. import linear_util as lu
|
||||
from ..config import flags
|
||||
from ..core import Primitive
|
||||
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
array_types, make_shaped_array)
|
||||
from ..api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
|
||||
pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple)
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import ad
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import parallel
|
||||
from ..util import curry, memoize, safe_zip, unzip2, prod
|
||||
from ..tree_util import build_tree, tree_unflatten
|
||||
from ..lib import xla_bridge
|
||||
from ..lib.xla_bridge import xla_client
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
@ -3795,7 +3795,7 @@ _one = partial(full_like, shape=(), fill_value=1)
|
||||
_twos = partial(full_like, fill_value=2)
|
||||
_two = partial(full_like, shape=(), fill_value=2)
|
||||
|
||||
_dtype = onp.result_type
|
||||
_dtype = dtype = onp.result_type
|
||||
_iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating)
|
||||
|
||||
|
@ -23,7 +23,7 @@ import numpy as onp
|
||||
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.lax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.abstract_arrays import ConcreteArray, ShapedArray, UnshapedArray
|
||||
from jax.api_util import (
|
@ -15,7 +15,7 @@
|
||||
Parallelization primitives.
|
||||
"""
|
||||
|
||||
from jax import lax
|
||||
from jax.lax import lax
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax.core import Primitive
|
||||
from jax.interpreters import ad
|
@ -21,7 +21,6 @@ import numpy as onp
|
||||
from jax.numpy import lax_numpy as np
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import lax_control_flow
|
||||
from jax import ad_util
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import ad
|
||||
@ -300,7 +299,7 @@ def lu_jvp_rule(primals, tangents):
|
||||
|
||||
a_shape = np.shape(a)
|
||||
m, n = a_shape[-2:]
|
||||
dtype = lax._dtype(a)
|
||||
dtype = lax.dtype(a)
|
||||
k = min(m, n)
|
||||
|
||||
permutation = lu_pivots_to_permutation(pivots, m)
|
||||
@ -377,7 +376,7 @@ def lu_pivots_to_permutation(swaps, k):
|
||||
|
||||
n, = np.shape(swaps)
|
||||
permutation = np.arange(k)
|
||||
_, permutation = lax_control_flow.fori_loop(
|
||||
_, permutation = lax.fori_loop(
|
||||
onp.array(0, onp.int32), onp.array(n, onp.int32), body_fn, (swaps, permutation))
|
||||
return permutation
|
||||
|
||||
|
@ -33,7 +33,6 @@ from .. import core
|
||||
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
|
||||
from ..interpreters.xla import DeviceArray
|
||||
from .. import lax
|
||||
from .. import lax_control_flow
|
||||
from ..util import memoize, partial, get_module_functions, unzip2, prod as _prod
|
||||
from ..lib import xla_bridge
|
||||
from ..lib.xla_bridge import xla_client
|
||||
@ -93,7 +92,7 @@ result_type = onp.result_type
|
||||
shape = _shape = onp.shape
|
||||
ndim = _ndim = onp.ndim
|
||||
size = onp.size
|
||||
_dtype = lax._dtype
|
||||
_dtype = lax.dtype
|
||||
|
||||
bool_ = onp.bool_
|
||||
uint8 = onp.uint8
|
||||
@ -398,7 +397,7 @@ def power(x1, x2):
|
||||
x1 = asarray(x1)
|
||||
x2 = asarray(x2)
|
||||
x1, x2 = _promote_args_like(onp.power, x1, x2)
|
||||
dtype = lax._dtype(x1)
|
||||
dtype = _dtype(x1)
|
||||
if not issubdtype(dtype, integer):
|
||||
return lax.pow(x1, x2)
|
||||
|
||||
@ -2165,8 +2164,8 @@ kaiser = onp.kaiser # TODO: lower via lax to allow non-constant beta.
|
||||
|
||||
@_wraps(getattr(onp, "gcd", None))
|
||||
def gcd(x1, x2):
|
||||
if (not issubdtype(lax._dtype(x1), integer) or
|
||||
not issubdtype(lax._dtype(x2), integer)):
|
||||
if (not issubdtype(_dtype(x1), integer) or
|
||||
not issubdtype(_dtype(x2), integer)):
|
||||
raise ValueError("Arguments to gcd must be integers.")
|
||||
def cond_fn(xs):
|
||||
x1, x2 = xs
|
||||
@ -2178,7 +2177,7 @@ def gcd(x1, x2):
|
||||
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
|
||||
x1, x2 = _promote_dtypes(lax.abs(x1), lax.abs(x2))
|
||||
x1, x2 = broadcast_arrays(x1, x2)
|
||||
gcd, _ = lax_control_flow.while_loop(cond_fn, body_fn, (x1, x2))
|
||||
gcd, _ = lax.while_loop(cond_fn, body_fn, (x1, x2))
|
||||
return gcd
|
||||
|
||||
|
||||
|
@ -63,7 +63,7 @@ def svd(a, full_matrices=True, compute_uv=True):
|
||||
@_wraps(onp.linalg.slogdet)
|
||||
def slogdet(a):
|
||||
a = _promote_arg_dtypes(np.asarray(a))
|
||||
dtype = lax._dtype(a)
|
||||
dtype = lax.dtype(a)
|
||||
a_shape = np.shape(a)
|
||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
||||
@ -140,7 +140,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
|
||||
elif ord == -np.inf:
|
||||
return np.amin(np.abs(x), axis=axis, keepdims=keepdims)
|
||||
elif ord == 0:
|
||||
return np.sum(x != 0, dtype=np.finfo(lax._dtype(x)).dtype,
|
||||
return np.sum(x != 0, dtype=np.finfo(lax.dtype(x)).dtype,
|
||||
axis=axis, keepdims=keepdims)
|
||||
elif ord == 1:
|
||||
# Numpy has a special case for ord == 1 as an optimization. We don't
|
||||
@ -228,7 +228,7 @@ def solve(a, b):
|
||||
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
|
||||
raise ValueError(msg.format(a_shape, b_shape))
|
||||
lu, pivots = lax_linalg.lu(a)
|
||||
dtype = lax._dtype(a)
|
||||
dtype = lax.dtype(a)
|
||||
|
||||
# TODO(phawkins): add unit_diagonal support to solve_triangular, use it here
|
||||
# instead of explicit masking of l.
|
||||
|
@ -43,7 +43,7 @@ def _scatter_update(x, idx, y, scatter_op):
|
||||
y = np.asarray(y)
|
||||
x_shape = np.shape(x)
|
||||
y_shape = np.shape(y)
|
||||
y = lax.convert_element_type(y, lax._dtype(x))
|
||||
y = lax.convert_element_type(y, lax.dtype(x))
|
||||
|
||||
if not isinstance(idx, tuple):
|
||||
idx = (idx,)
|
||||
|
@ -29,7 +29,6 @@ from functools import partial
|
||||
import numpy as onp
|
||||
|
||||
from . import lax
|
||||
from . import lax_control_flow
|
||||
from . import numpy as np
|
||||
from . import tree_util
|
||||
from .api import jit, vmap
|
||||
@ -77,7 +76,7 @@ def _make_rotate_left(dtype):
|
||||
nbits = onp.array(onp.iinfo(dtype).bits, dtype)
|
||||
|
||||
def _rotate_left(x, d):
|
||||
if lax._dtype(d) != lax._dtype(x):
|
||||
if lax.dtype(d) != lax.dtype(x):
|
||||
d = lax.convert_element_type(d, x.dtype)
|
||||
return (x << d) | lax.shift_right_logical(x, nbits - d)
|
||||
return _rotate_left
|
||||
@ -104,11 +103,11 @@ def threefry_2x32(keypair, count):
|
||||
"""
|
||||
# Based on ThreeFry2x32 by phawkins@ in //.../xla/client/lib/prng.cc
|
||||
key1, key2 = keypair
|
||||
if not lax._dtype(key1) == lax._dtype(key2) == lax._dtype(count) == onp.uint32:
|
||||
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
|
||||
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
||||
raise TypeError(msg.format([lax._dtype(x) for x in [key1, key2, count]]))
|
||||
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
||||
|
||||
rotate_left = _make_rotate_left(lax._dtype(count))
|
||||
rotate_left = _make_rotate_left(lax.dtype(count))
|
||||
|
||||
def apply_round(v, rot):
|
||||
v = v[:]
|
||||
@ -255,7 +254,7 @@ def _uniform(key, shape, dtype, minval, maxval):
|
||||
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
|
||||
# equivalent float representations, which might not be true on all platforms.
|
||||
float_bits = lax.bitwise_or(
|
||||
lax.shift_right_logical(bits, onp.array(nbits - nmant, lax._dtype(bits))),
|
||||
lax.shift_right_logical(bits, onp.array(nbits - nmant, lax.dtype(bits))),
|
||||
onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64))
|
||||
floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype)
|
||||
return lax.max(
|
||||
@ -399,7 +398,7 @@ def bernoulli(key, mean=onp.float32(0.5), shape=()):
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def _bernoulli(key, mean, shape):
|
||||
shape = shape or onp.shape(mean)
|
||||
if not onp.issubdtype(lax._dtype(mean), onp.float32):
|
||||
if not onp.issubdtype(lax.dtype(mean), onp.float32):
|
||||
mean = lax.convert_element_type(mean, onp.float32)
|
||||
if onp.shape(mean) != shape:
|
||||
mean = np.broadcast_to(mean, shape)
|
||||
@ -457,7 +456,7 @@ def _gamma_one(key, alpha):
|
||||
one_over_two = _constant_like(alpha, 0.5)
|
||||
one_over_three = _constant_like(alpha, 1. / 3.)
|
||||
squeeze_const = _constant_like(alpha, 0.0331)
|
||||
dtype = lax._dtype(alpha)
|
||||
dtype = lax.dtype(alpha)
|
||||
|
||||
key, subkey = split(key)
|
||||
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
|
||||
@ -472,7 +471,7 @@ def _gamma_one(key, alpha):
|
||||
|
||||
def _cond_fn(kXVU):
|
||||
_, X, V, U = kXVU
|
||||
# TODO: use lax_control_flow.cond when its batching rule is supported
|
||||
# TODO: use lax.cond when its batching rule is supported
|
||||
# The reason is to avoid evaluating second condition which involves log+log
|
||||
# if the first condition is satisfied
|
||||
cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))),
|
||||
@ -492,7 +491,7 @@ def _gamma_one(key, alpha):
|
||||
return key, X, V, U
|
||||
|
||||
# initial state is chosen such that _cond_fn will return True
|
||||
_, _, V, _ = lax_control_flow.while_loop(
|
||||
_, _, V, _ = lax.while_loop(
|
||||
_cond_fn, _body_fn, (key, zero, _constant_like(alpha, -1), zero))
|
||||
z = lax.mul(lax.mul(d, V), boost)
|
||||
return lax.select(lax.eq(z, zero), onp.finfo(z.dtype).tiny, z)
|
||||
|
@ -129,7 +129,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
a = np_linalg._promote_arg_dtypes(np.asarray(a))
|
||||
lu, pivots = lax_linalg.lu(a)
|
||||
dtype = lax._dtype(a)
|
||||
dtype = lax.dtype(a)
|
||||
m, n = np.shape(a)
|
||||
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
|
||||
p = np.real(np.array(permutation == np.arange(m)[:, None], dtype=dtype))
|
||||
|
@ -165,7 +165,7 @@ def ndtr(x):
|
||||
TypeError: if `x` is not floating-type.
|
||||
"""
|
||||
x = np.asarray(x)
|
||||
dtype = lax._dtype(x)
|
||||
dtype = lax.dtype(x)
|
||||
if dtype not in (np.float32, np.float64):
|
||||
raise TypeError(
|
||||
"x.dtype={} is not supported, see docstring for supported types."
|
||||
@ -175,7 +175,7 @@ def ndtr(x):
|
||||
|
||||
def _ndtr(x):
|
||||
"""Implements ndtr core logic."""
|
||||
dtype = lax._dtype(x).type
|
||||
dtype = lax.dtype(x).type
|
||||
half_sqrt_2 = dtype(0.5) * onp.sqrt(2., dtype=dtype)
|
||||
w = x * half_sqrt_2
|
||||
z = lax.abs(w)
|
||||
@ -206,7 +206,7 @@ def ndtri(p):
|
||||
TypeError: if `p` is not floating-type.
|
||||
"""
|
||||
x = np.asarray(p)
|
||||
dtype = lax._dtype(p)
|
||||
dtype = lax.dtype(p)
|
||||
if dtype not in (np.float32, np.float64):
|
||||
raise TypeError(
|
||||
"x.dtype={} is not supported, see docstring for supported types."
|
||||
@ -271,7 +271,7 @@ def _ndtri(p):
|
||||
2.89247864745380683936E-6,
|
||||
6.79019408009981274425E-9]))
|
||||
|
||||
dtype = lax._dtype(p).type
|
||||
dtype = lax.dtype(p).type
|
||||
shape = np.shape(p)
|
||||
|
||||
def _create_polynomial(var, coeffs):
|
||||
@ -391,7 +391,7 @@ def log_ndtr(x, series_order=3):
|
||||
raise ValueError("series_order must be <= 30.")
|
||||
|
||||
x = np.asarray(x)
|
||||
dtype = lax._dtype(x)
|
||||
dtype = lax.dtype(x)
|
||||
|
||||
if dtype == np.float64:
|
||||
lower_segment = _LOGNDTR_FLOAT64_LOWER
|
||||
@ -427,7 +427,7 @@ def log_ndtr(x, series_order=3):
|
||||
|
||||
def _log_ndtr_lower(x, series_order):
|
||||
"""Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
|
||||
dtype = lax._dtype(x).type
|
||||
dtype = lax.dtype(x).type
|
||||
x_2 = lax.square(x)
|
||||
# Log of the term multiplying (1 + sum)
|
||||
log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * onp.log(2. * onp.pi))
|
||||
@ -436,7 +436,7 @@ def _log_ndtr_lower(x, series_order):
|
||||
|
||||
def _log_ndtr_asymptotic_series(x, series_order):
|
||||
"""Calculates the asymptotic series used in log_ndtr."""
|
||||
dtype = lax._dtype(x).type
|
||||
dtype = lax.dtype(x).type
|
||||
if series_order <= 0:
|
||||
return onp.array(1, dtype)
|
||||
x_2 = lax.square(x)
|
||||
|
@ -233,7 +233,7 @@
|
||||
"\n",
|
||||
"__NB__: _Fancy Indexing_ is __not__ yet supported, but will likely be added to JAX soon.\n",
|
||||
"\n",
|
||||
"️⚠️ inside `jit`'d code and `lax_control_flow.while_loop` or `lax_control_flow.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
|
||||
"️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1166,10 +1166,10 @@
|
||||
"## Structured control flow primitives\n",
|
||||
"\n",
|
||||
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:\n",
|
||||
" - `lax_control_flow.cond` _will be differentiable soon_\n",
|
||||
" - `lax_control_flow.while_loop` __non-differentiable__*\n",
|
||||
" - `lax_control_flow.fori_loop` __non-differentiable__*\n",
|
||||
" - `lax_control_flow.scan` _will be differentiable soon_\n",
|
||||
" - `lax.cond` _will be differentiable soon_\n",
|
||||
" - `lax.while_loop` __non-differentiable__*\n",
|
||||
" - `lax.fori_loop` __non-differentiable__*\n",
|
||||
" - `lax.scan` _will be differentiable soon_\n",
|
||||
"\n",
|
||||
"*_these can in principle be made to be __forward__-differentiable, but this isn't on the current roadmap._"
|
||||
]
|
||||
@ -1208,9 +1208,9 @@
|
||||
"from jax import lax\n",
|
||||
"\n",
|
||||
"operand = np.array([0.])\n",
|
||||
"lax_control_flow.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
"# --> array([1.], dtype=float32)\n",
|
||||
"lax_control_flow.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
"# --> array([-1.], dtype=float32)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -1263,7 +1263,7 @@
|
||||
"init_val = 0\n",
|
||||
"cond_fun = lambda x: x<10\n",
|
||||
"body_fun = lambda x: x+1\n",
|
||||
"lax_control_flow.while_loop(cond_fun, body_fun, init_val)\n",
|
||||
"lax.while_loop(cond_fun, body_fun, init_val)\n",
|
||||
"# --> array(10, dtype=int32)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -1316,7 +1316,7 @@
|
||||
"start = 0\n",
|
||||
"stop = 10\n",
|
||||
"body_fun = lambda i,x: x+i\n",
|
||||
"lax_control_flow.fori_loop(start, stop, body_fun, init_val)\n",
|
||||
"lax.fori_loop(start, stop, body_fun, init_val)\n",
|
||||
"# --> array(45, dtype=int32)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -1354,10 +1354,10 @@
|
||||
"\\textrm{if} & ❌ & ✔ \\\\\n",
|
||||
"\\textrm{for} & ✔* & ✔\\\\\n",
|
||||
"\\textrm{while} & ✔* & ✔\\\\\n",
|
||||
"\\textrm{lax_control_flow.cond} & ✔ & \\textrm{soon!}\\\\\n",
|
||||
"\\textrm{lax_control_flow.while_loop} & ✔ & ❌\\\\\n",
|
||||
"\\textrm{lax_control_flow.fori_loop} & ✔ & ❌\\\\\n",
|
||||
"\\textrm{lax_control_flow.scan} & \\textrm{soon!} & \\textrm{soon!}\\\\\n",
|
||||
"\\textrm{lax.cond} & ✔ & \\textrm{soon!}\\\\\n",
|
||||
"\\textrm{lax.while_loop} & ✔ & ❌\\\\\n",
|
||||
"\\textrm{lax.fori_loop} & ✔ & ❌\\\\\n",
|
||||
"\\textrm{lax.scan} & \\textrm{soon!} & \\textrm{soon!}\\\\\n",
|
||||
"\\hline\n",
|
||||
"\\end{array}\n",
|
||||
"$$\n",
|
||||
|
@ -24,7 +24,6 @@ import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax import lax
|
||||
from jax import lax_control_flow
|
||||
from jax import lax_linalg
|
||||
from jax import random
|
||||
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr, jacfwd, jacrev, hessian
|
||||
@ -875,7 +874,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
def testWhileLoop(self):
|
||||
def fun(x):
|
||||
return lax_control_flow.while_loop(lambda x: x < 3, lambda x: x + 2, x)
|
||||
return lax.while_loop(lambda x: x < 3, lambda x: x + 2, x)
|
||||
|
||||
ans = vmap(fun)(onp.array([0, 1, 2, 3]))
|
||||
expected = onp.array([4, 3, 4, 3])
|
||||
@ -888,7 +887,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
def testWhileLoopCondConstsBatched(self):
|
||||
def fun(x, y):
|
||||
return lax_control_flow.while_loop(lambda x: x < y, lambda x: x + 2, x)
|
||||
return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)
|
||||
|
||||
ans = vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3]))
|
||||
expected = onp.array([2, 4])
|
||||
@ -896,7 +895,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
def testWhileLoopBodyConstsBatched(self):
|
||||
def fun(x, y):
|
||||
return lax_control_flow.while_loop(lambda x: x < 3, lambda x: x + y, x)
|
||||
return lax.while_loop(lambda x: x < 3, lambda x: x + y, x)
|
||||
|
||||
ans = vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3]))
|
||||
expected = onp.array([4, 3])
|
||||
@ -913,7 +912,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
return x, y
|
||||
|
||||
def fun(x, y):
|
||||
return lax_control_flow.while_loop(cond_fun, body_fun, (x, y))
|
||||
return lax.while_loop(cond_fun, body_fun, (x, y))
|
||||
|
||||
ans = vmap(fun)(onp.array([0, 0]), onp.array([1, 2]))
|
||||
expected = (onp.array([4, 3]), onp.array([1, 2]))
|
||||
@ -927,7 +926,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
return x, y
|
||||
|
||||
def fun(x):
|
||||
return lax_control_flow.fori_loop(0, 10, body_fun, (x, 0))
|
||||
return lax.fori_loop(0, 10, body_fun, (x, 0))
|
||||
|
||||
ans = vmap(fun)(onp.array([0, 1]))
|
||||
expected = (onp.array([10, 11]), onp.array([20, 20]))
|
||||
@ -941,7 +940,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
key, _ = random.split(key)
|
||||
return u, key
|
||||
|
||||
u, _ = lax_control_flow.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
|
||||
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
|
||||
return u
|
||||
|
||||
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash
|
||||
|
@ -25,7 +25,6 @@ import numpy.random as npr
|
||||
|
||||
from jax import api
|
||||
from jax import lax
|
||||
from jax import lax_control_flow
|
||||
from jax import test_util as jtu
|
||||
|
||||
|
||||
@ -43,7 +42,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return (lax.add(pos, 1), lax.add(count, 1))
|
||||
|
||||
def loop(init):
|
||||
result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0))
|
||||
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
||||
_, count = result
|
||||
return count
|
||||
|
||||
@ -66,7 +65,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return (num, lax.add(i, 1), inner_loop(i, count))
|
||||
|
||||
init_val = (num, 0, 0)
|
||||
_, i, count = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
|
||||
_, i, count = lax.while_loop(cond_fun, body_fun, init_val)
|
||||
return (i, count)
|
||||
|
||||
def inner_loop(i, count): # pylint: disable=missing-docstring
|
||||
@ -79,7 +78,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return (i, lax.add(j, 1), lax.add(count, 1))
|
||||
|
||||
init_val = (i, 0, count)
|
||||
_, _, count = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
|
||||
_, _, count = lax.while_loop(cond_fun, body_fun, init_val)
|
||||
return count
|
||||
|
||||
cloop = api.jit(outer_loop)
|
||||
@ -103,7 +102,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
pos, count = state
|
||||
return (lax.add(pos, 1), lax.add(count, inc))
|
||||
|
||||
result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0))
|
||||
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
||||
_, count = result
|
||||
return count
|
||||
|
||||
@ -135,7 +134,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc))
|
||||
return api.jit(f)(pos, inc)
|
||||
|
||||
result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0))
|
||||
result = lax.while_loop(loop_cond, loop_body, (init, 0))
|
||||
_, count = result
|
||||
return count
|
||||
|
||||
@ -172,7 +171,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
out = onp.zeros(arr.shape, dtype=arr.dtype)
|
||||
init_val = (0, num, arr, out)
|
||||
_, _, _, out = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
|
||||
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
|
||||
return out
|
||||
|
||||
def inner_loop(i, arr, out): # pylint: disable=missing-docstring
|
||||
@ -189,7 +188,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return (i, lax.add(j, 1), arr, out)
|
||||
|
||||
init_val = (i, 0, arr, out)
|
||||
_, _, _, out = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
|
||||
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
|
||||
return out
|
||||
|
||||
cloop = api.jit(outer_loop)
|
||||
@ -210,7 +209,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return (arr, num, lax.add(i, 1), lax.add(total, arr_i))
|
||||
|
||||
init_val = (arr, num, 0, 0.)
|
||||
_, _, _, total = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
|
||||
_, _, _, total = lax.while_loop(cond_fun, body_fun, init_val)
|
||||
return total
|
||||
|
||||
cfun = api.jit(sum_first_n)
|
||||
@ -226,7 +225,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def count(num):
|
||||
def body_fun(i, tot):
|
||||
return lax.add(tot, i)
|
||||
return lax_control_flow.fori_loop(0, num, body_fun, 0)
|
||||
return lax.fori_loop(0, num, body_fun, 0)
|
||||
|
||||
cfun = api.jit(count)
|
||||
|
||||
@ -241,7 +240,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def count(num):
|
||||
def body_fun(i, tot):
|
||||
return lax.add(num, lax.add(tot, i))
|
||||
return lax_control_flow.fori_loop(0, num, body_fun, 0)
|
||||
return lax.fori_loop(0, num, body_fun, 0)
|
||||
|
||||
cfun = api.jit(count)
|
||||
|
||||
@ -260,7 +259,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return (arr, lax.add(total, arr_i))
|
||||
|
||||
init_val = (arr, 0.)
|
||||
_, total = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
|
||||
_, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
|
||||
init_val)
|
||||
return total
|
||||
|
||||
@ -281,7 +280,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return {'arr': arr, 'total': lax.add(total, arr_i)}
|
||||
|
||||
init_val = {'arr': arr, 'total': 0.}
|
||||
out_val = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
||||
out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
||||
return out_val['total']
|
||||
|
||||
cfun = api.jit(sum_first_n)
|
||||
@ -301,7 +300,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return (arr, lax.add(total, arr_i), ())
|
||||
|
||||
init_val = (arr, 0., ())
|
||||
_, tot, _ = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
||||
_, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
|
||||
return tot
|
||||
|
||||
cfun = api.jit(sum_first_n)
|
||||
@ -326,7 +325,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def false_fun(x):
|
||||
y = lax.mul(2, x)
|
||||
return y, lax.mul(2, y)
|
||||
return lax_control_flow.cond(lax.lt(x, 3), x, lambda x: (x, x), x, false_fun)
|
||||
return lax.cond(lax.lt(x, 3), x, lambda x: (x, x), x, false_fun)
|
||||
|
||||
self.assertEqual(fun(0), cfun(0))
|
||||
self.assertEqual(fun(0), (0, 0))
|
||||
@ -351,10 +350,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@api.jit
|
||||
def cfun(x):
|
||||
return lax_control_flow.cond(
|
||||
return lax.cond(
|
||||
lax.lt(x, 2),
|
||||
x, lambda x: lax.mul(2, x),
|
||||
x, lambda x: lax_control_flow.cond(lax.lt(x, 5),
|
||||
x, lambda x: lax.cond(lax.lt(x, 5),
|
||||
x, lambda x: lax.mul(3, x),
|
||||
4, lambda y: lax.mul(y, x)))
|
||||
|
||||
@ -374,7 +373,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@api.jit
|
||||
def cfun(x):
|
||||
return lax_control_flow.cond(lax.lt(x, 3), x, lambda x: 5, x, lambda x: x)
|
||||
return lax.cond(lax.lt(x, 3), x, lambda x: 5, x, lambda x: x)
|
||||
|
||||
self.assertEqual(fun(0), cfun(0))
|
||||
self.assertEqual(cfun(0), 5)
|
||||
@ -390,7 +389,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@api.jit
|
||||
def cfun(x):
|
||||
return lax_control_flow.cond(lax.lt(x, 3),
|
||||
return lax.cond(lax.lt(x, 3),
|
||||
x, lambda x: (1, 2., 3.),
|
||||
x, lambda x: (x, 2., 4.))
|
||||
|
||||
@ -401,7 +400,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def testIssue514(self):
|
||||
# just check this doesn't crash
|
||||
lax_control_flow.cond(True,
|
||||
lax.cond(True,
|
||||
(0, 0), lambda x: (x[0], 0),
|
||||
(1, 1), lambda x: x)
|
||||
|
||||
|
@ -25,7 +25,6 @@ from absl.testing import parameterized
|
||||
import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax import lax
|
||||
from jax import lax_parallel
|
||||
from jax.api import _serial_pmap, _papply, jit, make_jaxpr
|
||||
from jax.linear_util import wrap_init
|
||||
|
||||
@ -42,40 +41,40 @@ class SerialPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReduceSum(self):
|
||||
f = lambda x: lax_parallel.psum(x, 'i')
|
||||
f = lambda x: lax.psum(x, 'i')
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
|
||||
expected = 4 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReduceMax(self):
|
||||
f = lambda x: lax_parallel.pmax(x, 'i')
|
||||
f = lambda x: lax.pmax(x, 'i')
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.arange(4))
|
||||
expected = 3 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsplit(self):
|
||||
f = lambda x: lax_parallel.psplit(x, 'i', 2)
|
||||
f = lambda x: lax.psplit(x, 'i', 2)
|
||||
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
|
||||
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
|
||||
expected = arg
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsplitLike(self):
|
||||
f = lambda x, y: lax_parallel.psplit_like(x, y, 'i')
|
||||
f = lambda x, y: lax.psplit_like(x, y, 'i')
|
||||
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
|
||||
ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
|
||||
expected = arg
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testLogSoftmax(self):
|
||||
f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i'))
|
||||
f = lambda x: x - np.log(lax.psum(np.exp(x), 'i'))
|
||||
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
|
||||
ans = _serial_pmap(f, axis_name='i')(x)
|
||||
expected = x - onp.log(onp.sum(onp.exp(x)))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testNested(self):
|
||||
f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j')
|
||||
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
|
||||
x = onp.ones((2, 2))
|
||||
ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x)
|
||||
ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x)
|
||||
@ -103,7 +102,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.ones(3))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
lambda x: lax_parallel.psum(x, axis_name))(onp.zeros((5, 3)))
|
||||
lambda x: lax.psum(x, axis_name))(onp.zeros((5, 3)))
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
arg = onp.arange(15.).reshape((5, 3))
|
||||
@ -116,7 +115,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.ones(3))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
lambda x: lax_parallel.pmax(x, axis_name))(onp.zeros((5, 3)))
|
||||
lambda x: lax.pmax(x, axis_name))(onp.zeros((5, 3)))
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
arg = onp.arange(15.).reshape((5, 3))
|
||||
@ -135,9 +134,9 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
def expected_spmd(p, t, f):
|
||||
return lax.select(
|
||||
lax_parallel.psplit_like(p, t, axis_name),
|
||||
lax.psplit_like(p, t, axis_name),
|
||||
t,
|
||||
lax_parallel.psplit_like(f, t, axis_name))
|
||||
lax.psplit_like(f, t, axis_name))
|
||||
|
||||
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
@ -156,7 +155,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5))
|
||||
lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
|
||||
|
@ -27,7 +27,6 @@ import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax import lax
|
||||
from jax.api import pmap, vmap, jvp, grad, make_jaxpr, linearize, device_put
|
||||
from jax.lax_parallel import psum
|
||||
from jax.lib import xla_bridge
|
||||
from jax.util import prod
|
||||
from jax.interpreters import pxla
|
||||
@ -54,7 +53,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
return device_mesh_shape
|
||||
|
||||
def testBasic(self):
|
||||
f = pmap(lambda x: x - psum(x, 'i'), axis_name='i')
|
||||
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
|
||||
@ -64,7 +63,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testNestedBasic(self):
|
||||
f = lambda x: psum(psum(x, 'i'), 'j')
|
||||
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
|
||||
f = pmap(pmap(f, 'i'), 'j')
|
||||
|
||||
def sum_and_broadcast(x, axis):
|
||||
@ -145,7 +144,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testTwoArgsGrad(self):
|
||||
def f(x, y):
|
||||
return psum(5. * np.cos(x) * np.sin(y), 'i')
|
||||
return lax.psum(5. * np.cos(x) * np.sin(y), 'i')
|
||||
f = pmap(f, 'i')
|
||||
|
||||
def g(x, y):
|
||||
@ -223,7 +222,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
|
||||
|
||||
def testPsumMultiple(self):
|
||||
f = lambda x: psum(x, ('i', 'j'))
|
||||
f = lambda x: lax.psum(x, ('i', 'j'))
|
||||
f = pmap(pmap(f, 'i'), 'j')
|
||||
|
||||
def sum_and_broadcast(x, axis):
|
||||
|
Loading…
x
Reference in New Issue
Block a user