make a lax package, revert control flow names (#607)

c.f. #597
pair=skyewm
This commit is contained in:
Matthew Johnson 2019-04-12 16:28:40 -07:00 committed by Skye Wanderman-Milne
parent 4f0280fe36
commit 0cf14837c9
18 changed files with 132 additions and 117 deletions

View File

@ -307,10 +307,10 @@ Because `jit` aims to specialize Python functions only on shapes and dtypes
during tracing, rather than on concrete values, Python control flow that depends
on concrete values wont be able to execute and will instead raise an error. If
you want compiled control flow, use structured control flow primitives like
lax_control_flow.cond and lax_control_flow.while_loop. Some indexing features,
like slice-based indexing `A[i:i+5]` for argument-dependent `i`, or
boolean-based indexing `A[bool_ind]` for argument-dependent `bool_ind`, produce
abstract values of unknown shape and are thus unsupported in `jit` functions.
lax.cond and lax.while_loop. Some indexing features, like slice-based indexing
`A[i:i+5]` for argument-dependent `i`, or boolean-based indexing `A[bool_ind]`
for argument-dependent `bool_ind`, produce abstract values of unknown shape and
are thus unsupported in `jit` functions.
In general, JAX is intended to be used with a functional style of Python
programming. Functions passed to transformations like `grad` and `jit` are

View File

@ -29,7 +29,7 @@ import matplotlib.pyplot as plt
import jax.numpy as np
from jax.config import config
from jax import jit, grad, lax, lax_control_flow, random
from jax import jit, grad, lax, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
@ -117,7 +117,7 @@ if __name__ == "__main__":
loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
g = grad(loss)(optimizers.get_params(opt_state))
return opt_update(i, g, opt_state)
return lax_control_flow.fori_loop(0, num_batches, body_fun, opt_state)
return lax.fori_loop(0, num_batches, body_fun, opt_state)
@jit
def evaluate(opt_state, images):

22
jax/lax/__init__.py Normal file
View File

@ -0,0 +1,22 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .lax import *
from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_reduce_and, _reduce_window_sum, _reduce_window_max,
_reduce_window_min, _reduce_window_prod, _float, _complex,
_input_dtype, _const, _eq_meet)
from .lax_control_flow import *
from .lax_parallel import *

View File

@ -35,27 +35,27 @@ from six.moves import builtins, xrange
import numpy as onp
from .util import partial, prod
from ..util import partial, prod
from . import core
from . import ad_util
from . import api
from . import linear_util as lu
from .config import flags
from .core import Primitive
from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
array_types, make_shaped_array)
from .api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple)
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
from .util import curry, memoize, safe_zip, unzip2, prod
from .tree_util import build_tree, tree_unflatten
from .lib import xla_bridge
from .lib.xla_bridge import xla_client
from .. import core
from .. import ad_util
from .. import api
from .. import linear_util as lu
from ..config import flags
from ..core import Primitive
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
array_types, make_shaped_array)
from ..api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple)
from ..interpreters import partial_eval as pe
from ..interpreters import xla
from ..interpreters import ad
from ..interpreters import batching
from ..interpreters import parallel
from ..util import curry, memoize, safe_zip, unzip2, prod
from ..tree_util import build_tree, tree_unflatten
from ..lib import xla_bridge
from ..lib.xla_bridge import xla_client
FLAGS = flags.FLAGS
@ -3795,7 +3795,7 @@ _one = partial(full_like, shape=(), fill_value=1)
_twos = partial(full_like, fill_value=2)
_two = partial(full_like, shape=(), fill_value=2)
_dtype = onp.result_type
_dtype = dtype = onp.result_type
_iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating)

View File

@ -23,7 +23,7 @@ import numpy as onp
from jax import api
from jax import core
from jax import lax
from jax.lax import lax
from jax import linear_util as lu
from jax.abstract_arrays import ConcreteArray, ShapedArray, UnshapedArray
from jax.api_util import (

View File

@ -15,7 +15,7 @@
Parallelization primitives.
"""
from jax import lax
from jax.lax import lax
from jax.abstract_arrays import ShapedArray
from jax.core import Primitive
from jax.interpreters import ad

View File

@ -21,7 +21,6 @@ import numpy as onp
from jax.numpy import lax_numpy as np
from jax import core
from jax import lax
from jax import lax_control_flow
from jax import ad_util
from jax.interpreters import xla
from jax.interpreters import ad
@ -300,7 +299,7 @@ def lu_jvp_rule(primals, tangents):
a_shape = np.shape(a)
m, n = a_shape[-2:]
dtype = lax._dtype(a)
dtype = lax.dtype(a)
k = min(m, n)
permutation = lu_pivots_to_permutation(pivots, m)
@ -377,7 +376,7 @@ def lu_pivots_to_permutation(swaps, k):
n, = np.shape(swaps)
permutation = np.arange(k)
_, permutation = lax_control_flow.fori_loop(
_, permutation = lax.fori_loop(
onp.array(0, onp.int32), onp.array(n, onp.int32), body_fn, (swaps, permutation))
return permutation

View File

@ -33,7 +33,6 @@ from .. import core
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from ..interpreters.xla import DeviceArray
from .. import lax
from .. import lax_control_flow
from ..util import memoize, partial, get_module_functions, unzip2, prod as _prod
from ..lib import xla_bridge
from ..lib.xla_bridge import xla_client
@ -93,7 +92,7 @@ result_type = onp.result_type
shape = _shape = onp.shape
ndim = _ndim = onp.ndim
size = onp.size
_dtype = lax._dtype
_dtype = lax.dtype
bool_ = onp.bool_
uint8 = onp.uint8
@ -398,7 +397,7 @@ def power(x1, x2):
x1 = asarray(x1)
x2 = asarray(x2)
x1, x2 = _promote_args_like(onp.power, x1, x2)
dtype = lax._dtype(x1)
dtype = _dtype(x1)
if not issubdtype(dtype, integer):
return lax.pow(x1, x2)
@ -2165,8 +2164,8 @@ kaiser = onp.kaiser # TODO: lower via lax to allow non-constant beta.
@_wraps(getattr(onp, "gcd", None))
def gcd(x1, x2):
if (not issubdtype(lax._dtype(x1), integer) or
not issubdtype(lax._dtype(x2), integer)):
if (not issubdtype(_dtype(x1), integer) or
not issubdtype(_dtype(x2), integer)):
raise ValueError("Arguments to gcd must be integers.")
def cond_fn(xs):
x1, x2 = xs
@ -2178,7 +2177,7 @@ def gcd(x1, x2):
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
x1, x2 = _promote_dtypes(lax.abs(x1), lax.abs(x2))
x1, x2 = broadcast_arrays(x1, x2)
gcd, _ = lax_control_flow.while_loop(cond_fn, body_fn, (x1, x2))
gcd, _ = lax.while_loop(cond_fn, body_fn, (x1, x2))
return gcd

View File

@ -63,7 +63,7 @@ def svd(a, full_matrices=True, compute_uv=True):
@_wraps(onp.linalg.slogdet)
def slogdet(a):
a = _promote_arg_dtypes(np.asarray(a))
dtype = lax._dtype(a)
dtype = lax.dtype(a)
a_shape = np.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
@ -140,7 +140,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
elif ord == -np.inf:
return np.amin(np.abs(x), axis=axis, keepdims=keepdims)
elif ord == 0:
return np.sum(x != 0, dtype=np.finfo(lax._dtype(x)).dtype,
return np.sum(x != 0, dtype=np.finfo(lax.dtype(x)).dtype,
axis=axis, keepdims=keepdims)
elif ord == 1:
# Numpy has a special case for ord == 1 as an optimization. We don't
@ -228,7 +228,7 @@ def solve(a, b):
"b=[..., m, k] or b=[..., m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
lu, pivots = lax_linalg.lu(a)
dtype = lax._dtype(a)
dtype = lax.dtype(a)
# TODO(phawkins): add unit_diagonal support to solve_triangular, use it here
# instead of explicit masking of l.

View File

@ -43,7 +43,7 @@ def _scatter_update(x, idx, y, scatter_op):
y = np.asarray(y)
x_shape = np.shape(x)
y_shape = np.shape(y)
y = lax.convert_element_type(y, lax._dtype(x))
y = lax.convert_element_type(y, lax.dtype(x))
if not isinstance(idx, tuple):
idx = (idx,)

View File

@ -29,7 +29,6 @@ from functools import partial
import numpy as onp
from . import lax
from . import lax_control_flow
from . import numpy as np
from . import tree_util
from .api import jit, vmap
@ -77,7 +76,7 @@ def _make_rotate_left(dtype):
nbits = onp.array(onp.iinfo(dtype).bits, dtype)
def _rotate_left(x, d):
if lax._dtype(d) != lax._dtype(x):
if lax.dtype(d) != lax.dtype(x):
d = lax.convert_element_type(d, x.dtype)
return (x << d) | lax.shift_right_logical(x, nbits - d)
return _rotate_left
@ -104,11 +103,11 @@ def threefry_2x32(keypair, count):
"""
# Based on ThreeFry2x32 by phawkins@ in //.../xla/client/lib/prng.cc
key1, key2 = keypair
if not lax._dtype(key1) == lax._dtype(key2) == lax._dtype(count) == onp.uint32:
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax._dtype(x) for x in [key1, key2, count]]))
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
rotate_left = _make_rotate_left(lax._dtype(count))
rotate_left = _make_rotate_left(lax.dtype(count))
def apply_round(v, rot):
v = v[:]
@ -255,7 +254,7 @@ def _uniform(key, shape, dtype, minval, maxval):
# bit-level transformation we use relies on Numpy and XLA having bit-for-bit
# equivalent float representations, which might not be true on all platforms.
float_bits = lax.bitwise_or(
lax.shift_right_logical(bits, onp.array(nbits - nmant, lax._dtype(bits))),
lax.shift_right_logical(bits, onp.array(nbits - nmant, lax.dtype(bits))),
onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64))
floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype)
return lax.max(
@ -399,7 +398,7 @@ def bernoulli(key, mean=onp.float32(0.5), shape=()):
@partial(jit, static_argnums=(2,))
def _bernoulli(key, mean, shape):
shape = shape or onp.shape(mean)
if not onp.issubdtype(lax._dtype(mean), onp.float32):
if not onp.issubdtype(lax.dtype(mean), onp.float32):
mean = lax.convert_element_type(mean, onp.float32)
if onp.shape(mean) != shape:
mean = np.broadcast_to(mean, shape)
@ -457,7 +456,7 @@ def _gamma_one(key, alpha):
one_over_two = _constant_like(alpha, 0.5)
one_over_three = _constant_like(alpha, 1. / 3.)
squeeze_const = _constant_like(alpha, 0.0331)
dtype = lax._dtype(alpha)
dtype = lax.dtype(alpha)
key, subkey = split(key)
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
@ -472,7 +471,7 @@ def _gamma_one(key, alpha):
def _cond_fn(kXVU):
_, X, V, U = kXVU
# TODO: use lax_control_flow.cond when its batching rule is supported
# TODO: use lax.cond when its batching rule is supported
# The reason is to avoid evaluating second condition which involves log+log
# if the first condition is satisfied
cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))),
@ -492,7 +491,7 @@ def _gamma_one(key, alpha):
return key, X, V, U
# initial state is chosen such that _cond_fn will return True
_, _, V, _ = lax_control_flow.while_loop(
_, _, V, _ = lax.while_loop(
_cond_fn, _body_fn, (key, zero, _constant_like(alpha, -1), zero))
z = lax.mul(lax.mul(d, V), boost)
return lax.select(lax.eq(z, zero), onp.finfo(z.dtype).tiny, z)

View File

@ -129,7 +129,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
a = np_linalg._promote_arg_dtypes(np.asarray(a))
lu, pivots = lax_linalg.lu(a)
dtype = lax._dtype(a)
dtype = lax.dtype(a)
m, n = np.shape(a)
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
p = np.real(np.array(permutation == np.arange(m)[:, None], dtype=dtype))

View File

@ -165,7 +165,7 @@ def ndtr(x):
TypeError: if `x` is not floating-type.
"""
x = np.asarray(x)
dtype = lax._dtype(x)
dtype = lax.dtype(x)
if dtype not in (np.float32, np.float64):
raise TypeError(
"x.dtype={} is not supported, see docstring for supported types."
@ -175,7 +175,7 @@ def ndtr(x):
def _ndtr(x):
"""Implements ndtr core logic."""
dtype = lax._dtype(x).type
dtype = lax.dtype(x).type
half_sqrt_2 = dtype(0.5) * onp.sqrt(2., dtype=dtype)
w = x * half_sqrt_2
z = lax.abs(w)
@ -206,7 +206,7 @@ def ndtri(p):
TypeError: if `p` is not floating-type.
"""
x = np.asarray(p)
dtype = lax._dtype(p)
dtype = lax.dtype(p)
if dtype not in (np.float32, np.float64):
raise TypeError(
"x.dtype={} is not supported, see docstring for supported types."
@ -271,7 +271,7 @@ def _ndtri(p):
2.89247864745380683936E-6,
6.79019408009981274425E-9]))
dtype = lax._dtype(p).type
dtype = lax.dtype(p).type
shape = np.shape(p)
def _create_polynomial(var, coeffs):
@ -391,7 +391,7 @@ def log_ndtr(x, series_order=3):
raise ValueError("series_order must be <= 30.")
x = np.asarray(x)
dtype = lax._dtype(x)
dtype = lax.dtype(x)
if dtype == np.float64:
lower_segment = _LOGNDTR_FLOAT64_LOWER
@ -427,7 +427,7 @@ def log_ndtr(x, series_order=3):
def _log_ndtr_lower(x, series_order):
"""Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
dtype = lax._dtype(x).type
dtype = lax.dtype(x).type
x_2 = lax.square(x)
# Log of the term multiplying (1 + sum)
log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * onp.log(2. * onp.pi))
@ -436,7 +436,7 @@ def _log_ndtr_lower(x, series_order):
def _log_ndtr_asymptotic_series(x, series_order):
"""Calculates the asymptotic series used in log_ndtr."""
dtype = lax._dtype(x).type
dtype = lax.dtype(x).type
if series_order <= 0:
return onp.array(1, dtype)
x_2 = lax.square(x)

View File

@ -233,7 +233,7 @@
"\n",
"__NB__: _Fancy Indexing_ is __not__ yet supported, but will likely be added to JAX soon.\n",
"\n",
"️⚠️ inside `jit`'d code and `lax_control_flow.while_loop` or `lax_control_flow.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
"️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
]
},
{
@ -1166,10 +1166,10 @@
"## Structured control flow primitives\n",
"\n",
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:\n",
" - `lax_control_flow.cond` _will be differentiable soon_\n",
" - `lax_control_flow.while_loop` __non-differentiable__*\n",
" - `lax_control_flow.fori_loop` __non-differentiable__*\n",
" - `lax_control_flow.scan` _will be differentiable soon_\n",
" - `lax.cond` _will be differentiable soon_\n",
" - `lax.while_loop` __non-differentiable__*\n",
" - `lax.fori_loop` __non-differentiable__*\n",
" - `lax.scan` _will be differentiable soon_\n",
"\n",
"*_these can in principle be made to be __forward__-differentiable, but this isn't on the current roadmap._"
]
@ -1208,9 +1208,9 @@
"from jax import lax\n",
"\n",
"operand = np.array([0.])\n",
"lax_control_flow.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"# --> array([1.], dtype=float32)\n",
"lax_control_flow.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"# --> array([-1.], dtype=float32)"
],
"execution_count": 0,
@ -1263,7 +1263,7 @@
"init_val = 0\n",
"cond_fun = lambda x: x<10\n",
"body_fun = lambda x: x+1\n",
"lax_control_flow.while_loop(cond_fun, body_fun, init_val)\n",
"lax.while_loop(cond_fun, body_fun, init_val)\n",
"# --> array(10, dtype=int32)"
],
"execution_count": 0,
@ -1316,7 +1316,7 @@
"start = 0\n",
"stop = 10\n",
"body_fun = lambda i,x: x+i\n",
"lax_control_flow.fori_loop(start, stop, body_fun, init_val)\n",
"lax.fori_loop(start, stop, body_fun, init_val)\n",
"# --> array(45, dtype=int32)"
],
"execution_count": 0,
@ -1354,10 +1354,10 @@
"\\textrm{if} & ❌ & ✔ \\\\\n",
"\\textrm{for} & ✔* & ✔\\\\\n",
"\\textrm{while} & ✔* & ✔\\\\\n",
"\\textrm{lax_control_flow.cond} & ✔ & \\textrm{soon!}\\\\\n",
"\\textrm{lax_control_flow.while_loop} & ✔ & ❌\\\\\n",
"\\textrm{lax_control_flow.fori_loop} & ✔ & ❌\\\\\n",
"\\textrm{lax_control_flow.scan} & \\textrm{soon!} & \\textrm{soon!}\\\\\n",
"\\textrm{lax.cond} & ✔ & \\textrm{soon!}\\\\\n",
"\\textrm{lax.while_loop} & ✔ & ❌\\\\\n",
"\\textrm{lax.fori_loop} & ✔ & ❌\\\\\n",
"\\textrm{lax.scan} & \\textrm{soon!} & \\textrm{soon!}\\\\\n",
"\\hline\n",
"\\end{array}\n",
"$$\n",

View File

@ -24,7 +24,6 @@ import jax.numpy as np
from jax import test_util as jtu
from jax.abstract_arrays import ShapedArray
from jax import lax
from jax import lax_control_flow
from jax import lax_linalg
from jax import random
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr, jacfwd, jacrev, hessian
@ -875,7 +874,7 @@ class BatchingTest(jtu.JaxTestCase):
def testWhileLoop(self):
def fun(x):
return lax_control_flow.while_loop(lambda x: x < 3, lambda x: x + 2, x)
return lax.while_loop(lambda x: x < 3, lambda x: x + 2, x)
ans = vmap(fun)(onp.array([0, 1, 2, 3]))
expected = onp.array([4, 3, 4, 3])
@ -888,7 +887,7 @@ class BatchingTest(jtu.JaxTestCase):
def testWhileLoopCondConstsBatched(self):
def fun(x, y):
return lax_control_flow.while_loop(lambda x: x < y, lambda x: x + 2, x)
return lax.while_loop(lambda x: x < y, lambda x: x + 2, x)
ans = vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3]))
expected = onp.array([2, 4])
@ -896,7 +895,7 @@ class BatchingTest(jtu.JaxTestCase):
def testWhileLoopBodyConstsBatched(self):
def fun(x, y):
return lax_control_flow.while_loop(lambda x: x < 3, lambda x: x + y, x)
return lax.while_loop(lambda x: x < 3, lambda x: x + y, x)
ans = vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3]))
expected = onp.array([4, 3])
@ -913,7 +912,7 @@ class BatchingTest(jtu.JaxTestCase):
return x, y
def fun(x, y):
return lax_control_flow.while_loop(cond_fun, body_fun, (x, y))
return lax.while_loop(cond_fun, body_fun, (x, y))
ans = vmap(fun)(onp.array([0, 0]), onp.array([1, 2]))
expected = (onp.array([4, 3]), onp.array([1, 2]))
@ -927,7 +926,7 @@ class BatchingTest(jtu.JaxTestCase):
return x, y
def fun(x):
return lax_control_flow.fori_loop(0, 10, body_fun, (x, 0))
return lax.fori_loop(0, 10, body_fun, (x, 0))
ans = vmap(fun)(onp.array([0, 1]))
expected = (onp.array([10, 11]), onp.array([20, 20]))
@ -941,7 +940,7 @@ class BatchingTest(jtu.JaxTestCase):
key, _ = random.split(key)
return u, key
u, _ = lax_control_flow.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
return u
print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash

View File

@ -25,7 +25,6 @@ import numpy.random as npr
from jax import api
from jax import lax
from jax import lax_control_flow
from jax import test_util as jtu
@ -43,7 +42,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return (lax.add(pos, 1), lax.add(count, 1))
def loop(init):
result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0))
result = lax.while_loop(loop_cond, loop_body, (init, 0))
_, count = result
return count
@ -66,7 +65,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return (num, lax.add(i, 1), inner_loop(i, count))
init_val = (num, 0, 0)
_, i, count = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
_, i, count = lax.while_loop(cond_fun, body_fun, init_val)
return (i, count)
def inner_loop(i, count): # pylint: disable=missing-docstring
@ -79,7 +78,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return (i, lax.add(j, 1), lax.add(count, 1))
init_val = (i, 0, count)
_, _, count = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
_, _, count = lax.while_loop(cond_fun, body_fun, init_val)
return count
cloop = api.jit(outer_loop)
@ -103,7 +102,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
pos, count = state
return (lax.add(pos, 1), lax.add(count, inc))
result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0))
result = lax.while_loop(loop_cond, loop_body, (init, 0))
_, count = result
return count
@ -135,7 +134,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc))
return api.jit(f)(pos, inc)
result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0))
result = lax.while_loop(loop_cond, loop_body, (init, 0))
_, count = result
return count
@ -172,7 +171,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
out = onp.zeros(arr.shape, dtype=arr.dtype)
init_val = (0, num, arr, out)
_, _, _, out = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
return out
def inner_loop(i, arr, out): # pylint: disable=missing-docstring
@ -189,7 +188,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return (i, lax.add(j, 1), arr, out)
init_val = (i, 0, arr, out)
_, _, _, out = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
_, _, _, out = lax.while_loop(cond_fun, body_fun, init_val)
return out
cloop = api.jit(outer_loop)
@ -210,7 +209,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return (arr, num, lax.add(i, 1), lax.add(total, arr_i))
init_val = (arr, num, 0, 0.)
_, _, _, total = lax_control_flow.while_loop(cond_fun, body_fun, init_val)
_, _, _, total = lax.while_loop(cond_fun, body_fun, init_val)
return total
cfun = api.jit(sum_first_n)
@ -226,7 +225,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def count(num):
def body_fun(i, tot):
return lax.add(tot, i)
return lax_control_flow.fori_loop(0, num, body_fun, 0)
return lax.fori_loop(0, num, body_fun, 0)
cfun = api.jit(count)
@ -241,7 +240,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def count(num):
def body_fun(i, tot):
return lax.add(num, lax.add(tot, i))
return lax_control_flow.fori_loop(0, num, body_fun, 0)
return lax.fori_loop(0, num, body_fun, 0)
cfun = api.jit(count)
@ -260,7 +259,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return (arr, lax.add(total, arr_i))
init_val = (arr, 0.)
_, total = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
_, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun,
init_val)
return total
@ -281,7 +280,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return {'arr': arr, 'total': lax.add(total, arr_i)}
init_val = {'arr': arr, 'total': 0.}
out_val = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
return out_val['total']
cfun = api.jit(sum_first_n)
@ -301,7 +300,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return (arr, lax.add(total, arr_i), ())
init_val = (arr, 0., ())
_, tot, _ = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
_, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
return tot
cfun = api.jit(sum_first_n)
@ -326,7 +325,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def false_fun(x):
y = lax.mul(2, x)
return y, lax.mul(2, y)
return lax_control_flow.cond(lax.lt(x, 3), x, lambda x: (x, x), x, false_fun)
return lax.cond(lax.lt(x, 3), x, lambda x: (x, x), x, false_fun)
self.assertEqual(fun(0), cfun(0))
self.assertEqual(fun(0), (0, 0))
@ -351,10 +350,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
@api.jit
def cfun(x):
return lax_control_flow.cond(
return lax.cond(
lax.lt(x, 2),
x, lambda x: lax.mul(2, x),
x, lambda x: lax_control_flow.cond(lax.lt(x, 5),
x, lambda x: lax.cond(lax.lt(x, 5),
x, lambda x: lax.mul(3, x),
4, lambda y: lax.mul(y, x)))
@ -374,7 +373,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
@api.jit
def cfun(x):
return lax_control_flow.cond(lax.lt(x, 3), x, lambda x: 5, x, lambda x: x)
return lax.cond(lax.lt(x, 3), x, lambda x: 5, x, lambda x: x)
self.assertEqual(fun(0), cfun(0))
self.assertEqual(cfun(0), 5)
@ -390,7 +389,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
@api.jit
def cfun(x):
return lax_control_flow.cond(lax.lt(x, 3),
return lax.cond(lax.lt(x, 3),
x, lambda x: (1, 2., 3.),
x, lambda x: (x, 2., 4.))
@ -401,7 +400,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def testIssue514(self):
# just check this doesn't crash
lax_control_flow.cond(True,
lax.cond(True,
(0, 0), lambda x: (x[0], 0),
(1, 1), lambda x: x)

View File

@ -25,7 +25,6 @@ from absl.testing import parameterized
import jax.numpy as np
from jax import test_util as jtu
from jax import lax
from jax import lax_parallel
from jax.api import _serial_pmap, _papply, jit, make_jaxpr
from jax.linear_util import wrap_init
@ -42,40 +41,40 @@ class SerialPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceSum(self):
f = lambda x: lax_parallel.psum(x, 'i')
f = lambda x: lax.psum(x, 'i')
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
expected = 4 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceMax(self):
f = lambda x: lax_parallel.pmax(x, 'i')
f = lambda x: lax.pmax(x, 'i')
ans = _serial_pmap(f, axis_name='i')(onp.arange(4))
expected = 3 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsplit(self):
f = lambda x: lax_parallel.psplit(x, 'i', 2)
f = lambda x: lax.psplit(x, 'i', 2)
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
expected = arg
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsplitLike(self):
f = lambda x, y: lax_parallel.psplit_like(x, y, 'i')
f = lambda x, y: lax.psplit_like(x, y, 'i')
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
expected = arg
self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self):
f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i'))
f = lambda x: x - np.log(lax.psum(np.exp(x), 'i'))
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
ans = _serial_pmap(f, axis_name='i')(x)
expected = x - onp.log(onp.sum(onp.exp(x)))
self.assertAllClose(ans, expected, check_dtypes=False)
def testNested(self):
f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j')
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
x = onp.ones((2, 2))
ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x)
ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x)
@ -103,7 +102,7 @@ class PapplyTest(jtu.JaxTestCase):
jaxpr = make_jaxpr(pfun)(onp.ones(3))
expected_jaxpr = make_jaxpr(
lambda x: lax_parallel.psum(x, axis_name))(onp.zeros((5, 3)))
lambda x: lax.psum(x, axis_name))(onp.zeros((5, 3)))
assert repr(jaxpr) == repr(expected_jaxpr)
arg = onp.arange(15.).reshape((5, 3))
@ -116,7 +115,7 @@ class PapplyTest(jtu.JaxTestCase):
jaxpr = make_jaxpr(pfun)(onp.ones(3))
expected_jaxpr = make_jaxpr(
lambda x: lax_parallel.pmax(x, axis_name))(onp.zeros((5, 3)))
lambda x: lax.pmax(x, axis_name))(onp.zeros((5, 3)))
assert repr(jaxpr) == repr(expected_jaxpr)
arg = onp.arange(15.).reshape((5, 3))
@ -135,9 +134,9 @@ class PapplyTest(jtu.JaxTestCase):
def expected_spmd(p, t, f):
return lax.select(
lax_parallel.psplit_like(p, t, axis_name),
lax.psplit_like(p, t, axis_name),
t,
lax_parallel.psplit_like(f, t, axis_name))
lax.psplit_like(f, t, axis_name))
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
assert repr(jaxpr) == repr(expected_jaxpr)
@ -156,7 +155,7 @@ class PapplyTest(jtu.JaxTestCase):
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
expected_jaxpr = make_jaxpr(
lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5))
lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
assert repr(jaxpr) == repr(expected_jaxpr)
ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))

View File

@ -27,7 +27,6 @@ import jax.numpy as np
from jax import test_util as jtu
from jax import lax
from jax.api import pmap, vmap, jvp, grad, make_jaxpr, linearize, device_put
from jax.lax_parallel import psum
from jax.lib import xla_bridge
from jax.util import prod
from jax.interpreters import pxla
@ -54,7 +53,7 @@ class PmapTest(jtu.JaxTestCase):
return device_mesh_shape
def testBasic(self):
f = pmap(lambda x: x - psum(x, 'i'), axis_name='i')
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (xla_bridge.device_count(), 4)
x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
@ -64,7 +63,7 @@ class PmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testNestedBasic(self):
f = lambda x: psum(psum(x, 'i'), 'j')
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
f = pmap(pmap(f, 'i'), 'j')
def sum_and_broadcast(x, axis):
@ -145,7 +144,7 @@ class PmapTest(jtu.JaxTestCase):
def testTwoArgsGrad(self):
def f(x, y):
return psum(5. * np.cos(x) * np.sin(y), 'i')
return lax.psum(5. * np.cos(x) * np.sin(y), 'i')
f = pmap(f, 'i')
def g(x, y):
@ -223,7 +222,7 @@ class PmapTest(jtu.JaxTestCase):
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
def testPsumMultiple(self):
f = lambda x: psum(x, ('i', 'j'))
f = lambda x: lax.psum(x, ('i', 'j'))
f = pmap(pmap(f, 'i'), 'j')
def sum_and_broadcast(x, axis):