mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
source sync
PiperOrigin-RevId: 222451919
This commit is contained in:
parent
fe4edf2839
commit
e180f08113
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from api import *
|
||||
from jax.api import *
|
||||
|
@ -15,6 +15,7 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as onp
|
||||
import six
|
||||
|
||||
from . import core
|
||||
from . import ad_util
|
||||
@ -52,7 +53,8 @@ class UnshapedArray(core.AbstractValue):
|
||||
_bool = _nonzero = concretization_function_error(bool)
|
||||
_float = concretization_function_error(float)
|
||||
_int = concretization_function_error(int)
|
||||
_long = concretization_function_error(long)
|
||||
if six.PY2:
|
||||
_long = concretization_function_error(long)
|
||||
_complex = concretization_function_error(complex)
|
||||
_hex = concretization_function_error(hex)
|
||||
_oct = concretization_function_error(oct)
|
||||
@ -94,7 +96,7 @@ class ShapedArray(UnshapedArray):
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype)
|
||||
else:
|
||||
raise TypeError, other
|
||||
raise TypeError(other)
|
||||
|
||||
def str_short(self):
|
||||
dtypestr = onp.dtype(self.dtype).name
|
||||
@ -137,7 +139,7 @@ class ConcreteArray(ShapedArray):
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype)
|
||||
else:
|
||||
raise TypeError, other
|
||||
raise TypeError(other)
|
||||
|
||||
def str_short(self):
|
||||
return str(self.val)
|
||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
||||
from .core import JaxTuple, lattice_join
|
||||
from .interpreters.partial_eval import Primitive
|
||||
from .tree_util import register_pytree_node
|
||||
from .util import safe_map
|
||||
|
||||
map = safe_map
|
||||
|
||||
jaxval_adders = {}
|
||||
|
||||
|
@ -17,8 +17,9 @@ from __future__ import absolute_import
|
||||
from .core import pack
|
||||
from .tree_util import build_tree, process_pytree
|
||||
from .linear_util import transformation_with_aux
|
||||
from .util import unzip2, partial
|
||||
from .util import safe_map, unzip2, partial
|
||||
|
||||
map = safe_map
|
||||
|
||||
@transformation_with_aux
|
||||
def flatten_fun(in_trees, *args, **kwargs):
|
||||
|
@ -18,6 +18,7 @@ from operator import attrgetter
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple, Counter, defaultdict
|
||||
from weakref import ref
|
||||
import six
|
||||
import types
|
||||
|
||||
from . import linear_util as lu
|
||||
@ -248,7 +249,6 @@ class Tracer(object):
|
||||
def __oct__(self): return self.aval._oct(self)
|
||||
|
||||
|
||||
|
||||
def __getattr__(self, name):
|
||||
# if the aval property raises an AttributeError, gets caught here
|
||||
assert skip_checks or name != "aval"
|
||||
@ -263,7 +263,10 @@ class Tracer(object):
|
||||
if t is aval_property:
|
||||
return attr.fget(self)
|
||||
elif t is aval_method:
|
||||
return types.MethodType(attr.fun, self, None)
|
||||
if six.PY3:
|
||||
return types.MethodType(attr.fun, self)
|
||||
else:
|
||||
return types.MethodType(attr.fun, self, None)
|
||||
else:
|
||||
return attr
|
||||
|
||||
@ -345,7 +348,7 @@ def new_master(trace_type, bottom=False):
|
||||
t = ref(master)
|
||||
del master
|
||||
if t() is not None:
|
||||
print trace_stack
|
||||
print(trace_stack)
|
||||
raise Exception('Leaked trace {}'.format(t()))
|
||||
|
||||
|
||||
|
@ -182,6 +182,7 @@ class LapaxMatrix(object):
|
||||
__sub__ = _make_infix_op(lax.sub)
|
||||
__mul__ = _make_infix_op(lax.batch_matmul)
|
||||
__div__ = _make_infix_op(lax.div)
|
||||
__truediv__ = _make_infix_op(lax.div)
|
||||
T = property(_make_infix_op(_matrix_transpose))
|
||||
|
||||
|
||||
|
@ -185,4 +185,4 @@ def make_schedule(scalar_or_schedule_fun):
|
||||
elif np.ndim(scalar_or_schedule_fun) == 0:
|
||||
return constant(scalar_or_schedule_fun)
|
||||
else:
|
||||
raise TypeError, type(scalar_or_schedule_fun)
|
||||
raise TypeError(type(scalar_or_schedule_fun))
|
||||
|
@ -27,6 +27,7 @@ import operator as op
|
||||
|
||||
import numpy as onp
|
||||
import numpy.random as npr
|
||||
from six.moves import reduce
|
||||
|
||||
from jax import lax
|
||||
from jax import random
|
||||
|
@ -20,11 +20,14 @@ from .. import core as core
|
||||
from ..core import Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, zero, Zero)
|
||||
from ..util import unzip2, unzip3, safe_zip, partial
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
|
||||
from ..tree_util import process_pytree, build_tree, register_pytree_node
|
||||
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
|
||||
|
||||
from six.moves import builtins, reduce
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
|
||||
def jvp(fun):
|
||||
return jvpfun(jvp_subtrace(fun))
|
||||
@ -102,13 +105,18 @@ def backward_pass(jaxpr, consts, freevar_vals, cotangent_in):
|
||||
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs])
|
||||
cts_out, ct_free_vars_out = get_primitive_transpose(eqn.primitive)(
|
||||
eqn.params, subjaxprs, sub_consts, sub_freevar_vals, invals, ct_in)
|
||||
# TODO(dougalm): support cases != 1
|
||||
assert(len(eqn.bound_subjaxprs) == 1)
|
||||
_, _, bound_vars = eqn.bound_subjaxprs[0]
|
||||
map(write_cotangent, bound_vars, ct_free_vars_out)
|
||||
else:
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
|
||||
|
||||
if cts_out is zero:
|
||||
cts_out = [zero for _ in eqn.invars]
|
||||
map(write_cotangent, eqn.invars, cts_out)
|
||||
# TODO(phawkins,dougalm): eqn.invars and cts_out can have different lengths
|
||||
for var, ct in builtins.zip(eqn.invars, cts_out):
|
||||
write_cotangent(var, ct)
|
||||
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
freevar_cts = map(read_cotangent, jaxpr.freevars)
|
||||
@ -195,7 +203,7 @@ class JVPTrace(Trace):
|
||||
elif xt is zero and yt is zero:
|
||||
return xt, yt
|
||||
else:
|
||||
raise TypeError, (xt, yt)
|
||||
raise TypeError((xt, yt))
|
||||
|
||||
def pack(self, tracers):
|
||||
primals = pack(t.primal for t in tracers)
|
||||
@ -344,7 +352,10 @@ def transposed_fun(jaxpr, in_tree_def, args):
|
||||
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
|
||||
yield out_jtuple, tree_def
|
||||
|
||||
def call_transpose(primitive, params, (jaxpr,), (consts,), (freevar_vals,), args, ct):
|
||||
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
jaxpr, = jaxpr
|
||||
consts, = consts
|
||||
freevar_vals, = freevar_vals
|
||||
assert isinstance(jaxpr, core.Jaxpr)
|
||||
assert all(a is None for a in args), "TODO(dougalm): handle non-tangent primal args"
|
||||
(ct, freevar_vals), in_tree_def = tree_to_jaxtuples((ct, freevar_vals))
|
||||
|
@ -18,6 +18,8 @@ import itertools as it
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from six.moves import reduce
|
||||
|
||||
from .. import core
|
||||
from ..core import Trace, Tracer, new_master, pack, AbstractTuple, JaxTuple
|
||||
from ..abstract_arrays import ShapedArray, make_shaped_array, array_types
|
||||
@ -83,7 +85,7 @@ class BatchTracer(Tracer):
|
||||
elif t is int:
|
||||
batch_dims = [self.batch_dim] * len(self.val)
|
||||
else:
|
||||
raise TypeError, t
|
||||
raise TypeError(t)
|
||||
return map(partial(BatchTracer, self.trace), self.val, batch_dims)
|
||||
|
||||
def full_lower(self):
|
||||
@ -150,7 +152,7 @@ def raise_to_shaped(aval):
|
||||
elif isinstance(aval, ShapedArray):
|
||||
return ShapedArray(aval.shape, aval.dtype)
|
||||
else:
|
||||
raise TypeError, type(aval)
|
||||
raise TypeError(type(aval))
|
||||
|
||||
def remove_batch_dim_from_aval(bdim, aval):
|
||||
t = type(aval)
|
||||
@ -167,7 +169,7 @@ def remove_batch_dim_from_aval(bdim, aval):
|
||||
unbatched_shape = tuple(onp.delete(aval.shape, bdim))
|
||||
return ShapedArray(unbatched_shape, aval.dtype)
|
||||
else:
|
||||
raise TypeError, t
|
||||
raise TypeError(t)
|
||||
|
||||
pytype_aval_mappings = {}
|
||||
|
||||
@ -259,7 +261,7 @@ def dimsize(dim, x):
|
||||
elif dim is None:
|
||||
return set()
|
||||
else:
|
||||
raise TypeError, type(dim)
|
||||
raise TypeError(type(dim))
|
||||
|
||||
def moveaxis(sz, dst, src, x):
|
||||
aval = get_aval(x)
|
||||
@ -284,7 +286,7 @@ def moveaxis(sz, dst, src, x):
|
||||
perm.insert(dst, src)
|
||||
return x.transpose(perm)
|
||||
else:
|
||||
raise TypeError, type(aval)
|
||||
raise TypeError(type(aval))
|
||||
|
||||
def broadcast(x, sz):
|
||||
try:
|
||||
|
@ -25,6 +25,8 @@ from ..core import (Trace, Tracer, new_master, Jaxpr, JaxprEqn, get_aval, pack,
|
||||
AbstractValue, AbstractTuple, unit, unitvar, Primitive,
|
||||
call_p)
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
class JaxprTrace(Trace):
|
||||
def pure(self, val):
|
||||
@ -212,7 +214,7 @@ def as_abstract_val(pv):
|
||||
elif isinstance(pv, JaxprTracerTuple):
|
||||
return AbstractTuple(map(as_abstract_val, pv))
|
||||
elif pv is None:
|
||||
raise TypeError, "{} is not abstract".format(pv)
|
||||
raise TypeError("{} is not abstract".format(pv))
|
||||
|
||||
|
||||
def partial_val_aval(pv, const):
|
||||
@ -283,7 +285,7 @@ def eqn_tracer_to_var(var, outvars, eqn):
|
||||
def tracers_to_jaxpr(in_tracers, out_tracer):
|
||||
newvar = gensym('')
|
||||
t_to_var = defaultdict(newvar)
|
||||
var = lambda t: t_to_var[t]
|
||||
var = lambda t: t_to_var[id(t)]
|
||||
sorted_tracers = toposort(out_tracer)
|
||||
invars = map(var, in_tracers)
|
||||
eqns = []
|
||||
@ -308,11 +310,11 @@ def tracers_to_jaxpr(in_tracers, out_tracer):
|
||||
destructuring_vars[key] = outvars
|
||||
else:
|
||||
outvars = destructuring_vars[key]
|
||||
t_to_var[t] = outvars[i]
|
||||
t_to_var[id(t)] = outvars[i]
|
||||
elif recipe is unit:
|
||||
t_to_var[t] = unitvar
|
||||
t_to_var[id(t)] = unitvar
|
||||
else:
|
||||
raise TypeError, recipe
|
||||
raise TypeError(recipe)
|
||||
|
||||
env_vars, env_vals = unzip2(env.items())
|
||||
const_vars, const_vals = unzip2(consts.items())
|
||||
@ -323,7 +325,7 @@ def tracers_to_jaxpr(in_tracers, out_tracer):
|
||||
|
||||
def gensym(suffix):
|
||||
counter = it.count()
|
||||
return lambda: Var(counter.next(), suffix)
|
||||
return lambda: Var(next(counter), suffix)
|
||||
|
||||
class Var(object):
|
||||
def __init__(self, count, suffix):
|
||||
@ -334,7 +336,7 @@ class Var(object):
|
||||
rem = self.count
|
||||
s = ''
|
||||
while True:
|
||||
rem, i = rem / 26, rem % 26
|
||||
rem, i = rem // 26, rem % 26
|
||||
s = chr(97 + i % 26) + s
|
||||
if not rem:
|
||||
break
|
||||
|
@ -18,13 +18,15 @@ from collections import namedtuple
|
||||
import itertools as it
|
||||
import numpy as onp
|
||||
import operator as op
|
||||
import six
|
||||
from six.moves import xrange
|
||||
|
||||
from absl import flags
|
||||
from .. import core
|
||||
from .. import ad_util
|
||||
from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types
|
||||
from ..core import AbstractTuple, JaxTuple, pack, valid_jaxtype
|
||||
from ..util import partial, partialmethod, memoize, unzip2, concatenate
|
||||
from ..util import partial, partialmethod, memoize, unzip2, concatenate, safe_map
|
||||
from ..linear_util import transformation_with_aux, memoize as linear_memoize
|
||||
from ..lib import xla_bridge as xb
|
||||
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
|
||||
@ -32,6 +34,7 @@ from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_device_values', True, 'Enable device-persistent values.')
|
||||
|
||||
map = safe_map
|
||||
|
||||
def apply_primitive(prim, *args, **kwargs):
|
||||
abstract_args = map(abstractify, args)
|
||||
@ -123,7 +126,8 @@ translations = {}
|
||||
|
||||
translations[core.pack_p] = lambda c, *xs: c.Tuple(*xs)
|
||||
translations[ad_util.add_jaxvals_p] = lambda c, x, y: c.Add(x, y)
|
||||
translations[core.call_p] = lambda c, (subc, a1), *a2: c.Call(subc, a1 + a2)
|
||||
translations[core.call_p] = lambda c, subc_a1, *a2: c.Call(subc_a1[0],
|
||||
subc_a1[1] + a2)
|
||||
translations[core.identity_p] = lambda c, x: x
|
||||
|
||||
|
||||
@ -242,7 +246,8 @@ class DeviceArray(DeviceValue):
|
||||
__bool__ = __nonzero__ = partialmethod(forward_to_value, bool)
|
||||
__float__ = partialmethod(forward_to_value, float)
|
||||
__int__ = partialmethod(forward_to_value, int)
|
||||
__long__ = partialmethod(forward_to_value, long)
|
||||
if six.PY2:
|
||||
__long__ = partialmethod(forward_to_value, long)
|
||||
__complex__ = partialmethod(forward_to_value, complex)
|
||||
__hex__ = partialmethod(forward_to_value, hex)
|
||||
__oct__ = partialmethod(forward_to_value, oct)
|
||||
@ -253,6 +258,14 @@ class DeviceArray(DeviceValue):
|
||||
# clobbered when jax.numpy is imported, but useful in tests
|
||||
def __eq__(self, other): return self._value == other
|
||||
|
||||
def __hash__(self):
|
||||
# TODO(mattjj): this is not semantically correct because it is possible
|
||||
# __eq__ is true for values with unequal __hash__ values. However, the
|
||||
# main use case at the moment is memoization for which false negatives are
|
||||
# fine.
|
||||
return id(self)
|
||||
|
||||
|
||||
core.pytype_aval_mappings[DeviceArray] = ConcreteArray
|
||||
pytype_aval_mappings[DeviceArray] = make_shaped_array
|
||||
canonicalize_dtype_handlers[DeviceArray] = identity
|
||||
@ -267,7 +280,7 @@ def xla_shape(x):
|
||||
if type(x) in (core.AbstractTuple, core.JaxTuple):
|
||||
return xb.Shape.tuple_shape(tuple(map(xla_shape, x)))
|
||||
else:
|
||||
raise TypeError, type(x)
|
||||
raise TypeError(type(x))
|
||||
|
||||
|
||||
# For callable XLA Computations (as opposed to, e.g., Computations used in the
|
||||
@ -298,7 +311,7 @@ def tree_flatten(maybe_tree):
|
||||
elif core.skip_checks or valid_jaxtype(maybe_tree):
|
||||
return [maybe_tree], leaf
|
||||
else:
|
||||
raise TypeError, type(maybe_tree)
|
||||
raise TypeError(type(maybe_tree))
|
||||
|
||||
JTupleTreeDef = namedtuple("JTupleTreeDef", ["child_specs"])
|
||||
|
||||
@ -313,7 +326,7 @@ def build_tree(xs, tree_spec):
|
||||
elif type(tree_spec) is JTupleTreeDef:
|
||||
return pack(map(partial(build_tree, xs), tree_spec.child_specs))
|
||||
else:
|
||||
raise TypeError, type(tree_spec)
|
||||
raise TypeError(type(tree_spec))
|
||||
|
||||
|
||||
def device_put(x):
|
||||
@ -364,4 +377,5 @@ xla_call = partial(core.call_bind, xla_call_p)
|
||||
xla_call_p.def_custom_bind(xla_call)
|
||||
xla_call_p.def_impl(xla_call_impl)
|
||||
|
||||
translations[xla_call_p] = lambda c, (subc, a1), *a2: c.Call(subc, a1 + a2)
|
||||
translations[xla_call_p] = lambda c, subc_a1, *a2: c.Call(subc_a1[0],
|
||||
subc_a1[1] + a2)
|
||||
|
46
jax/lax.py
46
jax/lax.py
@ -15,12 +15,13 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import __builtin__
|
||||
|
||||
import collections
|
||||
from .util import partial
|
||||
import itertools
|
||||
import operator
|
||||
import six
|
||||
from six.moves import builtins, xrange
|
||||
import string
|
||||
|
||||
import numpy as onp
|
||||
@ -40,9 +41,14 @@ from .util import curry, safe_zip, unzip2
|
||||
from .tree_util import build_tree
|
||||
from .lib import xla_bridge
|
||||
|
||||
_max = __builtin__.max
|
||||
_min = __builtin__.max
|
||||
_max = builtins.max
|
||||
_min = builtins.max
|
||||
|
||||
if six.PY3:
|
||||
def maketrans(s1, s2):
|
||||
return s1.maketrans(s1, s2)
|
||||
else:
|
||||
maketrans = string.maketrans
|
||||
|
||||
### traceables
|
||||
|
||||
@ -383,7 +389,7 @@ def _while_loop(cond_fun, body_fun, init_val):
|
||||
params = OpaqueParam((abs_out, cond_jaxpr, cond_consts, body_jaxpr, body_consts))
|
||||
out_flat = while_p.bind(init_val_flat, opaque_params=params)
|
||||
if out_tree() != in_tree:
|
||||
raise TypeError, "body_fun input and output must have identical structure"
|
||||
raise TypeError("body_fun input and output must have identical structure")
|
||||
return build_tree(out_tree(), out_flat)
|
||||
|
||||
class OpaqueParam(object):
|
||||
@ -500,8 +506,9 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
# state: (upper limit, index, loop value)
|
||||
# The `lt` and `add` functions are added to the namespace programmatically.
|
||||
_, _, result = _while_loop(
|
||||
lambda (upper, i, _): lt(i, upper),
|
||||
lambda (upper, i, x): (upper, add(i, 1), body_fun(i, x)),
|
||||
lambda upper_i_x: lt(upper_i_x[1], upper_i_x[0]),
|
||||
lambda upper_i_x: (upper_i_x[0], add(upper_i_x[1], 1),
|
||||
body_fun(upper_i_x[1], upper_i_x[2])),
|
||||
(upper, lower, init_val))
|
||||
return result
|
||||
|
||||
@ -519,7 +526,7 @@ def foreach_loop(sequence, body_fun, init_val):
|
||||
"""
|
||||
_, result = fori_loop(
|
||||
0, len(sequence),
|
||||
lambda i, (seq, val): body_fun(seq[i], val),
|
||||
lambda i, seq_val: body_fun(seq_val[0][i], seq_val[1]),
|
||||
(sequence, init_val))
|
||||
return result
|
||||
|
||||
@ -708,7 +715,7 @@ standard_binop = partial(binop, _input_dtype)
|
||||
def _brcast(x, *others):
|
||||
# used in jvprules to make binop broadcasting explicit for transposability.
|
||||
# requires shape info during jvp tracing, which isn't strictly necessary.
|
||||
shapes = filter(None, map(onp.shape, (x,) + others))
|
||||
shapes = list(filter(None, map(onp.shape, (x,) + others)))
|
||||
shape = tuple(shapes and onp.max(shapes, axis=0))
|
||||
if onp.shape(x) != shape:
|
||||
return _brcast_to(x, shape)
|
||||
@ -1017,8 +1024,8 @@ def conv_general_dilated_transpose_rhs(
|
||||
lhs_sdims, rhs_sdims, out_sdims = _get_sdims(dimension_numbers)
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
trans_dimension_numbers = (_charswap("C", "N", lhs_spec),
|
||||
out_spec.translate(string.maketrans("NC", "IO")),
|
||||
rhs_spec.translate(string.maketrans("IO", "NC")))
|
||||
out_spec.translate(maketrans("NC", "IO")),
|
||||
rhs_spec.translate(maketrans("IO", "NC")))
|
||||
|
||||
padding = _conv_general_vjp_rhs_padding(
|
||||
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
|
||||
@ -1543,9 +1550,8 @@ def slice_shape_rule(operand, start_indices, limit_indices, strides,
|
||||
msg = "slice strides must be positive, got {}"
|
||||
raise TypeError(msg.format(strides))
|
||||
|
||||
result_shape = onp.divide(onp.add(onp.subtract(limit_indices, start_indices),
|
||||
strides) - 1,
|
||||
strides)
|
||||
result_shape = onp.floor_divide(
|
||||
onp.add(onp.subtract(limit_indices, start_indices), strides) - 1, strides)
|
||||
return tuple(result_shape)
|
||||
|
||||
def slice_translation_rule(c, operand, start_indices, limit_indices, strides,
|
||||
@ -1911,7 +1917,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
||||
padding):
|
||||
pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding)
|
||||
operand_padded = onp.add(operand_shape, onp.add(*zip(*pads)))
|
||||
t = onp.divide(onp.subtract(operand_padded, window_dimensions), window_strides) + 1
|
||||
t = onp.floor_divide(
|
||||
onp.subtract(operand_padded, window_dimensions), window_strides) + 1
|
||||
return tuple(t)
|
||||
|
||||
|
||||
@ -2127,7 +2134,7 @@ def _check_same_dtypes(name, ignore_fp_precision, *dtypes):
|
||||
"""Check that dtypes agree, possibly ignoring float precision."""
|
||||
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
||||
# allows mixed floating point precision, but the HLO verifier often rejects it
|
||||
dtypes = map(onp.dtype, dtypes) # canonicalize
|
||||
dtypes = list(map(onp.dtype, dtypes)) # canonicalize
|
||||
if ignore_fp_precision:
|
||||
dtypes = [
|
||||
onp.floating if onp.issubdtype(dtype, onp.floating)
|
||||
@ -2172,7 +2179,8 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
|
||||
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
|
||||
|
||||
lhs_padded = onp.add(lhs_shape[2:], onp.add(*zip(*pads)))
|
||||
out_space = onp.divide(onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
|
||||
out_space = onp.floor_divide(
|
||||
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
|
||||
out_space = onp.maximum(0, out_space)
|
||||
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
|
||||
return tuple(out_shape)
|
||||
@ -2268,7 +2276,7 @@ def remaining(original, *removed_lists):
|
||||
|
||||
|
||||
def _charswap(a, b, s):
|
||||
return s.translate(string.maketrans(a+b, b+a))
|
||||
return s.translate(maketrans(a + b, b + a))
|
||||
|
||||
|
||||
def _get_sdims(dimension_numbers):
|
||||
@ -2339,13 +2347,13 @@ def _eq_meet(a, b):
|
||||
|
||||
def maybe_tracer_tuple_to_abstract_tuple(tup):
|
||||
if isinstance(tup, pe.JaxprTracerTuple):
|
||||
return core.AbstractTuple(map(maybe_tracer_tuple_to_abstract_tuple, tup))
|
||||
return core.AbstractTuple(list(map(maybe_tracer_tuple_to_abstract_tuple, tup)))
|
||||
elif isinstance(tup, core.AbstractValue):
|
||||
return tup
|
||||
elif tup is None:
|
||||
return core.AbstractTuple(()) # TODO(dougalm): check this
|
||||
else:
|
||||
raise TypeError, tup
|
||||
raise TypeError(tup)
|
||||
|
||||
|
||||
def subvals(lst, replace):
|
||||
|
@ -15,7 +15,6 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import __builtin__
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
@ -24,10 +23,12 @@ import numpy as onp
|
||||
import opt_einsum
|
||||
import scipy.special
|
||||
|
||||
_slice = __builtin__.slice
|
||||
_max = __builtin__.max
|
||||
_min = __builtin__.min
|
||||
_map = __builtin__.map
|
||||
from six.moves import builtins
|
||||
|
||||
_slice = builtins.slice
|
||||
_max = builtins.max
|
||||
_min = builtins.min
|
||||
_map = builtins.map
|
||||
|
||||
neg = onp.negative
|
||||
sign = onp.sign
|
||||
@ -87,13 +88,13 @@ sub = onp.subtract
|
||||
mul = onp.multiply
|
||||
|
||||
def div(lhs, rhs):
|
||||
quotient = onp.divide(lhs, rhs)
|
||||
if onp.issubdtype(onp.result_type(lhs), onp.integer):
|
||||
quotient = onp.floor_divide(lhs, rhs)
|
||||
select = onp.logical_and(onp.sign(lhs) != onp.sign(rhs),
|
||||
onp.remainder(lhs, rhs) != 0)
|
||||
return onp.where(select, quotient + 1, quotient)
|
||||
else:
|
||||
return quotient
|
||||
return onp.divide(lhs, rhs)
|
||||
|
||||
def rem(lhs, rhs):
|
||||
return onp.sign(lhs) * onp.remainder(onp.abs(lhs), onp.abs(rhs))
|
||||
@ -150,14 +151,14 @@ dot = onp.dot
|
||||
|
||||
def dot_general(lhs, rhs, dimension_numbers):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
new_id = itertools.count().next
|
||||
lhs_axis_ids = [new_id() for _ in lhs.shape]
|
||||
rhs_axis_ids = [new_id() for _ in rhs.shape]
|
||||
new_id = itertools.count()
|
||||
lhs_axis_ids = [next(new_id) for _ in lhs.shape]
|
||||
rhs_axis_ids = [next(new_id) for _ in rhs.shape]
|
||||
lhs_out_axis_ids = lhs_axis_ids[:]
|
||||
rhs_out_axis_ids = rhs_axis_ids[:]
|
||||
|
||||
for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):
|
||||
shared_id = new_id()
|
||||
shared_id = next(new_id)
|
||||
lhs_axis_ids[lhs_axis] = shared_id
|
||||
rhs_axis_ids[rhs_axis] = shared_id
|
||||
lhs_out_axis_ids[lhs_axis] = None
|
||||
@ -165,7 +166,7 @@ def dot_general(lhs, rhs, dimension_numbers):
|
||||
|
||||
batch_ids = []
|
||||
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
|
||||
shared_id = new_id()
|
||||
shared_id = next(new_id)
|
||||
lhs_axis_ids[lhs_axis] = shared_id
|
||||
rhs_axis_ids[rhs_axis] = shared_id
|
||||
lhs_out_axis_ids[lhs_axis] = None
|
||||
@ -215,7 +216,7 @@ select = onp.where
|
||||
def slice(operand, start_indices, limit_indices, strides=None): # pylint: disable=redefined-builtin
|
||||
if strides is None:
|
||||
strides = onp.ones(len(start_indices)).astype(int)
|
||||
slices = _map(_slice, start_indices, limit_indices, strides)
|
||||
slices = tuple(_map(_slice, start_indices, limit_indices, strides))
|
||||
return operand[slices]
|
||||
|
||||
def dynamic_slice(operand, start_indices, slice_sizes):
|
||||
@ -227,7 +228,7 @@ def dynamic_slice(operand, start_indices, slice_sizes):
|
||||
return out
|
||||
|
||||
def dynamic_update_slice(operand, update, start_indices):
|
||||
slices = _map(_slice, start_indices, onp.add(start_indices, update.shape))
|
||||
slices = tuple(_map(_slice, start_indices, onp.add(start_indices, update.shape)))
|
||||
updated_operand = onp.copy(operand)
|
||||
updated_operand[slices] = update
|
||||
return updated_operand
|
||||
@ -295,8 +296,8 @@ def _conv_view(lhs, rhs_shape, window_strides, pads, pad_value):
|
||||
out_strides = onp.multiply(window_strides, lhs.strides[2:])
|
||||
view_strides = lhs.strides[:1] + tuple(out_strides) + lhs.strides[1:]
|
||||
|
||||
out_shape = onp.divide(onp.subtract(in_shape, filter_shape),
|
||||
window_strides) + 1
|
||||
out_shape = onp.floor_divide(
|
||||
onp.subtract(in_shape, filter_shape), window_strides) + 1
|
||||
view_shape = lhs.shape[:1] + tuple(out_shape) + rhs_shape[1:]
|
||||
|
||||
view = onp.lib.stride_tricks.as_strided(lhs, view_shape, view_strides)
|
||||
|
@ -112,10 +112,10 @@ def get_xla_client():
|
||||
xla_client.initialize_platform_name(FLAGS.jax_platform_name)
|
||||
else:
|
||||
try:
|
||||
xla_client.initialize_platform_name('CUDA')
|
||||
xla_client.initialize_platform_name(b'CUDA')
|
||||
except RuntimeError:
|
||||
warnings.warn('No GPU found, falling back to CPU.')
|
||||
xla_client.initialize_platform_name('Host')
|
||||
xla_client.initialize_platform_name(b'Host')
|
||||
return xla_client
|
||||
|
||||
|
||||
|
@ -42,6 +42,8 @@ class Store(object):
|
||||
def __nonzero__(self):
|
||||
return hasattr(self, '_val')
|
||||
|
||||
__bool__ = __nonzero__
|
||||
|
||||
|
||||
@curry
|
||||
def staged(f, *init_args):
|
||||
@ -74,7 +76,7 @@ class WrappedFun(object):
|
||||
stack = []
|
||||
for gen, gen_args, out_store in self.transforms:
|
||||
gen = gen(*(gen_args + tuple(args)))
|
||||
args = gen.next()
|
||||
args = next(gen)
|
||||
stack.append((gen, out_store))
|
||||
|
||||
del gen
|
||||
|
@ -15,7 +15,8 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import __builtin__
|
||||
|
||||
from six.moves import builtins
|
||||
|
||||
import six
|
||||
import numpy as onp
|
||||
@ -38,11 +39,11 @@ import jax.lax as lax
|
||||
|
||||
|
||||
# We replace some builtin names to follow Numpy's API, so we capture here.
|
||||
_all = __builtin__.all
|
||||
_any = __builtin__.any
|
||||
_max = __builtin__.max
|
||||
_min = __builtin__.min
|
||||
_sum = __builtin__.sum
|
||||
_all = builtins.all
|
||||
_any = builtins.any
|
||||
_max = builtins.max
|
||||
_min = builtins.min
|
||||
_sum = builtins.sum
|
||||
|
||||
# We need some numpy scalars
|
||||
# TODO(mattjj): handle constants in an indirected, less explicit way?
|
||||
@ -985,7 +986,7 @@ def _canonicalize_tuple_index(arr, idx):
|
||||
if ellipsis_index is not None:
|
||||
if next(ellipses, None) is not None:
|
||||
msg = "Multiple ellipses (...) not supported: {}."
|
||||
raise IndexError(msg.format(map(type, idx)))
|
||||
raise IndexError(msg.format(list(map(type, idx))))
|
||||
colons = (slice(None),) * (arr.ndim - len_without_none)
|
||||
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
|
||||
elif len_without_none < arr.ndim:
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from six.moves import reduce
|
||||
|
||||
|
||||
class PrettyPrint(object):
|
||||
"""Crude Hughes-inspired pretty printer."""
|
||||
|
@ -58,7 +58,7 @@ class PRNGKey(object):
|
||||
def from_keypair(cls, keypair):
|
||||
"""Internal method to create a PRNGKey instance from a raw key pair."""
|
||||
new = cls.__new__(cls)
|
||||
new.keypair = keypair
|
||||
new.keypair = tuple(keypair)
|
||||
return new
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ def _make_rotate_left(dtype):
|
||||
|
||||
def _bit_stats(bits):
|
||||
"""This is a debugging function to compute the statistics of bit fields."""
|
||||
return onp.array([map(int, onp.binary_repr(x, 64)) for x in bits]).mean(0)
|
||||
return onp.array([list(map(int, onp.binary_repr(x, 64))) for x in bits]).mean(0)
|
||||
|
||||
|
||||
### hash function and split
|
||||
|
@ -15,6 +15,7 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import functools
|
||||
import re
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
@ -277,6 +278,13 @@ def check_raises(thunk, err_type, msg):
|
||||
except err_type as e:
|
||||
assert str(e) == msg, "{}\n\n{}\n".format(e, msg)
|
||||
|
||||
def check_raises_regexp(thunk, err_type, pattern):
|
||||
try:
|
||||
thunk()
|
||||
assert False
|
||||
except err_type as e:
|
||||
assert re.match(pattern, str(e)), "{}\n\n{}\n".format(e, pattern)
|
||||
|
||||
|
||||
class JaxTestCase(parameterized.TestCase):
|
||||
"""Base class for JAX tests including numerical checks and boilerplate."""
|
||||
|
@ -16,8 +16,11 @@ from __future__ import absolute_import
|
||||
|
||||
from collections import namedtuple
|
||||
import itertools as it
|
||||
from six.moves import reduce
|
||||
|
||||
from .util import unzip2, concatenate, partial
|
||||
from .util import unzip2, concatenate, partial, safe_map
|
||||
|
||||
map = safe_map
|
||||
|
||||
|
||||
def tree_map(f, tree):
|
||||
|
22
jax/util.py
22
jax/util.py
@ -24,16 +24,16 @@ allow_memoize_hash_failures = False
|
||||
def safe_zip(*args):
|
||||
n = len(args[0])
|
||||
for arg in args[1:]:
|
||||
assert len(arg) == n, 'length mismatch: {}'.format(map(len, args))
|
||||
return zip(*args)
|
||||
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
|
||||
return list(zip(*args))
|
||||
|
||||
|
||||
def safe_map(f, *args):
|
||||
args = map(list, args)
|
||||
args = list(map(list, args))
|
||||
n = len(args[0])
|
||||
for arg in args[1:]:
|
||||
assert len(arg) == n, 'length mismatch: {}'.format(map(len, args))
|
||||
return map(f, *args)
|
||||
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
|
||||
return list(map(f, *args))
|
||||
|
||||
|
||||
def unzip2(xys):
|
||||
@ -82,10 +82,10 @@ def toposort(end_node):
|
||||
stack = [end_node]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if node in child_counts:
|
||||
child_counts[node] += 1
|
||||
if id(node) in child_counts:
|
||||
child_counts[id(node)] += 1
|
||||
else:
|
||||
child_counts[node] = 1
|
||||
child_counts[id(node)] = 1
|
||||
stack.extend(node.parents)
|
||||
|
||||
sorted_nodes = []
|
||||
@ -94,16 +94,16 @@ def toposort(end_node):
|
||||
node = childless_nodes.pop()
|
||||
sorted_nodes.append(node)
|
||||
for parent in node.parents:
|
||||
if child_counts[parent] == 1:
|
||||
if child_counts[id(parent)] == 1:
|
||||
childless_nodes.append(parent)
|
||||
else:
|
||||
child_counts[parent] -= 1
|
||||
child_counts[id(parent)] -= 1
|
||||
|
||||
return sorted_nodes[::-1]
|
||||
|
||||
|
||||
def split_merge(predicate, xs):
|
||||
sides = map(predicate, xs)
|
||||
sides = list(map(predicate, xs))
|
||||
lhs = [x for x, s in zip(xs, sides) if s]
|
||||
rhs = [x for x, s in zip(xs, sides) if not s]
|
||||
def merge(new_lhs, new_rhs):
|
||||
|
@ -16,6 +16,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
import numpy as onp
|
||||
from absl.testing import absltest
|
||||
from jax import test_util as jtu
|
||||
@ -102,11 +104,11 @@ class APITest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
jtu.check_raises(lambda: grad(f)("foo"), TypeError,
|
||||
"Argument 'foo' of type <type 'str'> is not a valid JAX type")
|
||||
jtu.check_raises_regexp(lambda: grad(f)("foo"), TypeError,
|
||||
"Argument 'foo' of type <.*'str'> is not a valid JAX type")
|
||||
|
||||
jtu.check_raises(lambda: jit(f)("foo"), TypeError,
|
||||
"Argument 'foo' of type <type 'str'> is not a valid JAX type")
|
||||
jtu.check_raises_regexp(lambda: jit(f)("foo"), TypeError,
|
||||
"Argument 'foo' of type <.*'str'> is not a valid JAX type")
|
||||
|
||||
# TODO(dougalm): enable when we remove 'None' from pytree nodes
|
||||
# def test_bad_output(self):
|
||||
@ -176,12 +178,18 @@ class APITest(jtu.JaxTestCase):
|
||||
return x
|
||||
|
||||
assert jit(f, static_argnums=(1,))(0, 5) == 10
|
||||
jtu.check_raises(lambda: jit(f)(0, 5), TypeError, concretization_err_msg(int))
|
||||
jtu.check_raises_regexp(
|
||||
lambda: jit(f)(0, 5), TypeError,
|
||||
"('JaxprTracer' object cannot be interpreted as an integer"
|
||||
"|Abstract value passed to function.*)")
|
||||
|
||||
def test_casts(self):
|
||||
for castfun in [float, int, long, complex, hex, oct]:
|
||||
for castfun in [float, complex, hex, oct] + list(six.integer_types):
|
||||
f = lambda x: castfun(x)
|
||||
jtu.check_raises(lambda: jit(f)(0), TypeError, concretization_err_msg(castfun))
|
||||
jtu.check_raises_regexp(
|
||||
lambda: jit(f)(0), TypeError,
|
||||
"('JaxprTracer' object cannot be interpreted as an integer"
|
||||
"|Abstract value passed to function.*)")
|
||||
|
||||
def test_unimplemented_interpreter_rules(self):
|
||||
foo_p = Primitive('foo')
|
||||
|
@ -65,19 +65,23 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
# We test the hash by comparing to known values provided in the test code of
|
||||
# the original reference implementation of Threefry. For the values, see
|
||||
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
|
||||
expected = ("0x6b200159L", "0x99ba4efeL")
|
||||
def result_to_hex(result):
|
||||
return tuple([hex(x.copy()).rstrip("L") for x in result])
|
||||
|
||||
expected = ("0x6b200159", "0x99ba4efe")
|
||||
result = random.threefry_2x32(onp.uint32([0, 0]), onp.uint32([0, 0]))
|
||||
self.assertEqual(expected, tuple(map(hex, result)))
|
||||
|
||||
expected = ("0x1cb996fcL", "0xbb002be7L")
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
expected = ("0x1cb996fc", "0xbb002be7")
|
||||
result = random.threefry_2x32(onp.uint32([-1, -1]), onp.uint32([-1, -1]))
|
||||
self.assertEqual(expected, tuple(map(hex, result)))
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
expected = ("0xc4923a9cL", "0x483df7a0L")
|
||||
expected = ("0xc4923a9c", "0x483df7a0")
|
||||
result = random.threefry_2x32(
|
||||
onp.uint32([0x13198a2e, 0x03707344]),
|
||||
onp.uint32([0x243f6a88, 0x85a308d3]))
|
||||
self.assertEqual(expected, tuple(map(hex, result)))
|
||||
self.assertEqual(expected, result_to_hex(result))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
|
||||
|
Loading…
x
Reference in New Issue
Block a user