1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

source sync

PiperOrigin-RevId: 222451919
This commit is contained in:
Peter Hawkins 2018-11-21 13:20:44 -08:00 committed by Roy Frostig
parent fe4edf2839
commit e180f08113
24 changed files with 181 additions and 104 deletions

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from api import *
from jax.api import *

@ -15,6 +15,7 @@
from __future__ import absolute_import
import numpy as onp
import six
from . import core
from . import ad_util
@ -52,7 +53,8 @@ class UnshapedArray(core.AbstractValue):
_bool = _nonzero = concretization_function_error(bool)
_float = concretization_function_error(float)
_int = concretization_function_error(int)
_long = concretization_function_error(long)
if six.PY2:
_long = concretization_function_error(long)
_complex = concretization_function_error(complex)
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)
@ -94,7 +96,7 @@ class ShapedArray(UnshapedArray):
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError, other
raise TypeError(other)
def str_short(self):
dtypestr = onp.dtype(self.dtype).name
@ -137,7 +139,7 @@ class ConcreteArray(ShapedArray):
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError, other
raise TypeError(other)
def str_short(self):
return str(self.val)

@ -17,6 +17,9 @@ from __future__ import absolute_import
from .core import JaxTuple, lattice_join
from .interpreters.partial_eval import Primitive
from .tree_util import register_pytree_node
from .util import safe_map
map = safe_map
jaxval_adders = {}

@ -17,8 +17,9 @@ from __future__ import absolute_import
from .core import pack
from .tree_util import build_tree, process_pytree
from .linear_util import transformation_with_aux
from .util import unzip2, partial
from .util import safe_map, unzip2, partial
map = safe_map
@transformation_with_aux
def flatten_fun(in_trees, *args, **kwargs):

@ -18,6 +18,7 @@ from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple, Counter, defaultdict
from weakref import ref
import six
import types
from . import linear_util as lu
@ -248,7 +249,6 @@ class Tracer(object):
def __oct__(self): return self.aval._oct(self)
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert skip_checks or name != "aval"
@ -263,7 +263,10 @@ class Tracer(object):
if t is aval_property:
return attr.fget(self)
elif t is aval_method:
return types.MethodType(attr.fun, self, None)
if six.PY3:
return types.MethodType(attr.fun, self)
else:
return types.MethodType(attr.fun, self, None)
else:
return attr
@ -345,7 +348,7 @@ def new_master(trace_type, bottom=False):
t = ref(master)
del master
if t() is not None:
print trace_stack
print(trace_stack)
raise Exception('Leaked trace {}'.format(t()))

@ -182,6 +182,7 @@ class LapaxMatrix(object):
__sub__ = _make_infix_op(lax.sub)
__mul__ = _make_infix_op(lax.batch_matmul)
__div__ = _make_infix_op(lax.div)
__truediv__ = _make_infix_op(lax.div)
T = property(_make_infix_op(_matrix_transpose))

@ -185,4 +185,4 @@ def make_schedule(scalar_or_schedule_fun):
elif np.ndim(scalar_or_schedule_fun) == 0:
return constant(scalar_or_schedule_fun)
else:
raise TypeError, type(scalar_or_schedule_fun)
raise TypeError(type(scalar_or_schedule_fun))

@ -27,6 +27,7 @@ import operator as op
import numpy as onp
import numpy.random as npr
from six.moves import reduce
from jax import lax
from jax import random

@ -20,11 +20,14 @@ from .. import core as core
from ..core import Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, zero, Zero)
from ..util import unzip2, unzip3, safe_zip, partial
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
from ..tree_util import process_pytree, build_tree, register_pytree_node
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
from six.moves import builtins, reduce
zip = safe_zip
map = safe_map
def jvp(fun):
return jvpfun(jvp_subtrace(fun))
@ -102,13 +105,18 @@ def backward_pass(jaxpr, consts, freevar_vals, cotangent_in):
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs])
cts_out, ct_free_vars_out = get_primitive_transpose(eqn.primitive)(
eqn.params, subjaxprs, sub_consts, sub_freevar_vals, invals, ct_in)
# TODO(dougalm): support cases != 1
assert(len(eqn.bound_subjaxprs) == 1)
_, _, bound_vars = eqn.bound_subjaxprs[0]
map(write_cotangent, bound_vars, ct_free_vars_out)
else:
cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
if cts_out is zero:
cts_out = [zero for _ in eqn.invars]
map(write_cotangent, eqn.invars, cts_out)
# TODO(phawkins,dougalm): eqn.invars and cts_out can have different lengths
for var, ct in builtins.zip(eqn.invars, cts_out):
write_cotangent(var, ct)
cotangents_out = map(read_cotangent, jaxpr.invars)
freevar_cts = map(read_cotangent, jaxpr.freevars)
@ -195,7 +203,7 @@ class JVPTrace(Trace):
elif xt is zero and yt is zero:
return xt, yt
else:
raise TypeError, (xt, yt)
raise TypeError((xt, yt))
def pack(self, tracers):
primals = pack(t.primal for t in tracers)
@ -344,7 +352,10 @@ def transposed_fun(jaxpr, in_tree_def, args):
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
yield out_jtuple, tree_def
def call_transpose(primitive, params, (jaxpr,), (consts,), (freevar_vals,), args, ct):
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
jaxpr, = jaxpr
consts, = consts
freevar_vals, = freevar_vals
assert isinstance(jaxpr, core.Jaxpr)
assert all(a is None for a in args), "TODO(dougalm): handle non-tangent primal args"
(ct, freevar_vals), in_tree_def = tree_to_jaxtuples((ct, freevar_vals))

@ -18,6 +18,8 @@ import itertools as it
import numpy as onp
from six.moves import reduce
from .. import core
from ..core import Trace, Tracer, new_master, pack, AbstractTuple, JaxTuple
from ..abstract_arrays import ShapedArray, make_shaped_array, array_types
@ -83,7 +85,7 @@ class BatchTracer(Tracer):
elif t is int:
batch_dims = [self.batch_dim] * len(self.val)
else:
raise TypeError, t
raise TypeError(t)
return map(partial(BatchTracer, self.trace), self.val, batch_dims)
def full_lower(self):
@ -150,7 +152,7 @@ def raise_to_shaped(aval):
elif isinstance(aval, ShapedArray):
return ShapedArray(aval.shape, aval.dtype)
else:
raise TypeError, type(aval)
raise TypeError(type(aval))
def remove_batch_dim_from_aval(bdim, aval):
t = type(aval)
@ -167,7 +169,7 @@ def remove_batch_dim_from_aval(bdim, aval):
unbatched_shape = tuple(onp.delete(aval.shape, bdim))
return ShapedArray(unbatched_shape, aval.dtype)
else:
raise TypeError, t
raise TypeError(t)
pytype_aval_mappings = {}
@ -259,7 +261,7 @@ def dimsize(dim, x):
elif dim is None:
return set()
else:
raise TypeError, type(dim)
raise TypeError(type(dim))
def moveaxis(sz, dst, src, x):
aval = get_aval(x)
@ -284,7 +286,7 @@ def moveaxis(sz, dst, src, x):
perm.insert(dst, src)
return x.transpose(perm)
else:
raise TypeError, type(aval)
raise TypeError(type(aval))
def broadcast(x, sz):
try:

@ -25,6 +25,8 @@ from ..core import (Trace, Tracer, new_master, Jaxpr, JaxprEqn, get_aval, pack,
AbstractValue, AbstractTuple, unit, unitvar, Primitive,
call_p)
map = safe_map
zip = safe_zip
class JaxprTrace(Trace):
def pure(self, val):
@ -212,7 +214,7 @@ def as_abstract_val(pv):
elif isinstance(pv, JaxprTracerTuple):
return AbstractTuple(map(as_abstract_val, pv))
elif pv is None:
raise TypeError, "{} is not abstract".format(pv)
raise TypeError("{} is not abstract".format(pv))
def partial_val_aval(pv, const):
@ -283,7 +285,7 @@ def eqn_tracer_to_var(var, outvars, eqn):
def tracers_to_jaxpr(in_tracers, out_tracer):
newvar = gensym('')
t_to_var = defaultdict(newvar)
var = lambda t: t_to_var[t]
var = lambda t: t_to_var[id(t)]
sorted_tracers = toposort(out_tracer)
invars = map(var, in_tracers)
eqns = []
@ -308,11 +310,11 @@ def tracers_to_jaxpr(in_tracers, out_tracer):
destructuring_vars[key] = outvars
else:
outvars = destructuring_vars[key]
t_to_var[t] = outvars[i]
t_to_var[id(t)] = outvars[i]
elif recipe is unit:
t_to_var[t] = unitvar
t_to_var[id(t)] = unitvar
else:
raise TypeError, recipe
raise TypeError(recipe)
env_vars, env_vals = unzip2(env.items())
const_vars, const_vals = unzip2(consts.items())
@ -323,7 +325,7 @@ def tracers_to_jaxpr(in_tracers, out_tracer):
def gensym(suffix):
counter = it.count()
return lambda: Var(counter.next(), suffix)
return lambda: Var(next(counter), suffix)
class Var(object):
def __init__(self, count, suffix):
@ -334,7 +336,7 @@ class Var(object):
rem = self.count
s = ''
while True:
rem, i = rem / 26, rem % 26
rem, i = rem // 26, rem % 26
s = chr(97 + i % 26) + s
if not rem:
break

@ -18,13 +18,15 @@ from collections import namedtuple
import itertools as it
import numpy as onp
import operator as op
import six
from six.moves import xrange
from absl import flags
from .. import core
from .. import ad_util
from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types
from ..core import AbstractTuple, JaxTuple, pack, valid_jaxtype
from ..util import partial, partialmethod, memoize, unzip2, concatenate
from ..util import partial, partialmethod, memoize, unzip2, concatenate, safe_map
from ..linear_util import transformation_with_aux, memoize as linear_memoize
from ..lib import xla_bridge as xb
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
@ -32,6 +34,7 @@ from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_device_values', True, 'Enable device-persistent values.')
map = safe_map
def apply_primitive(prim, *args, **kwargs):
abstract_args = map(abstractify, args)
@ -123,7 +126,8 @@ translations = {}
translations[core.pack_p] = lambda c, *xs: c.Tuple(*xs)
translations[ad_util.add_jaxvals_p] = lambda c, x, y: c.Add(x, y)
translations[core.call_p] = lambda c, (subc, a1), *a2: c.Call(subc, a1 + a2)
translations[core.call_p] = lambda c, subc_a1, *a2: c.Call(subc_a1[0],
subc_a1[1] + a2)
translations[core.identity_p] = lambda c, x: x
@ -242,7 +246,8 @@ class DeviceArray(DeviceValue):
__bool__ = __nonzero__ = partialmethod(forward_to_value, bool)
__float__ = partialmethod(forward_to_value, float)
__int__ = partialmethod(forward_to_value, int)
__long__ = partialmethod(forward_to_value, long)
if six.PY2:
__long__ = partialmethod(forward_to_value, long)
__complex__ = partialmethod(forward_to_value, complex)
__hex__ = partialmethod(forward_to_value, hex)
__oct__ = partialmethod(forward_to_value, oct)
@ -253,6 +258,14 @@ class DeviceArray(DeviceValue):
# clobbered when jax.numpy is imported, but useful in tests
def __eq__(self, other): return self._value == other
def __hash__(self):
# TODO(mattjj): this is not semantically correct because it is possible
# __eq__ is true for values with unequal __hash__ values. However, the
# main use case at the moment is memoization for which false negatives are
# fine.
return id(self)
core.pytype_aval_mappings[DeviceArray] = ConcreteArray
pytype_aval_mappings[DeviceArray] = make_shaped_array
canonicalize_dtype_handlers[DeviceArray] = identity
@ -267,7 +280,7 @@ def xla_shape(x):
if type(x) in (core.AbstractTuple, core.JaxTuple):
return xb.Shape.tuple_shape(tuple(map(xla_shape, x)))
else:
raise TypeError, type(x)
raise TypeError(type(x))
# For callable XLA Computations (as opposed to, e.g., Computations used in the
@ -298,7 +311,7 @@ def tree_flatten(maybe_tree):
elif core.skip_checks or valid_jaxtype(maybe_tree):
return [maybe_tree], leaf
else:
raise TypeError, type(maybe_tree)
raise TypeError(type(maybe_tree))
JTupleTreeDef = namedtuple("JTupleTreeDef", ["child_specs"])
@ -313,7 +326,7 @@ def build_tree(xs, tree_spec):
elif type(tree_spec) is JTupleTreeDef:
return pack(map(partial(build_tree, xs), tree_spec.child_specs))
else:
raise TypeError, type(tree_spec)
raise TypeError(type(tree_spec))
def device_put(x):
@ -364,4 +377,5 @@ xla_call = partial(core.call_bind, xla_call_p)
xla_call_p.def_custom_bind(xla_call)
xla_call_p.def_impl(xla_call_impl)
translations[xla_call_p] = lambda c, (subc, a1), *a2: c.Call(subc, a1 + a2)
translations[xla_call_p] = lambda c, subc_a1, *a2: c.Call(subc_a1[0],
subc_a1[1] + a2)

@ -15,12 +15,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import __builtin__
import collections
from .util import partial
import itertools
import operator
import six
from six.moves import builtins, xrange
import string
import numpy as onp
@ -40,9 +41,14 @@ from .util import curry, safe_zip, unzip2
from .tree_util import build_tree
from .lib import xla_bridge
_max = __builtin__.max
_min = __builtin__.max
_max = builtins.max
_min = builtins.max
if six.PY3:
def maketrans(s1, s2):
return s1.maketrans(s1, s2)
else:
maketrans = string.maketrans
### traceables
@ -383,7 +389,7 @@ def _while_loop(cond_fun, body_fun, init_val):
params = OpaqueParam((abs_out, cond_jaxpr, cond_consts, body_jaxpr, body_consts))
out_flat = while_p.bind(init_val_flat, opaque_params=params)
if out_tree() != in_tree:
raise TypeError, "body_fun input and output must have identical structure"
raise TypeError("body_fun input and output must have identical structure")
return build_tree(out_tree(), out_flat)
class OpaqueParam(object):
@ -500,8 +506,9 @@ def fori_loop(lower, upper, body_fun, init_val):
# state: (upper limit, index, loop value)
# The `lt` and `add` functions are added to the namespace programmatically.
_, _, result = _while_loop(
lambda (upper, i, _): lt(i, upper),
lambda (upper, i, x): (upper, add(i, 1), body_fun(i, x)),
lambda upper_i_x: lt(upper_i_x[1], upper_i_x[0]),
lambda upper_i_x: (upper_i_x[0], add(upper_i_x[1], 1),
body_fun(upper_i_x[1], upper_i_x[2])),
(upper, lower, init_val))
return result
@ -519,7 +526,7 @@ def foreach_loop(sequence, body_fun, init_val):
"""
_, result = fori_loop(
0, len(sequence),
lambda i, (seq, val): body_fun(seq[i], val),
lambda i, seq_val: body_fun(seq_val[0][i], seq_val[1]),
(sequence, init_val))
return result
@ -708,7 +715,7 @@ standard_binop = partial(binop, _input_dtype)
def _brcast(x, *others):
# used in jvprules to make binop broadcasting explicit for transposability.
# requires shape info during jvp tracing, which isn't strictly necessary.
shapes = filter(None, map(onp.shape, (x,) + others))
shapes = list(filter(None, map(onp.shape, (x,) + others)))
shape = tuple(shapes and onp.max(shapes, axis=0))
if onp.shape(x) != shape:
return _brcast_to(x, shape)
@ -1017,8 +1024,8 @@ def conv_general_dilated_transpose_rhs(
lhs_sdims, rhs_sdims, out_sdims = _get_sdims(dimension_numbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
trans_dimension_numbers = (_charswap("C", "N", lhs_spec),
out_spec.translate(string.maketrans("NC", "IO")),
rhs_spec.translate(string.maketrans("IO", "NC")))
out_spec.translate(maketrans("NC", "IO")),
rhs_spec.translate(maketrans("IO", "NC")))
padding = _conv_general_vjp_rhs_padding(
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
@ -1543,9 +1550,8 @@ def slice_shape_rule(operand, start_indices, limit_indices, strides,
msg = "slice strides must be positive, got {}"
raise TypeError(msg.format(strides))
result_shape = onp.divide(onp.add(onp.subtract(limit_indices, start_indices),
strides) - 1,
strides)
result_shape = onp.floor_divide(
onp.add(onp.subtract(limit_indices, start_indices), strides) - 1, strides)
return tuple(result_shape)
def slice_translation_rule(c, operand, start_indices, limit_indices, strides,
@ -1911,7 +1917,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
padding):
pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding)
operand_padded = onp.add(operand_shape, onp.add(*zip(*pads)))
t = onp.divide(onp.subtract(operand_padded, window_dimensions), window_strides) + 1
t = onp.floor_divide(
onp.subtract(operand_padded, window_dimensions), window_strides) + 1
return tuple(t)
@ -2127,7 +2134,7 @@ def _check_same_dtypes(name, ignore_fp_precision, *dtypes):
"""Check that dtypes agree, possibly ignoring float precision."""
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
# allows mixed floating point precision, but the HLO verifier often rejects it
dtypes = map(onp.dtype, dtypes) # canonicalize
dtypes = list(map(onp.dtype, dtypes)) # canonicalize
if ignore_fp_precision:
dtypes = [
onp.floating if onp.issubdtype(dtype, onp.floating)
@ -2172,7 +2179,8 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
lhs_padded = onp.add(lhs_shape[2:], onp.add(*zip(*pads)))
out_space = onp.divide(onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
out_space = onp.floor_divide(
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
out_space = onp.maximum(0, out_space)
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
return tuple(out_shape)
@ -2268,7 +2276,7 @@ def remaining(original, *removed_lists):
def _charswap(a, b, s):
return s.translate(string.maketrans(a+b, b+a))
return s.translate(maketrans(a + b, b + a))
def _get_sdims(dimension_numbers):
@ -2339,13 +2347,13 @@ def _eq_meet(a, b):
def maybe_tracer_tuple_to_abstract_tuple(tup):
if isinstance(tup, pe.JaxprTracerTuple):
return core.AbstractTuple(map(maybe_tracer_tuple_to_abstract_tuple, tup))
return core.AbstractTuple(list(map(maybe_tracer_tuple_to_abstract_tuple, tup)))
elif isinstance(tup, core.AbstractValue):
return tup
elif tup is None:
return core.AbstractTuple(()) # TODO(dougalm): check this
else:
raise TypeError, tup
raise TypeError(tup)
def subvals(lst, replace):

@ -15,7 +15,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import __builtin__
import collections
import itertools
@ -24,10 +23,12 @@ import numpy as onp
import opt_einsum
import scipy.special
_slice = __builtin__.slice
_max = __builtin__.max
_min = __builtin__.min
_map = __builtin__.map
from six.moves import builtins
_slice = builtins.slice
_max = builtins.max
_min = builtins.min
_map = builtins.map
neg = onp.negative
sign = onp.sign
@ -87,13 +88,13 @@ sub = onp.subtract
mul = onp.multiply
def div(lhs, rhs):
quotient = onp.divide(lhs, rhs)
if onp.issubdtype(onp.result_type(lhs), onp.integer):
quotient = onp.floor_divide(lhs, rhs)
select = onp.logical_and(onp.sign(lhs) != onp.sign(rhs),
onp.remainder(lhs, rhs) != 0)
return onp.where(select, quotient + 1, quotient)
else:
return quotient
return onp.divide(lhs, rhs)
def rem(lhs, rhs):
return onp.sign(lhs) * onp.remainder(onp.abs(lhs), onp.abs(rhs))
@ -150,14 +151,14 @@ dot = onp.dot
def dot_general(lhs, rhs, dimension_numbers):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
new_id = itertools.count().next
lhs_axis_ids = [new_id() for _ in lhs.shape]
rhs_axis_ids = [new_id() for _ in rhs.shape]
new_id = itertools.count()
lhs_axis_ids = [next(new_id) for _ in lhs.shape]
rhs_axis_ids = [next(new_id) for _ in rhs.shape]
lhs_out_axis_ids = lhs_axis_ids[:]
rhs_out_axis_ids = rhs_axis_ids[:]
for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):
shared_id = new_id()
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None
@ -165,7 +166,7 @@ def dot_general(lhs, rhs, dimension_numbers):
batch_ids = []
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
shared_id = new_id()
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None
@ -215,7 +216,7 @@ select = onp.where
def slice(operand, start_indices, limit_indices, strides=None): # pylint: disable=redefined-builtin
if strides is None:
strides = onp.ones(len(start_indices)).astype(int)
slices = _map(_slice, start_indices, limit_indices, strides)
slices = tuple(_map(_slice, start_indices, limit_indices, strides))
return operand[slices]
def dynamic_slice(operand, start_indices, slice_sizes):
@ -227,7 +228,7 @@ def dynamic_slice(operand, start_indices, slice_sizes):
return out
def dynamic_update_slice(operand, update, start_indices):
slices = _map(_slice, start_indices, onp.add(start_indices, update.shape))
slices = tuple(_map(_slice, start_indices, onp.add(start_indices, update.shape)))
updated_operand = onp.copy(operand)
updated_operand[slices] = update
return updated_operand
@ -295,8 +296,8 @@ def _conv_view(lhs, rhs_shape, window_strides, pads, pad_value):
out_strides = onp.multiply(window_strides, lhs.strides[2:])
view_strides = lhs.strides[:1] + tuple(out_strides) + lhs.strides[1:]
out_shape = onp.divide(onp.subtract(in_shape, filter_shape),
window_strides) + 1
out_shape = onp.floor_divide(
onp.subtract(in_shape, filter_shape), window_strides) + 1
view_shape = lhs.shape[:1] + tuple(out_shape) + rhs_shape[1:]
view = onp.lib.stride_tricks.as_strided(lhs, view_shape, view_strides)

@ -112,10 +112,10 @@ def get_xla_client():
xla_client.initialize_platform_name(FLAGS.jax_platform_name)
else:
try:
xla_client.initialize_platform_name('CUDA')
xla_client.initialize_platform_name(b'CUDA')
except RuntimeError:
warnings.warn('No GPU found, falling back to CPU.')
xla_client.initialize_platform_name('Host')
xla_client.initialize_platform_name(b'Host')
return xla_client

@ -42,6 +42,8 @@ class Store(object):
def __nonzero__(self):
return hasattr(self, '_val')
__bool__ = __nonzero__
@curry
def staged(f, *init_args):
@ -74,7 +76,7 @@ class WrappedFun(object):
stack = []
for gen, gen_args, out_store in self.transforms:
gen = gen(*(gen_args + tuple(args)))
args = gen.next()
args = next(gen)
stack.append((gen, out_store))
del gen

@ -15,7 +15,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import __builtin__
from six.moves import builtins
import six
import numpy as onp
@ -38,11 +39,11 @@ import jax.lax as lax
# We replace some builtin names to follow Numpy's API, so we capture here.
_all = __builtin__.all
_any = __builtin__.any
_max = __builtin__.max
_min = __builtin__.min
_sum = __builtin__.sum
_all = builtins.all
_any = builtins.any
_max = builtins.max
_min = builtins.min
_sum = builtins.sum
# We need some numpy scalars
# TODO(mattjj): handle constants in an indirected, less explicit way?
@ -985,7 +986,7 @@ def _canonicalize_tuple_index(arr, idx):
if ellipsis_index is not None:
if next(ellipses, None) is not None:
msg = "Multiple ellipses (...) not supported: {}."
raise IndexError(msg.format(map(type, idx)))
raise IndexError(msg.format(list(map(type, idx))))
colons = (slice(None),) * (arr.ndim - len_without_none)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif len_without_none < arr.ndim:

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from six.moves import reduce
class PrettyPrint(object):
"""Crude Hughes-inspired pretty printer."""

@ -58,7 +58,7 @@ class PRNGKey(object):
def from_keypair(cls, keypair):
"""Internal method to create a PRNGKey instance from a raw key pair."""
new = cls.__new__(cls)
new.keypair = keypair
new.keypair = tuple(keypair)
return new
@ -83,7 +83,7 @@ def _make_rotate_left(dtype):
def _bit_stats(bits):
"""This is a debugging function to compute the statistics of bit fields."""
return onp.array([map(int, onp.binary_repr(x, 64)) for x in bits]).mean(0)
return onp.array([list(map(int, onp.binary_repr(x, 64))) for x in bits]).mean(0)
### hash function and split

@ -15,6 +15,7 @@
from __future__ import absolute_import
import functools
import re
from absl import flags
from absl.testing import absltest
@ -277,6 +278,13 @@ def check_raises(thunk, err_type, msg):
except err_type as e:
assert str(e) == msg, "{}\n\n{}\n".format(e, msg)
def check_raises_regexp(thunk, err_type, pattern):
try:
thunk()
assert False
except err_type as e:
assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern)
class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""

@ -16,8 +16,11 @@ from __future__ import absolute_import
from collections import namedtuple
import itertools as it
from six.moves import reduce
from .util import unzip2, concatenate, partial
from .util import unzip2, concatenate, partial, safe_map
map = safe_map
def tree_map(f, tree):

@ -24,16 +24,16 @@ allow_memoize_hash_failures = False
def safe_zip(*args):
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, 'length mismatch: {}'.format(map(len, args))
return zip(*args)
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
return list(zip(*args))
def safe_map(f, *args):
args = map(list, args)
args = list(map(list, args))
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, 'length mismatch: {}'.format(map(len, args))
return map(f, *args)
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
return list(map(f, *args))
def unzip2(xys):
@ -82,10 +82,10 @@ def toposort(end_node):
stack = [end_node]
while stack:
node = stack.pop()
if node in child_counts:
child_counts[node] += 1
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[node] = 1
child_counts[id(node)] = 1
stack.extend(node.parents)
sorted_nodes = []
@ -94,16 +94,16 @@ def toposort(end_node):
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in node.parents:
if child_counts[parent] == 1:
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[parent] -= 1
child_counts[id(parent)] -= 1
return sorted_nodes[::-1]
def split_merge(predicate, xs):
sides = map(predicate, xs)
sides = list(map(predicate, xs))
lhs = [x for x, s in zip(xs, sides) if s]
rhs = [x for x, s in zip(xs, sides) if not s]
def merge(new_lhs, new_rhs):

@ -16,6 +16,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import numpy as onp
from absl.testing import absltest
from jax import test_util as jtu
@ -102,11 +104,11 @@ class APITest(jtu.JaxTestCase):
def f(x):
return x
jtu.check_raises(lambda: grad(f)("foo"), TypeError,
"Argument 'foo' of type <type 'str'> is not a valid JAX type")
jtu.check_raises_regexp(lambda: grad(f)("foo"), TypeError,
"Argument 'foo' of type <.*'str'> is not a valid JAX type")
jtu.check_raises(lambda: jit(f)("foo"), TypeError,
"Argument 'foo' of type <type 'str'> is not a valid JAX type")
jtu.check_raises_regexp(lambda: jit(f)("foo"), TypeError,
"Argument 'foo' of type <.*'str'> is not a valid JAX type")
# TODO(dougalm): enable when we remove 'None' from pytree nodes
# def test_bad_output(self):
@ -176,12 +178,18 @@ class APITest(jtu.JaxTestCase):
return x
assert jit(f, static_argnums=(1,))(0, 5) == 10
jtu.check_raises(lambda: jit(f)(0, 5), TypeError, concretization_err_msg(int))
jtu.check_raises_regexp(
lambda: jit(f)(0, 5), TypeError,
"('JaxprTracer' object cannot be interpreted as an integer"
"|Abstract value passed to function.*)")
def test_casts(self):
for castfun in [float, int, long, complex, hex, oct]:
for castfun in [float, complex, hex, oct] + list(six.integer_types):
f = lambda x: castfun(x)
jtu.check_raises(lambda: jit(f)(0), TypeError, concretization_err_msg(castfun))
jtu.check_raises_regexp(
lambda: jit(f)(0), TypeError,
"('JaxprTracer' object cannot be interpreted as an integer"
"|Abstract value passed to function.*)")
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')

@ -65,19 +65,23 @@ class LaxRandomTest(jtu.JaxTestCase):
# We test the hash by comparing to known values provided in the test code of
# the original reference implementation of Threefry. For the values, see
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
expected = ("0x6b200159L", "0x99ba4efeL")
def result_to_hex(result):
return tuple([hex(x.copy()).rstrip("L") for x in result])
expected = ("0x6b200159", "0x99ba4efe")
result = random.threefry_2x32(onp.uint32([0, 0]), onp.uint32([0, 0]))
self.assertEqual(expected, tuple(map(hex, result)))
expected = ("0x1cb996fcL", "0xbb002be7L")
self.assertEqual(expected, result_to_hex(result))
expected = ("0x1cb996fc", "0xbb002be7")
result = random.threefry_2x32(onp.uint32([-1, -1]), onp.uint32([-1, -1]))
self.assertEqual(expected, tuple(map(hex, result)))
self.assertEqual(expected, result_to_hex(result))
expected = ("0xc4923a9cL", "0x483df7a0L")
expected = ("0xc4923a9c", "0x483df7a0")
result = random.threefry_2x32(
onp.uint32([0x13198a2e, 0x03707344]),
onp.uint32([0x243f6a88, 0x85a308d3]))
self.assertEqual(expected, tuple(map(hex, result)))
self.assertEqual(expected, result_to_hex(result))
@parameterized.named_parameters(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}