Merge branch 'master' into multibackend

This commit is contained in:
Matthew Johnson 2019-08-25 13:30:21 -07:00 committed by GitHub
commit 0cc21c8d72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 287 additions and 51 deletions

View File

@ -9,11 +9,6 @@ jobs:
build:
runs-on: macOS-latest
strategy:
max-parallel: 4
matrix:
python-version: [2.7, 3.5, 3.6, 3.7]
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v1
@ -28,5 +23,6 @@ jobs:
- uses: actions/setup-python@v1
with:
python-version: 3.7
- run: python -m pip install --upgrade pyenv pyenv-virtualenv
- run: brew install pyenv
- run: brew install pyenv-virtualenv
- run: build/build_jaxlib_wheels_macos.sh

View File

@ -23,10 +23,10 @@ http_archive(
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "461df411fccc278244edc32496e2d846fcb96ab019ea352c51476b6edcbdcc5b",
strip_prefix = "tensorflow-9f4fc034f686d9a484f5613a7d840a4bbcfe0e27",
sha256 = "cbabcb6616a0429aac544f80fd101076fbe4d99c48f8f728cb337a1da613b1d8",
strip_prefix = "tensorflow-006e2933990258fbe3cffe1580ce52894056c999",
urls = [
"https://github.com/tensorflow/tensorflow/archive/9f4fc034f686d9a484f5613a7d840a4bbcfe0e27.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/006e2933990258fbe3cffe1580ce52894056c999.tar.gz",
],
)

View File

@ -18,7 +18,7 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_reduce_and, _reduce_window_sum, _reduce_window_max,
_reduce_window_min, _reduce_window_prod, _float, _complex,
_input_dtype, _const, _eq_meet, _safe_mul,
_broadcasting_select)
_broadcasting_select, _check_user_dtype_supported)
from .lax_control_flow import *
from .lax_fft import *
from .lax_parallel import *

View File

@ -4418,3 +4418,15 @@ def subvals(lst, replace):
def _abstractify(x):
return raise_to_shaped(core.get_aval(x))
def _check_user_dtype_supported(dtype, fun_name=None):
if dtype is not None and onp.dtype(dtype) != xla_bridge.canonicalize_dtype(dtype):
msg = ("Explicitly requested dtype {} {} is not available, "
"and will be truncated to dtype {}. To enable more dtypes, set the "
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
"environment variable. "
"See https://github.com/google/jax#current-gotchas for more.")
fun_name = "requested in {}".format(fun_name) if fun_name else ""
truncated_dtype = xla_bridge.canonicalize_dtype(dtype).name
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))

View File

@ -54,7 +54,7 @@ def _initial_style_jaxpr(fun, in_tree, in_avals):
in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
out_avals, _ = unzip2(out_pvals)
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
(), const_avals + in_avals, out_avals)
@ -164,7 +164,7 @@ def while_loop(cond_fun, body_fun, init_val):
raise TypeError(msg.format(cond_tree))
if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]:
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(coud_jaxpr.out_avals))
raise TypeError(msg.format(cond_jaxpr.out_avals))
if not treedef_children(in_tree) == [body_tree]:
msg = "body_fun output pytree structure must match init_val, got {} and {}."
raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0]))
@ -302,18 +302,6 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
true_nconsts=len(true_consts), false_nconsts=len(false_consts))
return tree_unflatten(out_tree, out)
def _cond_impl(pred, *args, **kwargs):
true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
true_nops = len(true_jaxpr.in_avals) - true_nconsts
true_consts, true_ops, false_consts, false_ops = split_list(
args, [true_nconsts, true_nops, false_nconsts])
if pred:
return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops))
else:
return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
def _cond_abstract_eval(*args, **kwargs):
return kwargs["true_jaxpr"].out_avals
@ -341,10 +329,47 @@ def _cond_translation_rule(c, axis_env, pred, *args, **kwargs):
return c.Conditional(pred, true_op, true_c, false_op, false_c)
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
false_nconsts):
# TODO: maybe avoid moving arg axes to front if we're promoting to select?
args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(args, dims)]
true_nops = len(true_jaxpr.in_avals) - true_nconsts
(pred,), true_consts, true_ops, false_consts, false_ops = split_list(
args, [1, true_nconsts, true_nops, false_nconsts])
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
orig_bat = [d is not batching.not_mapped for d in dims]
(pred_bat,), t_bat, tconst_bat, f_bat, fconst_bat = split_list(
orig_bat, [1, true_nconsts, len(true_ops), false_nconsts])
_, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
_, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]
true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)
if pred_bat:
true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
true_out = [batching.broadcast(x, size, 0) if not b else x
for x, b in zip(true_out, out_bat)]
false_out = [batching.broadcast(x, size, 0) if not b else x
for x, b in zip(false_out, out_bat)]
return [lax.select(pred, t, f)
for t, f in zip(true_out, false_out)], [0] * len(true_out)
else:
out_dims = [0 if b else batching.not_mapped for b in out_bat]
return cond_p.bind(
*itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
cond_p = lax.Primitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(_cond_impl)
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_abstract_eval(_cond_abstract_eval)
batching.primitive_batchers[cond_p] = _cond_batching_rule
xla.initial_style_translations[cond_p] = _cond_translation_rule

View File

@ -27,12 +27,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from distutils.util import strtobool
import collections
import itertools
import os
import re
import string
import warnings
import types
import warnings
import numpy as onp
import opt_einsum
@ -42,12 +44,21 @@ from six.moves import builtins, xrange
from jax import jit, device_put
from .. import core
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from ..config import flags
from ..interpreters.xla import DeviceArray
from .. import lax
from ..util import partial, get_module_functions, unzip2, prod as _prod
from ..lib import xla_bridge
from ..lib import xla_client
FLAGS = flags.FLAGS
flags.DEFINE_enum(
'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'),
enum_values=['allow', 'warn', 'raise'],
help=
'Control NumPy-style automatic rank promotion broadcasting '
'("allow", "warn", or "raise").')
if six.PY3:
def removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))
@ -158,9 +169,20 @@ def _promote_shapes(*args):
return args
else:
shapes = [shape(arg) for arg in args]
nd = len(lax.broadcast_shapes(*shapes))
return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
if shp and len(shp) != nd else arg for arg, shp in zip(args, shapes)]
ranks = [len(shp) for shp in shapes]
if len(set(ranks)) == 1:
return args
elif FLAGS.jax_numpy_rank_promotion != "raise":
if FLAGS.jax_numpy_rank_promotion == "warn":
msg = "following NumPy automatic rank promotion behavior for {}."
warnings.warn(msg.format(' '.join(map(str, shapes))))
nd = len(lax.broadcast_shapes(*shapes))
return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
if shp and len(shp) != nd else arg for arg, shp in zip(args, shapes)]
else:
msg = ("operands could not be broadcast together with shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'.")
raise ValueError(msg.format(' '.join(map(str, shapes))))
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
@ -1386,6 +1408,7 @@ def atleast_3d(*arys):
def array(object, dtype=None, copy=True, order="K", ndmin=0):
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
lax._check_user_dtype_supported(dtype, "array")
if isinstance(object, ndarray):
if dtype and _dtype(object) != xla_bridge.canonicalize_dtype(dtype):
@ -1420,38 +1443,49 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
@_wraps(onp.asarray)
def asarray(a, dtype=None, order=None):
lax._check_user_dtype_supported(dtype, "asarray")
return array(a, dtype=dtype, copy=False, order=order)
@_wraps(onp.zeros_like)
def zeros_like(x, dtype=None):
lax._check_user_dtype_supported(dtype, "zeros_like")
return lax.full_like(x, 0, dtype)
@_wraps(onp.ones_like)
def ones_like(x, dtype=None):
lax._check_user_dtype_supported(dtype, "ones_like")
return lax.full_like(x, 1, dtype)
@_wraps(onp.full)
def full(shape, fill_value, dtype=None):
lax._check_user_dtype_supported(dtype, "full")
return lax.full(shape, fill_value, dtype)
@_wraps(onp.full_like)
def full_like(a, fill_value, dtype=None):
lax._check_user_dtype_supported(dtype, "full_like")
return lax.full_like(a, fill_value, dtype)
@_wraps(onp.zeros)
def zeros(shape, dtype=onp.dtype("float64")):
def zeros(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "zeros")
dtype = onp.dtype("float64") if dtype is None else dtype
shape = (shape,) if onp.isscalar(shape) else shape
return lax.full(shape, 0, dtype)
@_wraps(onp.ones)
def ones(shape, dtype=onp.dtype("float64")):
def ones(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "ones")
dtype = onp.dtype("float64") if dtype is None else dtype
shape = (shape,) if onp.isscalar(shape) else shape
return lax.full(shape, 1, dtype)
@ -1471,7 +1505,9 @@ empty = zeros
@_wraps(onp.eye)
def eye(N, M=None, k=None, dtype=onp.dtype("float64")):
def eye(N, M=None, k=None, dtype=None):
lax._check_user_dtype_supported(dtype, "eye")
dtype = onp.dtype("float64") if dtype is None else dtype
M = N if M is None else M
if N < 0 or M < 0:
msg = "negative dimensions are not allowed, got {} and {}"
@ -1490,11 +1526,13 @@ def eye(N, M=None, k=None, dtype=onp.dtype("float64")):
@_wraps(onp.identity)
def identity(n, dtype=None):
lax._check_user_dtype_supported(dtype, "identity")
return eye(n, dtype=dtype)
@_wraps(onp.arange)
def arange(start, stop=None, step=None, dtype=None):
lax._check_user_dtype_supported(dtype, "arange")
# If called like np.arange(N), we create a lazy lax._IotaConstant.
if stop is None and step is None:
dtype = dtype or _dtype(start)
@ -1516,6 +1554,7 @@ def _wrap_numpy_nullary_function(f):
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
lax._check_user_dtype_supported(dtype, "linspace")
try:
out = onp.linspace(start, stop, num, endpoint, retstep, dtype, axis)
if retstep:
@ -1624,6 +1663,7 @@ def repeat(a, repeats, axis=None):
@_wraps(onp.tri)
def tri(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
x = arange(N, dtype=int32)
@ -1657,6 +1697,7 @@ def triu(m, k=0):
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
if out:
raise NotImplementedError("The 'out' argument to trace is not supported.")
lax._check_user_dtype_supported(dtype, "trace")
axis1 = _canonicalize_axis(axis1, ndim(a))
axis2 = _canonicalize_axis(axis2, ndim(a))
@ -2817,6 +2858,10 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims)
def _astype(arr, dtype):
lax._check_user_dtype_supported(dtype, "astype")
return lax.convert_element_type(arr, dtype)
### track unimplemented functions
def _not_implemented(fun):
@ -2912,7 +2957,7 @@ setattr(ShapedArray, "flatten", core.aval_method(ravel))
setattr(ShapedArray, "T", core.aval_property(transpose))
setattr(ShapedArray, "real", core.aval_property(real))
setattr(ShapedArray, "imag", core.aval_property(imag))
setattr(ShapedArray, "astype", core.aval_method(lax.convert_element_type))
setattr(ShapedArray, "astype", core.aval_method(_astype))
# Forward operators, methods, and properties on DeviceArray to lax_numpy
@ -2926,7 +2971,7 @@ setattr(DeviceArray, "flatten", ravel)
setattr(DeviceArray, "T", property(transpose))
setattr(DeviceArray, "real", property(real))
setattr(DeviceArray, "imag", property(imag))
setattr(DeviceArray, "astype", lax.convert_element_type)
setattr(DeviceArray, "astype", _astype)
# Extra methods that are handy

View File

@ -14,9 +14,6 @@
"""Utilities for working with tree-like container data structures.
The code here is independent of JAX. The only dependence is on jax.util, which
itself has no JAX-specific code.
This module provides a small set of utility functions for working with tree-like
data structures, such as nested tuples, lists, and dicts. We call these
structures pytrees. They are trees in that they are defined recursively (any
@ -29,6 +26,10 @@ mapped over, rather than treated as leaves) is extensible. There is a single
module-level registry of types, and class hierarchy is ignored. By registering a
new pytree node type, that type in effect becomes transparent to the utility
functions in this file.
The primary purpose of this module is to enable the interoperability between
user defined data structures and JAX transformations (e.g. `jit`). This is not
meant to be a general purpose tree-like data structure handling library.
"""
from __future__ import absolute_import

View File

@ -159,6 +159,7 @@ class PyTreeDef {
private:
enum class Kind {
kLeaf, // An opaque leaf node
kNone, // None.
kTuple, // A tuple
kNamedTuple, // A collections.namedtuple
kList, // A list
@ -247,7 +248,9 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
Node node;
int start_num_nodes = tree->traversal_.size();
int start_num_leaves = leaves->size();
if (PyTuple_CheckExact(handle.ptr())) {
if (py::isinstance<py::none>(handle)) {
node.kind = Kind::kNone;
} else if (PyTuple_CheckExact(handle.ptr())) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.kind = Kind::kTuple;
node.arity = tuple.size();
@ -334,6 +337,7 @@ py::object PyTreeDef::Unflatten(py::iterable leaves) const {
++leaf_count;
break;
case Kind::kNone:
case Kind::kTuple:
case Kind::kNamedTuple:
case Kind::kList:
@ -368,6 +372,9 @@ py::object PyTreeDef::Unflatten(py::iterable leaves) const {
case Kind::kLeaf:
throw std::logic_error("MakeNode not implemented for leaves.");
case Kind::kNone:
return py::none();
case Kind::kTuple:
case Kind::kNamedTuple: {
py::tuple tuple(node.arity);
@ -434,6 +441,9 @@ py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
--leaf;
break;
case Kind::kNone:
break;
case Kind::kTuple: {
if (!PyTuple_CheckExact(object.ptr())) {
throw std::invalid_argument(
@ -570,6 +580,7 @@ py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
break;
}
case Kind::kNone:
case Kind::kTuple:
case Kind::kNamedTuple:
case Kind::kList:
@ -694,6 +705,9 @@ std::string PyTreeDef::ToString() const {
case Kind::kLeaf:
agenda.push_back("*");
continue;
case Kind::kNone:
kind = "None";
break;
case Kind::kNamedTuple:
kind = "namedtuple";
break;

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.1.25"
__version__ = "0.1.26"

View File

@ -28,7 +28,7 @@ setup(
author_email='jax-dev@google.com',
packages=find_packages(exclude=["examples"]),
install_requires=[
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum<3',
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum',
'fastcache'
],
url='https://github.com/google/jax',

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import collections
from functools import partial
import unittest
import warnings
from absl.testing import absltest
import numpy as onp
@ -41,6 +42,7 @@ from jax import tree_util
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class APITest(jtu.JaxTestCase):
@ -186,9 +188,15 @@ class APITest(jtu.JaxTestCase):
def f(x, y):
return x + y
jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
ValueError,
"Incompatible shapes for broadcasting: ((3,), (4,))")
jtu.check_raises(
lambda: f(np.zeros(3), np.zeros(4)),
TypeError,
"add got incompatible shapes for broadcasting: (3,), (4,).")
jtu.check_raises(
lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
TypeError,
"add got incompatible shapes for broadcasting: (3,), (4,).")
def test_dot_mismatch(self):
def f(x, y):
@ -969,6 +977,54 @@ class APITest(jtu.JaxTestCase):
for x, y in zip(xs, ys):
self.assertAllClose(x * 2 - 3., y, check_dtypes=True)
def test_dtype_warning(self):
# cf. issue #1230
if FLAGS.jax_enable_x64:
return # test only applies when x64 is disabled
def check_warning(warn, nowarn):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
nowarn() # get rid of extra startup warning
prev_len = len(w)
nowarn()
assert len(w) == prev_len
warn()
assert len(w) > 0
msg = str(w[-1].message)
expected_prefix = "Explicitly requested dtype "
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
prev_len = len(w)
nowarn()
assert len(w) == prev_len
check_warning(lambda: np.array([1, 2, 3], dtype="float64"),
lambda: np.array([1, 2, 3], dtype="float32"),)
check_warning(lambda: np.ones(3, dtype=onp.float64),
lambda: np.ones(3))
check_warning(lambda: np.ones_like(3, dtype=onp.int64),
lambda: np.ones_like(3, dtype=onp.int32))
check_warning(lambda: np.zeros(3, dtype="int64"),
lambda: np.zeros(3, dtype="int32"))
check_warning(lambda: np.zeros_like(3, dtype="float64"),
lambda: np.zeros_like(3, dtype="float32"))
check_warning(lambda: np.full((2, 3), 1, dtype="int64"),
lambda: np.full((2, 3), 1))
check_warning(lambda: np.ones(3).astype("float64"),
lambda: np.ones(3).astype("float32"))
check_warning(lambda: np.eye(3, dtype=onp.float64),
lambda: np.eye(3))
check_warning(lambda: np.arange(3, dtype=onp.float64),
lambda: np.arange(3, dtype=onp.float32))
check_warning(lambda: np.linspace(0, 3, dtype=onp.float64),
lambda: np.linspace(0, 3, dtype=onp.float32))
check_warning(lambda: np.tri(2, dtype="float64"),
lambda: np.tri(2, dtype="float32"))
if __name__ == '__main__':
absltest.main()

View File

@ -475,6 +475,58 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(fun(4), cfun(4))
self.assertEqual(cfun(4), (4, 2., 4.))
def testCondBatched(self):
def fun(x, y, z):
pred = lax.lt(x, 3)
true_fun = lambda y: y
false_fun = lambda z: lax.neg(z)
return lax.cond(pred, y, true_fun, z, false_fun)
# these cases stay as cond
x = onp.array(2)
y = onp.array([1, 2])
z = onp.array([3, 4])
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
expected = onp.array([1, 2])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
x = onp.array(4)
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
expected = onp.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
fun = api.jit(fun)
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
expected = onp.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
z = onp.array(5)
ans = api.vmap(fun, (None, 0, None))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None)))(x, y, z)
expected = onp.array([-5, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
# these cases become select
x = onp.array([2, 4])
ans = api.vmap(fun, (0, 0, None))(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None)))(x, y, z)
expected = onp.array([1, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
z = onp.array([3, 4])
ans = api.vmap(fun)(x, y, z)
jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z)
expected = onp.array([1, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
def testIssue514(self):
# just check this doesn't crash
lax.cond(True,
@ -900,6 +952,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
python_should_be_executing = False
lax.while_loop(cond, body, 0)
def testWhileCondConstant(self):
out = lax.while_loop(lambda _: False, lambda _: (), ()) # doesn't crash
self.assertEqual(out, ())
if __name__ == '__main__':
absltest.main()

View File

@ -23,6 +23,7 @@ import itertools
import operator
import unittest
from unittest import SkipTest
import warnings
from absl.testing import absltest
from absl.testing import parameterized
@ -1829,5 +1830,35 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
lnp_fun = partial(lnp.meshgrid, indexing=indexing, sparse=sparse)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def testDisableNumpyRankPromotionBroadcasting(self):
try:
prev_flag = FLAGS.jax_numpy_rank_promotion
FLAGS.jax_numpy_rank_promotion = "allow"
lnp.ones(2) + lnp.ones((1, 2)) # works just fine
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = FLAGS.jax_numpy_rank_promotion
FLAGS.jax_numpy_rank_promotion = "raise"
self.assertRaises(ValueError, lambda: lnp.ones(2) + lnp.ones((1, 2)))
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = FLAGS.jax_numpy_rank_promotion
FLAGS.jax_numpy_rank_promotion = "warn"
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
lnp.ones(2) + lnp.ones((1, 2))
assert len(w) > 0
msg = str(w[-1].message)
self.assertEqual(
msg,
"following NumPy automatic rank promotion behavior for (2,) (1, 2).")
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
if __name__ == "__main__":
absltest.main()

View File

@ -57,10 +57,10 @@ PYTREES = [
((),),
(([()]),),
((1, 2),),
(((1, "foo"), ["bar", (3, (), 7)]),),
(((1, "foo"), ["bar", (3, None, 7)]),),
([3],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=())), bar={"baz": 34})],),
([AnObject(3, (), [4, "foo"])],),
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject(3, None, [4, "foo"])],),
({"a": 1, "b": 2},),
]
@ -112,19 +112,19 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual([c0, c1], tree.children())
def testFlattenUpTo(self):
_, tree = tree_util.tree_flatten([(1, 2), (), ATuple(foo=3, bar=7)])
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
if not hasattr(tree, "flatten_up_to"):
self.skipTest("Test requires Jaxlib >= 0.1.23")
out = tree.flatten_up_to([({
"foo": 7
}, (3, 4)), (), ATuple(foo=(11, 9), bar=())])
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), ()])
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
def testTreeMultimap(self):
x = ((1, 2), (3, 4, 5))
y = (([3], ()), ({"foo": "bar"}, 7, [5, 6]))
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
self.assertEqual(out, (((1, [3]), (2, ())),
self.assertEqual(out, (((1, [3]), (2, None)),
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))