diff --git a/README.md b/README.md index 7ea5003ff..a7cbd5e23 100644 --- a/README.md +++ b/README.md @@ -307,10 +307,10 @@ Because `jit` aims to specialize Python functions only on shapes and dtypes during tracing, rather than on concrete values, Python control flow that depends on concrete values won’t be able to execute and will instead raise an error. If you want compiled control flow, use structured control flow primitives like -lax_control_flow.cond and lax_control_flow.while_loop. Some indexing features, -like slice-based indexing `A[i:i+5]` for argument-dependent `i`, or -boolean-based indexing `A[bool_ind]` for argument-dependent `bool_ind`, produce -abstract values of unknown shape and are thus unsupported in `jit` functions. +lax.cond and lax.while_loop. Some indexing features, like slice-based indexing +`A[i:i+5]` for argument-dependent `i`, or boolean-based indexing `A[bool_ind]` +for argument-dependent `bool_ind`, produce abstract values of unknown shape and +are thus unsupported in `jit` functions. In general, JAX is intended to be used with a functional style of Python programming. Functions passed to transformations like `grad` and `jit` are diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index a8b9315de..ca7b8e511 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -29,7 +29,7 @@ import matplotlib.pyplot as plt import jax.numpy as np from jax.config import config -from jax import jit, grad, lax, lax_control_flow, random +from jax import jit, grad, lax, random from jax.experimental import optimizers from jax.experimental import stax from jax.experimental.stax import Dense, FanOut, Relu, Softplus @@ -117,7 +117,7 @@ if __name__ == "__main__": loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size g = grad(loss)(optimizers.get_params(opt_state)) return opt_update(i, g, opt_state) - return lax_control_flow.fori_loop(0, num_batches, body_fun, opt_state) + return lax.fori_loop(0, num_batches, body_fun, opt_state) @jit def evaluate(opt_state, images): diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py new file mode 100644 index 000000000..bbf0ad7c5 --- /dev/null +++ b/jax/lax/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from .lax import * +from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or, + _reduce_and, _reduce_window_sum, _reduce_window_max, + _reduce_window_min, _reduce_window_prod, _float, _complex, + _input_dtype, _const, _eq_meet) +from .lax_control_flow import * +from .lax_parallel import * diff --git a/jax/lax.py b/jax/lax/lax.py similarity index 99% rename from jax/lax.py rename to jax/lax/lax.py index 2c1bd56b0..885f748b0 100644 --- a/jax/lax.py +++ b/jax/lax/lax.py @@ -35,27 +35,27 @@ from six.moves import builtins, xrange import numpy as onp -from .util import partial, prod +from ..util import partial, prod -from . import core -from . import ad_util -from . import api -from . import linear_util as lu -from .config import flags -from .core import Primitive -from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray, - array_types, make_shaped_array) -from .api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree, - pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple) -from .interpreters import partial_eval as pe -from .interpreters import xla -from .interpreters import ad -from .interpreters import batching -from .interpreters import parallel -from .util import curry, memoize, safe_zip, unzip2, prod -from .tree_util import build_tree, tree_unflatten -from .lib import xla_bridge -from .lib.xla_bridge import xla_client +from .. import core +from .. import ad_util +from .. import api +from .. import linear_util as lu +from ..config import flags +from ..core import Primitive +from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray, + array_types, make_shaped_array) +from ..api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree, + pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple) +from ..interpreters import partial_eval as pe +from ..interpreters import xla +from ..interpreters import ad +from ..interpreters import batching +from ..interpreters import parallel +from ..util import curry, memoize, safe_zip, unzip2, prod +from ..tree_util import build_tree, tree_unflatten +from ..lib import xla_bridge +from ..lib.xla_bridge import xla_client FLAGS = flags.FLAGS @@ -3795,7 +3795,7 @@ _one = partial(full_like, shape=(), fill_value=1) _twos = partial(full_like, fill_value=2) _two = partial(full_like, shape=(), fill_value=2) -_dtype = onp.result_type +_dtype = dtype = onp.result_type _iscomplex = lambda x: onp.issubdtype(_dtype(x), onp.complexfloating) diff --git a/jax/lax_control_flow.py b/jax/lax/lax_control_flow.py similarity index 99% rename from jax/lax_control_flow.py rename to jax/lax/lax_control_flow.py index e4679360f..1407deef5 100644 --- a/jax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -23,7 +23,7 @@ import numpy as onp from jax import api from jax import core -from jax import lax +from jax.lax import lax from jax import linear_util as lu from jax.abstract_arrays import ConcreteArray, ShapedArray, UnshapedArray from jax.api_util import ( diff --git a/jax/lax_parallel.py b/jax/lax/lax_parallel.py similarity index 99% rename from jax/lax_parallel.py rename to jax/lax/lax_parallel.py index 54162f2d7..109132a48 100644 --- a/jax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -15,7 +15,7 @@ Parallelization primitives. """ -from jax import lax +from jax.lax import lax from jax.abstract_arrays import ShapedArray from jax.core import Primitive from jax.interpreters import ad diff --git a/jax/lax_linalg.py b/jax/lax_linalg.py index c4ccf86d4..75bde6362 100644 --- a/jax/lax_linalg.py +++ b/jax/lax_linalg.py @@ -21,7 +21,6 @@ import numpy as onp from jax.numpy import lax_numpy as np from jax import core from jax import lax -from jax import lax_control_flow from jax import ad_util from jax.interpreters import xla from jax.interpreters import ad @@ -300,7 +299,7 @@ def lu_jvp_rule(primals, tangents): a_shape = np.shape(a) m, n = a_shape[-2:] - dtype = lax._dtype(a) + dtype = lax.dtype(a) k = min(m, n) permutation = lu_pivots_to_permutation(pivots, m) @@ -377,7 +376,7 @@ def lu_pivots_to_permutation(swaps, k): n, = np.shape(swaps) permutation = np.arange(k) - _, permutation = lax_control_flow.fori_loop( + _, permutation = lax.fori_loop( onp.array(0, onp.int32), onp.array(n, onp.int32), body_fn, (swaps, permutation)) return permutation diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index db7d811a9..593602556 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -33,7 +33,6 @@ from .. import core from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray from ..interpreters.xla import DeviceArray from .. import lax -from .. import lax_control_flow from ..util import memoize, partial, get_module_functions, unzip2, prod as _prod from ..lib import xla_bridge from ..lib.xla_bridge import xla_client @@ -93,7 +92,7 @@ result_type = onp.result_type shape = _shape = onp.shape ndim = _ndim = onp.ndim size = onp.size -_dtype = lax._dtype +_dtype = lax.dtype bool_ = onp.bool_ uint8 = onp.uint8 @@ -398,7 +397,7 @@ def power(x1, x2): x1 = asarray(x1) x2 = asarray(x2) x1, x2 = _promote_args_like(onp.power, x1, x2) - dtype = lax._dtype(x1) + dtype = _dtype(x1) if not issubdtype(dtype, integer): return lax.pow(x1, x2) @@ -2165,8 +2164,8 @@ kaiser = onp.kaiser # TODO: lower via lax to allow non-constant beta. @_wraps(getattr(onp, "gcd", None)) def gcd(x1, x2): - if (not issubdtype(lax._dtype(x1), integer) or - not issubdtype(lax._dtype(x2), integer)): + if (not issubdtype(_dtype(x1), integer) or + not issubdtype(_dtype(x2), integer)): raise ValueError("Arguments to gcd must be integers.") def cond_fn(xs): x1, x2 = xs @@ -2178,7 +2177,7 @@ def gcd(x1, x2): return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) x1, x2 = _promote_dtypes(lax.abs(x1), lax.abs(x2)) x1, x2 = broadcast_arrays(x1, x2) - gcd, _ = lax_control_flow.while_loop(cond_fn, body_fn, (x1, x2)) + gcd, _ = lax.while_loop(cond_fn, body_fn, (x1, x2)) return gcd diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 1b08b586f..f0bfd772d 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -63,7 +63,7 @@ def svd(a, full_matrices=True, compute_uv=True): @_wraps(onp.linalg.slogdet) def slogdet(a): a = _promote_arg_dtypes(np.asarray(a)) - dtype = lax._dtype(a) + dtype = lax.dtype(a) a_shape = np.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" @@ -140,7 +140,7 @@ def norm(x, ord=None, axis=None, keepdims=False): elif ord == -np.inf: return np.amin(np.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: - return np.sum(x != 0, dtype=np.finfo(lax._dtype(x)).dtype, + return np.sum(x != 0, dtype=np.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't @@ -228,7 +228,7 @@ def solve(a, b): "b=[..., m, k] or b=[..., m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) lu, pivots = lax_linalg.lu(a) - dtype = lax._dtype(a) + dtype = lax.dtype(a) # TODO(phawkins): add unit_diagonal support to solve_triangular, use it here # instead of explicit masking of l. diff --git a/jax/ops/scatter.py b/jax/ops/scatter.py index 380d7dd0c..4e85fad61 100644 --- a/jax/ops/scatter.py +++ b/jax/ops/scatter.py @@ -43,7 +43,7 @@ def _scatter_update(x, idx, y, scatter_op): y = np.asarray(y) x_shape = np.shape(x) y_shape = np.shape(y) - y = lax.convert_element_type(y, lax._dtype(x)) + y = lax.convert_element_type(y, lax.dtype(x)) if not isinstance(idx, tuple): idx = (idx,) diff --git a/jax/random.py b/jax/random.py index 0adb6b04f..80a169e3a 100644 --- a/jax/random.py +++ b/jax/random.py @@ -29,7 +29,6 @@ from functools import partial import numpy as onp from . import lax -from . import lax_control_flow from . import numpy as np from . import tree_util from .api import jit, vmap @@ -77,7 +76,7 @@ def _make_rotate_left(dtype): nbits = onp.array(onp.iinfo(dtype).bits, dtype) def _rotate_left(x, d): - if lax._dtype(d) != lax._dtype(x): + if lax.dtype(d) != lax.dtype(x): d = lax.convert_element_type(d, x.dtype) return (x << d) | lax.shift_right_logical(x, nbits - d) return _rotate_left @@ -104,11 +103,11 @@ def threefry_2x32(keypair, count): """ # Based on ThreeFry2x32 by phawkins@ in //.../xla/client/lib/prng.cc key1, key2 = keypair - if not lax._dtype(key1) == lax._dtype(key2) == lax._dtype(count) == onp.uint32: + if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32: msg = "threefry_2x32 requires uint32 arguments, got {}" - raise TypeError(msg.format([lax._dtype(x) for x in [key1, key2, count]])) + raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]])) - rotate_left = _make_rotate_left(lax._dtype(count)) + rotate_left = _make_rotate_left(lax.dtype(count)) def apply_round(v, rot): v = v[:] @@ -255,7 +254,7 @@ def _uniform(key, shape, dtype, minval, maxval): # bit-level transformation we use relies on Numpy and XLA having bit-for-bit # equivalent float representations, which might not be true on all platforms. float_bits = lax.bitwise_or( - lax.shift_right_logical(bits, onp.array(nbits - nmant, lax._dtype(bits))), + lax.shift_right_logical(bits, onp.array(nbits - nmant, lax.dtype(bits))), onp.array(1., dtype).view(onp.uint32 if nbits == 32 else onp.uint64)) floats = lax.bitcast_convert_type(float_bits, dtype) - onp.array(1., dtype) return lax.max( @@ -399,7 +398,7 @@ def bernoulli(key, mean=onp.float32(0.5), shape=()): @partial(jit, static_argnums=(2,)) def _bernoulli(key, mean, shape): shape = shape or onp.shape(mean) - if not onp.issubdtype(lax._dtype(mean), onp.float32): + if not onp.issubdtype(lax.dtype(mean), onp.float32): mean = lax.convert_element_type(mean, onp.float32) if onp.shape(mean) != shape: mean = np.broadcast_to(mean, shape) @@ -457,7 +456,7 @@ def _gamma_one(key, alpha): one_over_two = _constant_like(alpha, 0.5) one_over_three = _constant_like(alpha, 1. / 3.) squeeze_const = _constant_like(alpha, 0.0331) - dtype = lax._dtype(alpha) + dtype = lax.dtype(alpha) key, subkey = split(key) # for alpha < 1, we boost alpha to alpha + 1 and get a sample according to @@ -472,7 +471,7 @@ def _gamma_one(key, alpha): def _cond_fn(kXVU): _, X, V, U = kXVU - # TODO: use lax_control_flow.cond when its batching rule is supported + # TODO: use lax.cond when its batching rule is supported # The reason is to avoid evaluating second condition which involves log+log # if the first condition is satisfied cond = lax.bitwise_and(lax.ge(U, lax.sub(one, lax.mul(squeeze_const, lax.mul(X, X)))), @@ -492,7 +491,7 @@ def _gamma_one(key, alpha): return key, X, V, U # initial state is chosen such that _cond_fn will return True - _, _, V, _ = lax_control_flow.while_loop( + _, _, V, _ = lax.while_loop( _cond_fn, _body_fn, (key, zero, _constant_like(alpha, -1), zero)) z = lax.mul(lax.mul(d, V), boost) return lax.select(lax.eq(z, zero), onp.finfo(z.dtype).tiny, z) diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 240984a6d..8514c3df7 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -129,7 +129,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True): del overwrite_a, check_finite a = np_linalg._promote_arg_dtypes(np.asarray(a)) lu, pivots = lax_linalg.lu(a) - dtype = lax._dtype(a) + dtype = lax.dtype(a) m, n = np.shape(a) permutation = lax_linalg.lu_pivots_to_permutation(pivots, m) p = np.real(np.array(permutation == np.arange(m)[:, None], dtype=dtype)) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index ed92c4050..ab6985be3 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -165,7 +165,7 @@ def ndtr(x): TypeError: if `x` is not floating-type. """ x = np.asarray(x) - dtype = lax._dtype(x) + dtype = lax.dtype(x) if dtype not in (np.float32, np.float64): raise TypeError( "x.dtype={} is not supported, see docstring for supported types." @@ -175,7 +175,7 @@ def ndtr(x): def _ndtr(x): """Implements ndtr core logic.""" - dtype = lax._dtype(x).type + dtype = lax.dtype(x).type half_sqrt_2 = dtype(0.5) * onp.sqrt(2., dtype=dtype) w = x * half_sqrt_2 z = lax.abs(w) @@ -206,7 +206,7 @@ def ndtri(p): TypeError: if `p` is not floating-type. """ x = np.asarray(p) - dtype = lax._dtype(p) + dtype = lax.dtype(p) if dtype not in (np.float32, np.float64): raise TypeError( "x.dtype={} is not supported, see docstring for supported types." @@ -271,7 +271,7 @@ def _ndtri(p): 2.89247864745380683936E-6, 6.79019408009981274425E-9])) - dtype = lax._dtype(p).type + dtype = lax.dtype(p).type shape = np.shape(p) def _create_polynomial(var, coeffs): @@ -391,7 +391,7 @@ def log_ndtr(x, series_order=3): raise ValueError("series_order must be <= 30.") x = np.asarray(x) - dtype = lax._dtype(x) + dtype = lax.dtype(x) if dtype == np.float64: lower_segment = _LOGNDTR_FLOAT64_LOWER @@ -427,7 +427,7 @@ def log_ndtr(x, series_order=3): def _log_ndtr_lower(x, series_order): """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`.""" - dtype = lax._dtype(x).type + dtype = lax.dtype(x).type x_2 = lax.square(x) # Log of the term multiplying (1 + sum) log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * onp.log(2. * onp.pi)) @@ -436,7 +436,7 @@ def _log_ndtr_lower(x, series_order): def _log_ndtr_asymptotic_series(x, series_order): """Calculates the asymptotic series used in log_ndtr.""" - dtype = lax._dtype(x).type + dtype = lax.dtype(x).type if series_order <= 0: return onp.array(1, dtype) x_2 = lax.square(x) diff --git a/notebooks/Common_Gotchas_in_JAX.ipynb b/notebooks/Common_Gotchas_in_JAX.ipynb index 19ba66dcb..82d5cbeea 100644 --- a/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/notebooks/Common_Gotchas_in_JAX.ipynb @@ -233,7 +233,7 @@ "\n", "__NB__: _Fancy Indexing_ is __not__ yet supported, but will likely be added to JAX soon.\n", "\n", - "️⚠️ inside `jit`'d code and `lax_control_flow.while_loop` or `lax_control_flow.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation." + "️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation." ] }, { @@ -1166,10 +1166,10 @@ "## Structured control flow primitives\n", "\n", "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:\n", - " - `lax_control_flow.cond` _will be differentiable soon_\n", - " - `lax_control_flow.while_loop` __non-differentiable__*\n", - " - `lax_control_flow.fori_loop` __non-differentiable__*\n", - " - `lax_control_flow.scan` _will be differentiable soon_\n", + " - `lax.cond` _will be differentiable soon_\n", + " - `lax.while_loop` __non-differentiable__*\n", + " - `lax.fori_loop` __non-differentiable__*\n", + " - `lax.scan` _will be differentiable soon_\n", "\n", "*_these can in principle be made to be __forward__-differentiable, but this isn't on the current roadmap._" ] @@ -1208,9 +1208,9 @@ "from jax import lax\n", "\n", "operand = np.array([0.])\n", - "lax_control_flow.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n", + "lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n", "# --> array([1.], dtype=float32)\n", - "lax_control_flow.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n", + "lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n", "# --> array([-1.], dtype=float32)" ], "execution_count": 0, @@ -1263,7 +1263,7 @@ "init_val = 0\n", "cond_fun = lambda x: x<10\n", "body_fun = lambda x: x+1\n", - "lax_control_flow.while_loop(cond_fun, body_fun, init_val)\n", + "lax.while_loop(cond_fun, body_fun, init_val)\n", "# --> array(10, dtype=int32)" ], "execution_count": 0, @@ -1316,7 +1316,7 @@ "start = 0\n", "stop = 10\n", "body_fun = lambda i,x: x+i\n", - "lax_control_flow.fori_loop(start, stop, body_fun, init_val)\n", + "lax.fori_loop(start, stop, body_fun, init_val)\n", "# --> array(45, dtype=int32)" ], "execution_count": 0, @@ -1354,10 +1354,10 @@ "\\textrm{if} & ❌ & ✔ \\\\\n", "\\textrm{for} & ✔* & ✔\\\\\n", "\\textrm{while} & ✔* & ✔\\\\\n", - "\\textrm{lax_control_flow.cond} & ✔ & \\textrm{soon!}\\\\\n", - "\\textrm{lax_control_flow.while_loop} & ✔ & ❌\\\\\n", - "\\textrm{lax_control_flow.fori_loop} & ✔ & ❌\\\\\n", - "\\textrm{lax_control_flow.scan} & \\textrm{soon!} & \\textrm{soon!}\\\\\n", + "\\textrm{lax.cond} & ✔ & \\textrm{soon!}\\\\\n", + "\\textrm{lax.while_loop} & ✔ & ❌\\\\\n", + "\\textrm{lax.fori_loop} & ✔ & ❌\\\\\n", + "\\textrm{lax.scan} & \\textrm{soon!} & \\textrm{soon!}\\\\\n", "\\hline\n", "\\end{array}\n", "$$\n", diff --git a/tests/batching_test.py b/tests/batching_test.py index 80d51ba64..b6eab209d 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -24,7 +24,6 @@ import jax.numpy as np from jax import test_util as jtu from jax.abstract_arrays import ShapedArray from jax import lax -from jax import lax_control_flow from jax import lax_linalg from jax import random from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr, jacfwd, jacrev, hessian @@ -875,7 +874,7 @@ class BatchingTest(jtu.JaxTestCase): def testWhileLoop(self): def fun(x): - return lax_control_flow.while_loop(lambda x: x < 3, lambda x: x + 2, x) + return lax.while_loop(lambda x: x < 3, lambda x: x + 2, x) ans = vmap(fun)(onp.array([0, 1, 2, 3])) expected = onp.array([4, 3, 4, 3]) @@ -888,7 +887,7 @@ class BatchingTest(jtu.JaxTestCase): def testWhileLoopCondConstsBatched(self): def fun(x, y): - return lax_control_flow.while_loop(lambda x: x < y, lambda x: x + 2, x) + return lax.while_loop(lambda x: x < y, lambda x: x + 2, x) ans = vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3])) expected = onp.array([2, 4]) @@ -896,7 +895,7 @@ class BatchingTest(jtu.JaxTestCase): def testWhileLoopBodyConstsBatched(self): def fun(x, y): - return lax_control_flow.while_loop(lambda x: x < 3, lambda x: x + y, x) + return lax.while_loop(lambda x: x < 3, lambda x: x + y, x) ans = vmap(fun, in_axes=(None, 0))(0, onp.array([2, 3])) expected = onp.array([4, 3]) @@ -913,7 +912,7 @@ class BatchingTest(jtu.JaxTestCase): return x, y def fun(x, y): - return lax_control_flow.while_loop(cond_fun, body_fun, (x, y)) + return lax.while_loop(cond_fun, body_fun, (x, y)) ans = vmap(fun)(onp.array([0, 0]), onp.array([1, 2])) expected = (onp.array([4, 3]), onp.array([1, 2])) @@ -927,7 +926,7 @@ class BatchingTest(jtu.JaxTestCase): return x, y def fun(x): - return lax_control_flow.fori_loop(0, 10, body_fun, (x, 0)) + return lax.fori_loop(0, 10, body_fun, (x, 0)) ans = vmap(fun)(onp.array([0, 1])) expected = (onp.array([10, 11]), onp.array([20, 20])) @@ -941,7 +940,7 @@ class BatchingTest(jtu.JaxTestCase): key, _ = random.split(key) return u, key - u, _ = lax_control_flow.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key)) + u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key)) return u print(vmap(f)(random.split(random.PRNGKey(0), 2))) # no crash diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 54c585484..b6785a66b 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -25,7 +25,6 @@ import numpy.random as npr from jax import api from jax import lax -from jax import lax_control_flow from jax import test_util as jtu @@ -43,7 +42,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return (lax.add(pos, 1), lax.add(count, 1)) def loop(init): - result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0)) + result = lax.while_loop(loop_cond, loop_body, (init, 0)) _, count = result return count @@ -66,7 +65,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return (num, lax.add(i, 1), inner_loop(i, count)) init_val = (num, 0, 0) - _, i, count = lax_control_flow.while_loop(cond_fun, body_fun, init_val) + _, i, count = lax.while_loop(cond_fun, body_fun, init_val) return (i, count) def inner_loop(i, count): # pylint: disable=missing-docstring @@ -79,7 +78,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return (i, lax.add(j, 1), lax.add(count, 1)) init_val = (i, 0, count) - _, _, count = lax_control_flow.while_loop(cond_fun, body_fun, init_val) + _, _, count = lax.while_loop(cond_fun, body_fun, init_val) return count cloop = api.jit(outer_loop) @@ -103,7 +102,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): pos, count = state return (lax.add(pos, 1), lax.add(count, inc)) - result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0)) + result = lax.while_loop(loop_cond, loop_body, (init, 0)) _, count = result return count @@ -135,7 +134,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc)) return api.jit(f)(pos, inc) - result = lax_control_flow.while_loop(loop_cond, loop_body, (init, 0)) + result = lax.while_loop(loop_cond, loop_body, (init, 0)) _, count = result return count @@ -172,7 +171,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): out = onp.zeros(arr.shape, dtype=arr.dtype) init_val = (0, num, arr, out) - _, _, _, out = lax_control_flow.while_loop(cond_fun, body_fun, init_val) + _, _, _, out = lax.while_loop(cond_fun, body_fun, init_val) return out def inner_loop(i, arr, out): # pylint: disable=missing-docstring @@ -189,7 +188,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return (i, lax.add(j, 1), arr, out) init_val = (i, 0, arr, out) - _, _, _, out = lax_control_flow.while_loop(cond_fun, body_fun, init_val) + _, _, _, out = lax.while_loop(cond_fun, body_fun, init_val) return out cloop = api.jit(outer_loop) @@ -210,7 +209,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return (arr, num, lax.add(i, 1), lax.add(total, arr_i)) init_val = (arr, num, 0, 0.) - _, _, _, total = lax_control_flow.while_loop(cond_fun, body_fun, init_val) + _, _, _, total = lax.while_loop(cond_fun, body_fun, init_val) return total cfun = api.jit(sum_first_n) @@ -226,7 +225,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): def count(num): def body_fun(i, tot): return lax.add(tot, i) - return lax_control_flow.fori_loop(0, num, body_fun, 0) + return lax.fori_loop(0, num, body_fun, 0) cfun = api.jit(count) @@ -241,7 +240,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): def count(num): def body_fun(i, tot): return lax.add(num, lax.add(tot, i)) - return lax_control_flow.fori_loop(0, num, body_fun, 0) + return lax.fori_loop(0, num, body_fun, 0) cfun = api.jit(count) @@ -260,7 +259,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return (arr, lax.add(total, arr_i)) init_val = (arr, 0.) - _, total = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun, + _, total = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) return total @@ -281,7 +280,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return {'arr': arr, 'total': lax.add(total, arr_i)} init_val = {'arr': arr, 'total': 0.} - out_val = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) + out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) return out_val['total'] cfun = api.jit(sum_first_n) @@ -301,7 +300,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): return (arr, lax.add(total, arr_i), ()) init_val = (arr, 0., ()) - _, tot, _ = lax_control_flow.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) + _, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val) return tot cfun = api.jit(sum_first_n) @@ -326,7 +325,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): def false_fun(x): y = lax.mul(2, x) return y, lax.mul(2, y) - return lax_control_flow.cond(lax.lt(x, 3), x, lambda x: (x, x), x, false_fun) + return lax.cond(lax.lt(x, 3), x, lambda x: (x, x), x, false_fun) self.assertEqual(fun(0), cfun(0)) self.assertEqual(fun(0), (0, 0)) @@ -351,10 +350,10 @@ class LaxControlFlowTest(jtu.JaxTestCase): @api.jit def cfun(x): - return lax_control_flow.cond( + return lax.cond( lax.lt(x, 2), x, lambda x: lax.mul(2, x), - x, lambda x: lax_control_flow.cond(lax.lt(x, 5), + x, lambda x: lax.cond(lax.lt(x, 5), x, lambda x: lax.mul(3, x), 4, lambda y: lax.mul(y, x))) @@ -374,7 +373,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): @api.jit def cfun(x): - return lax_control_flow.cond(lax.lt(x, 3), x, lambda x: 5, x, lambda x: x) + return lax.cond(lax.lt(x, 3), x, lambda x: 5, x, lambda x: x) self.assertEqual(fun(0), cfun(0)) self.assertEqual(cfun(0), 5) @@ -390,7 +389,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): @api.jit def cfun(x): - return lax_control_flow.cond(lax.lt(x, 3), + return lax.cond(lax.lt(x, 3), x, lambda x: (1, 2., 3.), x, lambda x: (x, 2., 4.)) @@ -401,7 +400,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): def testIssue514(self): # just check this doesn't crash - lax_control_flow.cond(True, + lax.cond(True, (0, 0), lambda x: (x[0], 0), (1, 1), lambda x: x) diff --git a/tests/parallel_test.py b/tests/parallel_test.py index 1235c1a70..06681150c 100644 --- a/tests/parallel_test.py +++ b/tests/parallel_test.py @@ -25,7 +25,6 @@ from absl.testing import parameterized import jax.numpy as np from jax import test_util as jtu from jax import lax -from jax import lax_parallel from jax.api import _serial_pmap, _papply, jit, make_jaxpr from jax.linear_util import wrap_init @@ -42,40 +41,40 @@ class SerialPmapTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def testReduceSum(self): - f = lambda x: lax_parallel.psum(x, 'i') + f = lambda x: lax.psum(x, 'i') ans = _serial_pmap(f, axis_name='i')(onp.ones(4)) expected = 4 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testReduceMax(self): - f = lambda x: lax_parallel.pmax(x, 'i') + f = lambda x: lax.pmax(x, 'i') ans = _serial_pmap(f, axis_name='i')(onp.arange(4)) expected = 3 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testPsplit(self): - f = lambda x: lax_parallel.psplit(x, 'i', 2) + f = lambda x: lax.psplit(x, 'i', 2) arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5) ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg) expected = arg self.assertAllClose(ans, expected, check_dtypes=False) def testPsplitLike(self): - f = lambda x, y: lax_parallel.psplit_like(x, y, 'i') + f = lambda x, y: lax.psplit_like(x, y, 'i') arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5) ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg) expected = arg self.assertAllClose(ans, expected, check_dtypes=False) def testLogSoftmax(self): - f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i')) + f = lambda x: x - np.log(lax.psum(np.exp(x), 'i')) x = onp.log(onp.arange(1., 10., dtype=onp.float32)) ans = _serial_pmap(f, axis_name='i')(x) expected = x - onp.log(onp.sum(onp.exp(x))) self.assertAllClose(ans, expected, check_dtypes=False) def testNested(self): - f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j') + f = lambda x: lax.psum(lax.psum(x, 'i'), 'j') x = onp.ones((2, 2)) ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x) ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x) @@ -103,7 +102,7 @@ class PapplyTest(jtu.JaxTestCase): jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr( - lambda x: lax_parallel.psum(x, axis_name))(onp.zeros((5, 3))) + lambda x: lax.psum(x, axis_name))(onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) @@ -116,7 +115,7 @@ class PapplyTest(jtu.JaxTestCase): jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr( - lambda x: lax_parallel.pmax(x, axis_name))(onp.zeros((5, 3))) + lambda x: lax.pmax(x, axis_name))(onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) @@ -135,9 +134,9 @@ class PapplyTest(jtu.JaxTestCase): def expected_spmd(p, t, f): return lax.select( - lax_parallel.psplit_like(p, t, axis_name), + lax.psplit_like(p, t, axis_name), t, - lax_parallel.psplit_like(f, t, axis_name)) + lax.psplit_like(f, t, axis_name)) expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f) assert repr(jaxpr) == repr(expected_jaxpr) @@ -156,7 +155,7 @@ class PapplyTest(jtu.JaxTestCase): jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( - lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5)) + lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.)) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 01f6da5d4..b3e87a253 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -27,7 +27,6 @@ import jax.numpy as np from jax import test_util as jtu from jax import lax from jax.api import pmap, vmap, jvp, grad, make_jaxpr, linearize, device_put -from jax.lax_parallel import psum from jax.lib import xla_bridge from jax.util import prod from jax.interpreters import pxla @@ -54,7 +53,7 @@ class PmapTest(jtu.JaxTestCase): return device_mesh_shape def testBasic(self): - f = pmap(lambda x: x - psum(x, 'i'), axis_name='i') + f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i') shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) @@ -64,7 +63,7 @@ class PmapTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def testNestedBasic(self): - f = lambda x: psum(psum(x, 'i'), 'j') + f = lambda x: lax.psum(lax.psum(x, 'i'), 'j') f = pmap(pmap(f, 'i'), 'j') def sum_and_broadcast(x, axis): @@ -145,7 +144,7 @@ class PmapTest(jtu.JaxTestCase): def testTwoArgsGrad(self): def f(x, y): - return psum(5. * np.cos(x) * np.sin(y), 'i') + return lax.psum(5. * np.cos(x) * np.sin(y), 'i') f = pmap(f, 'i') def g(x, y): @@ -223,7 +222,7 @@ class PmapTest(jtu.JaxTestCase): self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False) def testPsumMultiple(self): - f = lambda x: psum(x, ('i', 'j')) + f = lambda x: lax.psum(x, ('i', 'j')) f = pmap(pmap(f, 'i'), 'j') def sum_and_broadcast(x, axis):