mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:06:05 +00:00
Merge branch 'master' into multibackend
This commit is contained in:
commit
0cc21c8d72
8
.github/workflows/build_mac_jaxlib.yml
vendored
8
.github/workflows/build_mac_jaxlib.yml
vendored
@ -9,11 +9,6 @@ jobs:
|
||||
build:
|
||||
|
||||
runs-on: macOS-latest
|
||||
strategy:
|
||||
max-parallel: 4
|
||||
matrix:
|
||||
python-version: [2.7, 3.5, 3.6, 3.7]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- uses: actions/setup-python@v1
|
||||
@ -28,5 +23,6 @@ jobs:
|
||||
- uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.7
|
||||
- run: python -m pip install --upgrade pyenv pyenv-virtualenv
|
||||
- run: brew install pyenv
|
||||
- run: brew install pyenv-virtualenv
|
||||
- run: build/build_jaxlib_wheels_macos.sh
|
||||
|
@ -23,10 +23,10 @@ http_archive(
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "461df411fccc278244edc32496e2d846fcb96ab019ea352c51476b6edcbdcc5b",
|
||||
strip_prefix = "tensorflow-9f4fc034f686d9a484f5613a7d840a4bbcfe0e27",
|
||||
sha256 = "cbabcb6616a0429aac544f80fd101076fbe4d99c48f8f728cb337a1da613b1d8",
|
||||
strip_prefix = "tensorflow-006e2933990258fbe3cffe1580ce52894056c999",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/9f4fc034f686d9a484f5613a7d840a4bbcfe0e27.tar.gz",
|
||||
"https://github.com/tensorflow/tensorflow/archive/006e2933990258fbe3cffe1580ce52894056c999.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_reduce_and, _reduce_window_sum, _reduce_window_max,
|
||||
_reduce_window_min, _reduce_window_prod, _float, _complex,
|
||||
_input_dtype, _const, _eq_meet, _safe_mul,
|
||||
_broadcasting_select)
|
||||
_broadcasting_select, _check_user_dtype_supported)
|
||||
from .lax_control_flow import *
|
||||
from .lax_fft import *
|
||||
from .lax_parallel import *
|
||||
|
@ -4418,3 +4418,15 @@ def subvals(lst, replace):
|
||||
|
||||
def _abstractify(x):
|
||||
return raise_to_shaped(core.get_aval(x))
|
||||
|
||||
|
||||
def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
if dtype is not None and onp.dtype(dtype) != xla_bridge.canonicalize_dtype(dtype):
|
||||
msg = ("Explicitly requested dtype {} {} is not available, "
|
||||
"and will be truncated to dtype {}. To enable more dtypes, set the "
|
||||
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
|
||||
"environment variable. "
|
||||
"See https://github.com/google/jax#current-gotchas for more.")
|
||||
fun_name = "requested in {}".format(fun_name) if fun_name else ""
|
||||
truncated_dtype = xla_bridge.canonicalize_dtype(dtype).name
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
|
||||
|
@ -54,7 +54,7 @@ def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
|
||||
fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
|
||||
out_avals, _ = unzip2(out_pvals)
|
||||
out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
||||
typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
|
||||
(), const_avals + in_avals, out_avals)
|
||||
@ -164,7 +164,7 @@ def while_loop(cond_fun, body_fun, init_val):
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]:
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(coud_jaxpr.out_avals))
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
if not treedef_children(in_tree) == [body_tree]:
|
||||
msg = "body_fun output pytree structure must match init_val, got {} and {}."
|
||||
raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0]))
|
||||
@ -302,18 +302,6 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
||||
true_nconsts=len(true_consts), false_nconsts=len(false_consts))
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
def _cond_impl(pred, *args, **kwargs):
|
||||
true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
|
||||
kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
|
||||
true_nops = len(true_jaxpr.in_avals) - true_nconsts
|
||||
true_consts, true_ops, false_consts, false_ops = split_list(
|
||||
args, [true_nconsts, true_nops, false_nconsts])
|
||||
|
||||
if pred:
|
||||
return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops))
|
||||
else:
|
||||
return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
|
||||
|
||||
def _cond_abstract_eval(*args, **kwargs):
|
||||
return kwargs["true_jaxpr"].out_avals
|
||||
|
||||
@ -341,10 +329,47 @@ def _cond_translation_rule(c, axis_env, pred, *args, **kwargs):
|
||||
|
||||
return c.Conditional(pred, true_op, true_c, false_op, false_c)
|
||||
|
||||
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
|
||||
false_nconsts):
|
||||
# TODO: maybe avoid moving arg axes to front if we're promoting to select?
|
||||
args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(args, dims)]
|
||||
true_nops = len(true_jaxpr.in_avals) - true_nconsts
|
||||
(pred,), true_consts, true_ops, false_consts, false_ops = split_list(
|
||||
args, [1, true_nconsts, true_nops, false_nconsts])
|
||||
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
||||
orig_bat = [d is not batching.not_mapped for d in dims]
|
||||
(pred_bat,), t_bat, tconst_bat, f_bat, fconst_bat = split_list(
|
||||
orig_bat, [1, true_nconsts, len(true_ops), false_nconsts])
|
||||
|
||||
_, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
|
||||
_, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
|
||||
out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]
|
||||
|
||||
true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
|
||||
false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)
|
||||
|
||||
if pred_bat:
|
||||
true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
|
||||
false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
|
||||
true_out = [batching.broadcast(x, size, 0) if not b else x
|
||||
for x, b in zip(true_out, out_bat)]
|
||||
false_out = [batching.broadcast(x, size, 0) if not b else x
|
||||
for x, b in zip(false_out, out_bat)]
|
||||
return [lax.select(pred, t, f)
|
||||
for t, f in zip(true_out, false_out)], [0] * len(true_out)
|
||||
else:
|
||||
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
||||
return cond_p.bind(
|
||||
*itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
|
||||
true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
|
||||
true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
|
||||
|
||||
cond_p = lax.Primitive('cond')
|
||||
cond_p.multiple_results = True
|
||||
cond_p.def_impl(_cond_impl)
|
||||
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
||||
cond_p.def_abstract_eval(_cond_abstract_eval)
|
||||
batching.primitive_batchers[cond_p] = _cond_batching_rule
|
||||
xla.initial_style_translations[cond_p] = _cond_translation_rule
|
||||
|
||||
|
||||
|
@ -27,12 +27,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from distutils.util import strtobool
|
||||
import collections
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import warnings
|
||||
import types
|
||||
import warnings
|
||||
|
||||
import numpy as onp
|
||||
import opt_einsum
|
||||
@ -42,12 +44,21 @@ from six.moves import builtins, xrange
|
||||
from jax import jit, device_put
|
||||
from .. import core
|
||||
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
|
||||
from ..config import flags
|
||||
from ..interpreters.xla import DeviceArray
|
||||
from .. import lax
|
||||
from ..util import partial, get_module_functions, unzip2, prod as _prod
|
||||
from ..lib import xla_bridge
|
||||
from ..lib import xla_client
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_enum(
|
||||
'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'),
|
||||
enum_values=['allow', 'warn', 'raise'],
|
||||
help=
|
||||
'Control NumPy-style automatic rank promotion broadcasting '
|
||||
'("allow", "warn", or "raise").')
|
||||
|
||||
if six.PY3:
|
||||
def removechars(s, chars):
|
||||
return s.translate(str.maketrans(dict.fromkeys(chars)))
|
||||
@ -158,9 +169,20 @@ def _promote_shapes(*args):
|
||||
return args
|
||||
else:
|
||||
shapes = [shape(arg) for arg in args]
|
||||
nd = len(lax.broadcast_shapes(*shapes))
|
||||
return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
|
||||
if shp and len(shp) != nd else arg for arg, shp in zip(args, shapes)]
|
||||
ranks = [len(shp) for shp in shapes]
|
||||
if len(set(ranks)) == 1:
|
||||
return args
|
||||
elif FLAGS.jax_numpy_rank_promotion != "raise":
|
||||
if FLAGS.jax_numpy_rank_promotion == "warn":
|
||||
msg = "following NumPy automatic rank promotion behavior for {}."
|
||||
warnings.warn(msg.format(' '.join(map(str, shapes))))
|
||||
nd = len(lax.broadcast_shapes(*shapes))
|
||||
return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
|
||||
if shp and len(shp) != nd else arg for arg, shp in zip(args, shapes)]
|
||||
else:
|
||||
msg = ("operands could not be broadcast together with shapes {} "
|
||||
"and with the config option jax_numpy_rank_promotion='raise'.")
|
||||
raise ValueError(msg.format(' '.join(map(str, shapes))))
|
||||
|
||||
def _promote_dtypes(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
@ -1386,6 +1408,7 @@ def atleast_3d(*arys):
|
||||
def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
if order is not None and order != "K":
|
||||
raise NotImplementedError("Only implemented for order='K'")
|
||||
lax._check_user_dtype_supported(dtype, "array")
|
||||
|
||||
if isinstance(object, ndarray):
|
||||
if dtype and _dtype(object) != xla_bridge.canonicalize_dtype(dtype):
|
||||
@ -1420,38 +1443,49 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
|
||||
@_wraps(onp.asarray)
|
||||
def asarray(a, dtype=None, order=None):
|
||||
lax._check_user_dtype_supported(dtype, "asarray")
|
||||
return array(a, dtype=dtype, copy=False, order=order)
|
||||
|
||||
|
||||
@_wraps(onp.zeros_like)
|
||||
def zeros_like(x, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "zeros_like")
|
||||
return lax.full_like(x, 0, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.ones_like)
|
||||
def ones_like(x, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "ones_like")
|
||||
return lax.full_like(x, 1, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.full)
|
||||
def full(shape, fill_value, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "full")
|
||||
return lax.full(shape, fill_value, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.full_like)
|
||||
def full_like(a, fill_value, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "full_like")
|
||||
return lax.full_like(a, fill_value, dtype)
|
||||
|
||||
|
||||
@_wraps(onp.zeros)
|
||||
def zeros(shape, dtype=onp.dtype("float64")):
|
||||
def zeros(shape, dtype=None):
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
lax._check_user_dtype_supported(dtype, "zeros")
|
||||
dtype = onp.dtype("float64") if dtype is None else dtype
|
||||
shape = (shape,) if onp.isscalar(shape) else shape
|
||||
return lax.full(shape, 0, dtype)
|
||||
|
||||
@_wraps(onp.ones)
|
||||
def ones(shape, dtype=onp.dtype("float64")):
|
||||
def ones(shape, dtype=None):
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
lax._check_user_dtype_supported(dtype, "ones")
|
||||
dtype = onp.dtype("float64") if dtype is None else dtype
|
||||
shape = (shape,) if onp.isscalar(shape) else shape
|
||||
return lax.full(shape, 1, dtype)
|
||||
|
||||
@ -1471,7 +1505,9 @@ empty = zeros
|
||||
|
||||
|
||||
@_wraps(onp.eye)
|
||||
def eye(N, M=None, k=None, dtype=onp.dtype("float64")):
|
||||
def eye(N, M=None, k=None, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "eye")
|
||||
dtype = onp.dtype("float64") if dtype is None else dtype
|
||||
M = N if M is None else M
|
||||
if N < 0 or M < 0:
|
||||
msg = "negative dimensions are not allowed, got {} and {}"
|
||||
@ -1490,11 +1526,13 @@ def eye(N, M=None, k=None, dtype=onp.dtype("float64")):
|
||||
|
||||
@_wraps(onp.identity)
|
||||
def identity(n, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "identity")
|
||||
return eye(n, dtype=dtype)
|
||||
|
||||
|
||||
@_wraps(onp.arange)
|
||||
def arange(start, stop=None, step=None, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "arange")
|
||||
# If called like np.arange(N), we create a lazy lax._IotaConstant.
|
||||
if stop is None and step is None:
|
||||
dtype = dtype or _dtype(start)
|
||||
@ -1516,6 +1554,7 @@ def _wrap_numpy_nullary_function(f):
|
||||
|
||||
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
|
||||
axis=0):
|
||||
lax._check_user_dtype_supported(dtype, "linspace")
|
||||
try:
|
||||
out = onp.linspace(start, stop, num, endpoint, retstep, dtype, axis)
|
||||
if retstep:
|
||||
@ -1624,6 +1663,7 @@ def repeat(a, repeats, axis=None):
|
||||
|
||||
@_wraps(onp.tri)
|
||||
def tri(N, M=None, k=0, dtype=None):
|
||||
lax._check_user_dtype_supported(dtype, "tri")
|
||||
M = M if M is not None else N
|
||||
dtype = dtype or float32
|
||||
x = arange(N, dtype=int32)
|
||||
@ -1657,6 +1697,7 @@ def triu(m, k=0):
|
||||
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
|
||||
if out:
|
||||
raise NotImplementedError("The 'out' argument to trace is not supported.")
|
||||
lax._check_user_dtype_supported(dtype, "trace")
|
||||
|
||||
axis1 = _canonicalize_axis(axis1, ndim(a))
|
||||
axis2 = _canonicalize_axis(axis2, ndim(a))
|
||||
@ -2817,6 +2858,10 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
|
||||
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
|
||||
keepdims=keepdims)
|
||||
|
||||
def _astype(arr, dtype):
|
||||
lax._check_user_dtype_supported(dtype, "astype")
|
||||
return lax.convert_element_type(arr, dtype)
|
||||
|
||||
### track unimplemented functions
|
||||
|
||||
def _not_implemented(fun):
|
||||
@ -2912,7 +2957,7 @@ setattr(ShapedArray, "flatten", core.aval_method(ravel))
|
||||
setattr(ShapedArray, "T", core.aval_property(transpose))
|
||||
setattr(ShapedArray, "real", core.aval_property(real))
|
||||
setattr(ShapedArray, "imag", core.aval_property(imag))
|
||||
setattr(ShapedArray, "astype", core.aval_method(lax.convert_element_type))
|
||||
setattr(ShapedArray, "astype", core.aval_method(_astype))
|
||||
|
||||
|
||||
# Forward operators, methods, and properties on DeviceArray to lax_numpy
|
||||
@ -2926,7 +2971,7 @@ setattr(DeviceArray, "flatten", ravel)
|
||||
setattr(DeviceArray, "T", property(transpose))
|
||||
setattr(DeviceArray, "real", property(real))
|
||||
setattr(DeviceArray, "imag", property(imag))
|
||||
setattr(DeviceArray, "astype", lax.convert_element_type)
|
||||
setattr(DeviceArray, "astype", _astype)
|
||||
|
||||
|
||||
# Extra methods that are handy
|
||||
|
@ -14,9 +14,6 @@
|
||||
|
||||
"""Utilities for working with tree-like container data structures.
|
||||
|
||||
The code here is independent of JAX. The only dependence is on jax.util, which
|
||||
itself has no JAX-specific code.
|
||||
|
||||
This module provides a small set of utility functions for working with tree-like
|
||||
data structures, such as nested tuples, lists, and dicts. We call these
|
||||
structures pytrees. They are trees in that they are defined recursively (any
|
||||
@ -29,6 +26,10 @@ mapped over, rather than treated as leaves) is extensible. There is a single
|
||||
module-level registry of types, and class hierarchy is ignored. By registering a
|
||||
new pytree node type, that type in effect becomes transparent to the utility
|
||||
functions in this file.
|
||||
|
||||
The primary purpose of this module is to enable the interoperability between
|
||||
user defined data structures and JAX transformations (e.g. `jit`). This is not
|
||||
meant to be a general purpose tree-like data structure handling library.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -159,6 +159,7 @@ class PyTreeDef {
|
||||
private:
|
||||
enum class Kind {
|
||||
kLeaf, // An opaque leaf node
|
||||
kNone, // None.
|
||||
kTuple, // A tuple
|
||||
kNamedTuple, // A collections.namedtuple
|
||||
kList, // A list
|
||||
@ -247,7 +248,9 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
|
||||
Node node;
|
||||
int start_num_nodes = tree->traversal_.size();
|
||||
int start_num_leaves = leaves->size();
|
||||
if (PyTuple_CheckExact(handle.ptr())) {
|
||||
if (py::isinstance<py::none>(handle)) {
|
||||
node.kind = Kind::kNone;
|
||||
} else if (PyTuple_CheckExact(handle.ptr())) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.kind = Kind::kTuple;
|
||||
node.arity = tuple.size();
|
||||
@ -334,6 +337,7 @@ py::object PyTreeDef::Unflatten(py::iterable leaves) const {
|
||||
++leaf_count;
|
||||
break;
|
||||
|
||||
case Kind::kNone:
|
||||
case Kind::kTuple:
|
||||
case Kind::kNamedTuple:
|
||||
case Kind::kList:
|
||||
@ -368,6 +372,9 @@ py::object PyTreeDef::Unflatten(py::iterable leaves) const {
|
||||
case Kind::kLeaf:
|
||||
throw std::logic_error("MakeNode not implemented for leaves.");
|
||||
|
||||
case Kind::kNone:
|
||||
return py::none();
|
||||
|
||||
case Kind::kTuple:
|
||||
case Kind::kNamedTuple: {
|
||||
py::tuple tuple(node.arity);
|
||||
@ -434,6 +441,9 @@ py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
|
||||
--leaf;
|
||||
break;
|
||||
|
||||
case Kind::kNone:
|
||||
break;
|
||||
|
||||
case Kind::kTuple: {
|
||||
if (!PyTuple_CheckExact(object.ptr())) {
|
||||
throw std::invalid_argument(
|
||||
@ -570,6 +580,7 @@ py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
|
||||
break;
|
||||
}
|
||||
|
||||
case Kind::kNone:
|
||||
case Kind::kTuple:
|
||||
case Kind::kNamedTuple:
|
||||
case Kind::kList:
|
||||
@ -694,6 +705,9 @@ std::string PyTreeDef::ToString() const {
|
||||
case Kind::kLeaf:
|
||||
agenda.push_back("*");
|
||||
continue;
|
||||
case Kind::kNone:
|
||||
kind = "None";
|
||||
break;
|
||||
case Kind::kNamedTuple:
|
||||
kind = "namedtuple";
|
||||
break;
|
||||
|
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.1.25"
|
||||
__version__ = "0.1.26"
|
||||
|
2
setup.py
2
setup.py
@ -28,7 +28,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=find_packages(exclude=["examples"]),
|
||||
install_requires=[
|
||||
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum<3',
|
||||
'numpy>=1.12', 'six', 'protobuf>=3.6.0', 'absl-py', 'opt_einsum',
|
||||
'fastcache'
|
||||
],
|
||||
url='https://github.com/google/jax',
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
from functools import partial
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
@ -41,6 +42,7 @@ from jax import tree_util
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
class APITest(jtu.JaxTestCase):
|
||||
|
||||
@ -186,9 +188,15 @@ class APITest(jtu.JaxTestCase):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
|
||||
ValueError,
|
||||
"Incompatible shapes for broadcasting: ((3,), (4,))")
|
||||
jtu.check_raises(
|
||||
lambda: f(np.zeros(3), np.zeros(4)),
|
||||
TypeError,
|
||||
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
||||
|
||||
jtu.check_raises(
|
||||
lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
|
||||
TypeError,
|
||||
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
||||
|
||||
def test_dot_mismatch(self):
|
||||
def f(x, y):
|
||||
@ -969,6 +977,54 @@ class APITest(jtu.JaxTestCase):
|
||||
for x, y in zip(xs, ys):
|
||||
self.assertAllClose(x * 2 - 3., y, check_dtypes=True)
|
||||
|
||||
def test_dtype_warning(self):
|
||||
# cf. issue #1230
|
||||
if FLAGS.jax_enable_x64:
|
||||
return # test only applies when x64 is disabled
|
||||
|
||||
def check_warning(warn, nowarn):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
nowarn() # get rid of extra startup warning
|
||||
|
||||
prev_len = len(w)
|
||||
nowarn()
|
||||
assert len(w) == prev_len
|
||||
|
||||
warn()
|
||||
assert len(w) > 0
|
||||
msg = str(w[-1].message)
|
||||
expected_prefix = "Explicitly requested dtype "
|
||||
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
|
||||
|
||||
prev_len = len(w)
|
||||
nowarn()
|
||||
assert len(w) == prev_len
|
||||
|
||||
check_warning(lambda: np.array([1, 2, 3], dtype="float64"),
|
||||
lambda: np.array([1, 2, 3], dtype="float32"),)
|
||||
check_warning(lambda: np.ones(3, dtype=onp.float64),
|
||||
lambda: np.ones(3))
|
||||
check_warning(lambda: np.ones_like(3, dtype=onp.int64),
|
||||
lambda: np.ones_like(3, dtype=onp.int32))
|
||||
check_warning(lambda: np.zeros(3, dtype="int64"),
|
||||
lambda: np.zeros(3, dtype="int32"))
|
||||
check_warning(lambda: np.zeros_like(3, dtype="float64"),
|
||||
lambda: np.zeros_like(3, dtype="float32"))
|
||||
check_warning(lambda: np.full((2, 3), 1, dtype="int64"),
|
||||
lambda: np.full((2, 3), 1))
|
||||
check_warning(lambda: np.ones(3).astype("float64"),
|
||||
lambda: np.ones(3).astype("float32"))
|
||||
check_warning(lambda: np.eye(3, dtype=onp.float64),
|
||||
lambda: np.eye(3))
|
||||
check_warning(lambda: np.arange(3, dtype=onp.float64),
|
||||
lambda: np.arange(3, dtype=onp.float32))
|
||||
check_warning(lambda: np.linspace(0, 3, dtype=onp.float64),
|
||||
lambda: np.linspace(0, 3, dtype=onp.float32))
|
||||
check_warning(lambda: np.tri(2, dtype="float64"),
|
||||
lambda: np.tri(2, dtype="float32"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -475,6 +475,58 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(fun(4), cfun(4))
|
||||
self.assertEqual(cfun(4), (4, 2., 4.))
|
||||
|
||||
def testCondBatched(self):
|
||||
def fun(x, y, z):
|
||||
pred = lax.lt(x, 3)
|
||||
true_fun = lambda y: y
|
||||
false_fun = lambda z: lax.neg(z)
|
||||
return lax.cond(pred, y, true_fun, z, false_fun)
|
||||
|
||||
# these cases stay as cond
|
||||
x = onp.array(2)
|
||||
y = onp.array([1, 2])
|
||||
z = onp.array([3, 4])
|
||||
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
|
||||
expected = onp.array([1, 2])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
x = onp.array(4)
|
||||
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0)))(x, y, z)
|
||||
expected = onp.array([-3, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
fun = api.jit(fun)
|
||||
ans = api.vmap(fun, (None, 0, 0))(x, y, z)
|
||||
expected = onp.array([-3, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
z = onp.array(5)
|
||||
ans = api.vmap(fun, (None, 0, None))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None)))(x, y, z)
|
||||
expected = onp.array([-5, -5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
|
||||
# these cases become select
|
||||
x = onp.array([2, 4])
|
||||
ans = api.vmap(fun, (0, 0, None))(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None)))(x, y, z)
|
||||
expected = onp.array([1, -5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" in str(jaxpr)
|
||||
|
||||
z = onp.array([3, 4])
|
||||
ans = api.vmap(fun)(x, y, z)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z)
|
||||
expected = onp.array([1, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" in str(jaxpr)
|
||||
|
||||
def testIssue514(self):
|
||||
# just check this doesn't crash
|
||||
lax.cond(True,
|
||||
@ -900,6 +952,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
python_should_be_executing = False
|
||||
lax.while_loop(cond, body, 0)
|
||||
|
||||
def testWhileCondConstant(self):
|
||||
out = lax.while_loop(lambda _: False, lambda _: (), ()) # doesn't crash
|
||||
self.assertEqual(out, ())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -23,6 +23,7 @@ import itertools
|
||||
import operator
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -1829,5 +1830,35 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
lnp_fun = partial(lnp.meshgrid, indexing=indexing, sparse=sparse)
|
||||
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
def testDisableNumpyRankPromotionBroadcasting(self):
|
||||
try:
|
||||
prev_flag = FLAGS.jax_numpy_rank_promotion
|
||||
FLAGS.jax_numpy_rank_promotion = "allow"
|
||||
lnp.ones(2) + lnp.ones((1, 2)) # works just fine
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = prev_flag
|
||||
|
||||
try:
|
||||
prev_flag = FLAGS.jax_numpy_rank_promotion
|
||||
FLAGS.jax_numpy_rank_promotion = "raise"
|
||||
self.assertRaises(ValueError, lambda: lnp.ones(2) + lnp.ones((1, 2)))
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = prev_flag
|
||||
|
||||
try:
|
||||
prev_flag = FLAGS.jax_numpy_rank_promotion
|
||||
FLAGS.jax_numpy_rank_promotion = "warn"
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
lnp.ones(2) + lnp.ones((1, 2))
|
||||
assert len(w) > 0
|
||||
msg = str(w[-1].message)
|
||||
self.assertEqual(
|
||||
msg,
|
||||
"following NumPy automatic rank promotion behavior for (2,) (1, 2).")
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = prev_flag
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -57,10 +57,10 @@ PYTREES = [
|
||||
((),),
|
||||
(([()]),),
|
||||
((1, 2),),
|
||||
(((1, "foo"), ["bar", (3, (), 7)]),),
|
||||
(((1, "foo"), ["bar", (3, None, 7)]),),
|
||||
([3],),
|
||||
([3, ATuple(foo=(3, ATuple(foo=3, bar=())), bar={"baz": 34})],),
|
||||
([AnObject(3, (), [4, "foo"])],),
|
||||
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
||||
([AnObject(3, None, [4, "foo"])],),
|
||||
({"a": 1, "b": 2},),
|
||||
]
|
||||
|
||||
@ -112,19 +112,19 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual([c0, c1], tree.children())
|
||||
|
||||
def testFlattenUpTo(self):
|
||||
_, tree = tree_util.tree_flatten([(1, 2), (), ATuple(foo=3, bar=7)])
|
||||
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
|
||||
if not hasattr(tree, "flatten_up_to"):
|
||||
self.skipTest("Test requires Jaxlib >= 0.1.23")
|
||||
out = tree.flatten_up_to([({
|
||||
"foo": 7
|
||||
}, (3, 4)), (), ATuple(foo=(11, 9), bar=())])
|
||||
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), ()])
|
||||
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
|
||||
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
|
||||
|
||||
def testTreeMultimap(self):
|
||||
x = ((1, 2), (3, 4, 5))
|
||||
y = (([3], ()), ({"foo": "bar"}, 7, [5, 6]))
|
||||
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
||||
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
|
||||
self.assertEqual(out, (((1, [3]), (2, ())),
|
||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user