prevent jit from treating keyword args as static

fixes #523
This commit is contained in:
Matthew Johnson 2019-04-10 22:09:14 -07:00
parent d573084783
commit 9c2e1c35b1
13 changed files with 197 additions and 128 deletions

View File

@ -37,7 +37,8 @@ from . import core
from . import linear_util as lu
from .core import pack, eval_jaxpr
from .api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
pytree_fun_to_flatjaxtuple_fun, apply_jaxtree_fun, wraps)
pytree_fun_to_flatjaxtuple_fun, apply_jaxtree_fun, wraps,
pytree_fun_to_jaxtupletree_fun2)
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef,
tree_map, tree_flatten, tree_unflatten, tree_structure,
tree_transpose, leaf)
@ -69,12 +70,13 @@ def jit(fun, static_argnums=()):
fun: Function to be jitted. Should be a pure function, as side-effects may
only be executed once. Its positional arguments and return value should be
arrays, scalars, or standard Python containers (tuple/list/dict) thereof.
Keyword arguments and positional arguments specified by `static_argnums`
can be anything at all. These are treated as static (see below).
static_argnums: A tuple of ints. Specifies which arguments to treat as
static (compile-time constant). Operations that only depend on static
arguments will be constant-folded. Calling the jitted function with
different values for these constants will trigger recompilation.
Positional arguments indicated by `static_argnums` can be anything at all.
static_argnums: A tuple of ints. Specifies which positional arguments to
treat as static (compile-time constant). Operations that only depend on
static arguments will be constant-folded. Calling the jitted function with
different values for these constants will trigger recompilation. If the
jitted function is called with fewer positional arguments than indicated
by `static_argnums` then an error is raised.
Returns:
A wrapped version of `fun`, set up for just-in-time compilation.
@ -98,16 +100,22 @@ def jit(fun, static_argnums=()):
def f_jitted(*args, **kwargs):
if _jit_is_disabled or config.read('jax_disable_jit'):
return fun(*args, **kwargs)
f = lu.wrap_init(fun, kwargs)
if static_argnums and max(static_argnums) >= len(args):
msg = ("jitted function has static_argnums={} but was called with only {}"
" positional arguments")
raise TypeError(msg.format(static_argnums, len(args)))
f = lu.wrap_init(fun)
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, dyn_args = _argnums_partial(f, dyn_argnums, args)
jaxtupletree_args, in_trees = unzip2(map(pytree_to_jaxtupletree, dyn_args))
_check_args(jaxtupletree_args)
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
return build_tree(out_tree(), jaxtupletree_out)
jaxtuple_args, in_trees = unzip2(map(pytree_to_jaxtupletree, dyn_args))
jaxtuple_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
_check_args(jaxtuple_args)
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun2(f, kwargs_tree, in_trees)
out = xla.xla_call(jaxtree_fun, jaxtuple_kwargs, *jaxtuple_args)
return build_tree(out_tree(), out)
f_jitted.__name__ = "jit({})".format(f_jitted.__name__)
jitted_name = "jit({}, static_argnums={})"
f_jitted.__name__ = jitted_name.format(f_jitted.__name__, static_argnums)
return f_jitted
@ -157,15 +165,16 @@ def xla_computation(fun, static_argnums=()):
aval = xla.abstractify(x)
return pe.PartialVal((aval, core.unit))
@wraps(fun)
def computation_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
jax_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
pvals = map(pv_like, jax_args)
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
return xla.build_jaxpr(jaxpr, consts, *map(xla.abstractify, jax_args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun2(wrapped, kwargs_tree, in_trees)
pvals = map(pv_like, (jax_kwargs,) + tuple(jax_args))
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
return xla.build_jaxpr(jaxpr, consts, xla.abstractify(jax_kwargs),
*map(xla.abstractify, jax_args))
return computation_maker
@ -436,19 +445,20 @@ def pmap(fun, axis_name=None):
raise TypeError(msg.format(axis_sizes))
axis_size = axis_sizes.pop()
jaxtupletree_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
_check_args(jaxtupletree_args)
f = lu.wrap_init(fun, kwargs)
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
jaxtupletree_out = pxla.xla_pmap(f, *jaxtupletree_args,
axis_name=axis_name, axis_size=axis_size)
return build_tree(out_tree(), jaxtupletree_out)
f = lu.wrap_init(fun)
jaxtuple_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
jaxtuple_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
_check_args(jaxtuple_args)
f, out_tree = pytree_fun_to_jaxtupletree_fun2(f, kwargs_tree, in_trees)
out = pxla.xla_pmap(f, jaxtuple_kwargs, *jaxtuple_args,
axis_name=axis_name, axis_size=axis_size)
return build_tree(out_tree(), out)
namestr = "pmap({}, axis_name={})".format
f_jitted.__name__ = namestr(f_jitted.__name__, axis_name)
return f_jitted
def serial_pmap(fun, axis_name=None, in_axes=0, out_axes=0):
def _serial_pmap(fun, axis_name=None, in_axes=0, out_axes=0):
"""Vectorizing pseudo-map for single-program multiple-data (SPMD) functions."""
axis_name = _TempAxisName() if axis_name is None else axis_name
@ -467,7 +477,7 @@ class _TempAxisName(object):
return '<temp axis {}>'.format(hex(id(self)))
def papply(fun, axis_size, in_axes=0, out_axes=0):
def _papply(fun, axis_size, in_axes=0, out_axes=0):
"""Apply a function using parallel computation by sharding inputs."""
axis_name = parallel.newvar()
@ -652,10 +662,10 @@ def vjp(fun, *primals, **kwargs):
def trace_to_jaxpr(traceable, py_pvals, **kwargs):
fun = lu.wrap_init(traceable)
fun = lu.wrap_init(traceable, kwargs)
pvals, in_trees = unzip2(map(tree_to_pval_tuples, py_pvals))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
return jaxpr, consts, out_pval, (in_trees, out_tree())
def lift_jaxpr(jaxpr, consts, io_tree, pvals, py_args):
@ -714,11 +724,11 @@ def make_jaxpr(fun):
@wraps(fun)
def jaxpr_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
wrapped = lu.wrap_init(fun, kwargs)
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
pvals = map(pv_like, jax_args)
jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals)
return jaxpr
jaxpr_maker.__name__ = "make_jaxpr({})".format(jaxpr_maker.__name__)
@ -751,7 +761,7 @@ def _argnums_partial_(dyn_argnums, fixed_args, *dyn_args):
args = [None if arg is None else arg.val for arg in fixed_args]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
ans = yield args
ans = yield args, {}
yield ans
def _check_args(args):
@ -863,11 +873,11 @@ def make_graphviz(fun):
@wraps(fun)
def graphviz_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
wrapped = lu.wrap_init(fun, kwargs)
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
pvals = map(pv_like, jax_args)
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
return jaxpr_to_graphviz(jaxpr, consts)
graphviz_maker.__name__ = "make_graphviz({})".format(graphviz_maker.__name__)

View File

@ -39,11 +39,17 @@ def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
def get_doc(fun): return getattr(fun, "__doc__", "")
@transformation_with_aux
def pytree_fun_to_jaxtupletree_fun(args_trees, *args):
py_args = map(build_tree, args_trees, args)
ans = yield py_args, {}
yield pytree_to_jaxtupletree(ans)
@transformation_with_aux
def pytree_fun_to_jaxtupletree_fun(in_trees, *args):
py_args = map(build_tree, in_trees, args)
ans = yield py_args
def pytree_fun_to_jaxtupletree_fun2(kwargs_tree, args_trees, kwargs, *args):
py_args = map(build_tree, args_trees, args)
py_kwargs = build_tree(kwargs_tree, kwargs)
ans = yield py_args, py_kwargs
yield pytree_to_jaxtupletree(ans)
def apply_jaxtree_fun(fun, io_tree, *py_args):
@ -62,7 +68,7 @@ pytree_to_jaxtupletree = partial(process_pytree, pack)
@transformation_with_aux
def pytree_fun_to_flatjaxtuple_fun(in_trees, *args):
py_args = map(tree_unflatten, in_trees, args)
ans = yield py_args
ans = yield py_args, {}
yield pytree_to_flatjaxtuple(ans)
def pytree_to_flatjaxtuple(pytree):

View File

@ -525,7 +525,7 @@ def apply_todos(todos, x):
@lu.transformation_with_aux
def process_env_traces(primitive, level, *args):
ans = yield args
ans = yield args, {}
todo = []
while isinstance(ans, Tracer) and ans.trace.level > level:
t = ans.trace

View File

@ -39,5 +39,5 @@ def ravel_list(*lst):
@transformation_with_aux
def ravel_fun(unravel_inputs, flat_in, **kwargs):
pytree_args = unravel_inputs(flat_in)
ans = yield pytree_args
ans = yield pytree_args, {}
yield ravel_pytree(ans)

View File

@ -44,7 +44,7 @@ def jvp(fun, has_aux=False):
@transformation
def jvpfun(primals, tangents):
with new_master(JVPTrace) as master:
out_primal, out_tangent = yield master, primals, tangents
out_primal, out_tangent = yield (master, primals, tangents), {}
del master
out_tangent = instantiate_zeros(out_primal, out_tangent)
yield (out_primal, out_tangent)
@ -55,7 +55,7 @@ def jvp_subtrace(master, primals, tangents):
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x.trace.level < trace.level
ans = yield map(partial(JVPTracer, trace), primals, tangents)
ans = yield map(partial(JVPTracer, trace), primals, tangents), {}
out_tracer = trace.full_raise(ans)
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
yield (out_primal, out_tangent)
@ -66,7 +66,7 @@ def jvp_subtrace_aux(master, primals, tangents):
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x.trace.level < trace.level
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents)
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
out_tracer, aux_tracer = map(trace.full_raise, (ans, aux))
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
aux = aux_tracer.primal # ignore aux tangent
@ -75,7 +75,7 @@ def jvp_subtrace_aux(master, primals, tangents):
@transformation
def pack_output(*args):
ans = yield args
ans = yield args, {}
yield pack(ans)
def linearize(traceable, *primals, **kwargs):
@ -399,7 +399,7 @@ def instantiate_zeros(example, tangent):
@transformation_with_aux
def traceable(in_tree_def, new_primals, new_tangents):
new_tangents = build_tree(in_tree_def, new_tangents)
primal_out, tangent_out = yield new_primals, new_tangents
primal_out, tangent_out = yield (new_primals, new_tangents), {}
out_jtuple, tree_def = tree_to_jaxtuples((primal_out, tangent_out))
yield out_jtuple, tree_def
@ -407,7 +407,7 @@ def traceable(in_tree_def, new_primals, new_tangents):
def transposed_fun(jaxpr, in_tree_def, args):
args, consts, freevar_vals, ct = args
args, ct, freevar_vals = build_tree(in_tree_def, (args, ct, freevar_vals))
freevar_cts, cotangents_out = yield jaxpr, consts, freevar_vals, args, ct
freevar_cts, cotangents_out = yield (jaxpr, consts, freevar_vals, args, ct), {}
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
yield out_jtuple, tree_def
@ -428,7 +428,7 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
def transposed_mapped(jaxpr, in_tree_def, freevar_vals, args):
args, consts, ct = args
args, ct = build_tree(in_tree_def, (args, ct))
freevar_cts, cotangents_out = yield jaxpr, consts, freevar_vals, args, ct
freevar_cts, cotangents_out = yield (jaxpr, consts, freevar_vals, args, ct), {}
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
yield out_jtuple, tree_def

View File

@ -51,7 +51,7 @@ def batch_transform(size, in_dims, out_dim_dst, vals):
with new_master(BatchTrace) as master:
trace = BatchTrace(master, core.cur_sublevel())
in_tracers = map(partial(BatchTracer, trace), vals, in_dims)
out_tracer = yield in_tracers
out_tracer = yield in_tracers, {}
out_tracer = trace.full_raise(out_tracer)
out_val, out_dim = out_tracer.val, out_tracer.batch_dim
del master
@ -61,7 +61,7 @@ def batch_transform(size, in_dims, out_dim_dst, vals):
@transformation_with_aux
def batch_subtrace(master, dims, *vals):
trace = BatchTrace(master, core.cur_sublevel())
ans = yield map(partial(BatchTracer, trace), vals, dims)
ans = yield map(partial(BatchTracer, trace), vals, dims), {}
out_tracer = trace.full_raise(ans)
out_val, out_dim = out_tracer.val, out_tracer.batch_dim
yield out_val, out_dim

View File

@ -56,7 +56,7 @@ def serial_pmap_transform(name, axes, *vals):
with new_master(SerialPmapTrace) as master:
trace = SerialPmapTrace(master, core.cur_sublevel())
in_tracers = map(partial(SerialPmapTracer, trace, name), vals, axes)
ans = yield in_tracers
ans = yield in_tracers, {}
out_tracer = trace.full_raise(ans)
out_val, out_axis = out_tracer.val, out_tracer.axis
del master, out_tracer
@ -65,7 +65,7 @@ def serial_pmap_transform(name, axes, *vals):
@lu.transformation_with_aux
def serial_pmap_subtrace(master, name, axes, *vals):
trace = SerialPmapTrace(master, core.cur_sublevel())
ans = yield map(partial(SerialPmapTracer, trace, name), vals, axes)
ans = yield map(partial(SerialPmapTracer, trace, name), vals, axes), {}
out_tracer = trace.full_raise(ans)
out_val, out_axis = out_tracer.val, out_tracer.axis
yield out_val, out_axis
@ -205,7 +205,7 @@ def papply_transform(name, args, axis_size, in_axes, out_axis):
with new_master(PapplyTrace) as master:
trace = PapplyTrace(master, core.cur_sublevel())
in_tracers = map(partial(PapplyTracer, trace, name, axis_size), args, in_axes)
out_tracer = yield in_tracers
out_tracer = yield in_tracers, {}
out_tracer = trace.full_raise(out_tracer)
out_tracer = ensure_axis(out_axis, out_tracer.axis, out_tracer)
out_val = out_tracer.val

View File

@ -170,8 +170,9 @@ def partial_eval(f, trace, pvs):
@transformation_with_aux
def partial_eval_wrapper(avals, *consts, **kwargs):
jaxpr, (out_pval, consts, env) = yield (map(PartialVal, zip(avals, consts)),)
def partial_eval_wrapper(avals, *consts):
py_args = (map(PartialVal, zip(avals, consts)),)
jaxpr, (out_pval, consts, env) = yield py_args, {}
out_pv, out_const = out_pval
out = pack((out_const, pack(consts)))
yield out, (out_pv, jaxpr, env)
@ -334,13 +335,13 @@ def abstractify(x):
return PartialVal((core.concrete_aval(x), unit))
def trace_unwrapped_to_jaxpr(fun, pvals, **kwargs):
return trace_to_jaxpr(lu.wrap_init(fun), pvals, **kwargs)
return trace_to_jaxpr(lu.wrap_init(fun, kwargs), pvals)
def trace_to_jaxpr(fun, pvals, **kwargs):
def trace_to_jaxpr(fun, pvals):
"""Traces a function, given abstract inputs, to a jaxpr."""
with new_master(JaxprTrace) as master:
fun = trace_to_subjaxpr(fun, master)
jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals, **kwargs)
jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
assert not env
del master
@ -351,7 +352,7 @@ def trace_to_subjaxpr(master, pvals):
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
trace = JaxprTrace(master, core.cur_sublevel())
in_tracers = map(trace.new_arg, pvals)
out_tracer = yield in_tracers
out_tracer = yield in_tracers, {}
out_tracer = trace.full_raise(out_tracer)
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracer)
out_pval = out_tracer.pval
@ -464,7 +465,7 @@ def eval_jaxpr_raw(jaxpr, consts, freevar_vals, *args):
map(write, eqn.outvars, outvals)
return read(jaxpr.outvar)
def compiled_call_impl(fun, *args, **kwargs):
def compiled_call_impl(fun, *args):
with new_master(JaxprTrace, True) as master:
pvals = map(abstractify, args)
jaxpr, (pval, consts, env) = trace_to_subjaxpr(fun, master).call_wrapped(pvals)

View File

@ -412,7 +412,7 @@ def xla_shape(x):
@lu.transformation_with_aux
def flatten_fun(in_trees, *flat_args):
jtuple_trees = tuple(map(partial(build_tree, iter(flat_args)), in_trees))
ans = yield jtuple_trees
ans = yield jtuple_trees, {}
aval = core.get_aval(ans)
if type(aval) is AbstractTuple:
ans_flat, out_tree = tree_flatten(ans)

View File

@ -120,15 +120,15 @@ class WrappedFun(object):
f: the function to be transformed.
transforms: a list of `(gen, gen_args, out_store)` tuples representing
transformations to apply to `f.`
kwargs: keyword arguments to pass to `f`.
params: extra parameters to pass as keyword arguments to `f`.
"""
def __init__(self, f, transforms, kwargs):
def __init__(self, f, transforms, params):
self.f = f
self.transforms = transforms
self.kwargs = kwargs
self.params = params
def wrap(self, *transformation):
return WrappedFun(self.f, [transformation] + self.transforms, self.kwargs)
return WrappedFun(self.f, [transformation] + self.transforms, self.params)
def populate_stores(self, other):
for (_, _, self_store), (_, _, other_store) in zip(self.transforms,
@ -136,15 +136,16 @@ class WrappedFun(object):
if self_store is not None:
self_store.store(other_store.val)
def call_wrapped(self, *args):
def call_wrapped(self, *args, **kwargs):
stack = []
for gen, gen_args, out_store in self.transforms:
gen = gen(*(gen_args + tuple(args)))
args = next(gen)
gen = gen(*(gen_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
assert type(args) in (tuple, list) and type(kwargs) is dict
stack.append((gen, out_store))
del gen
ans = self.f(*args, **self.kwargs)
ans = self.f(*args, **dict(self.params, **kwargs))
del args
while stack:
gen, out_store = stack.pop()
@ -165,7 +166,7 @@ class WrappedFun(object):
def hashable_payload(self):
return (self.f,
tuple((gen, tuple(gen_args)) for gen, gen_args, _ in self.transforms),
tuple(sorted(self.kwargs.items())))
tuple(sorted(self.params.items())))
def __hash__(self):
return hash(self.hashable_payload())
@ -189,9 +190,9 @@ def fun_name(f):
except:
return str(f)
def wrap_init(f, kwargs={}):
def wrap_init(f, params={}):
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
return WrappedFun(f, [], kwargs)
return WrappedFun(f, [], params)
def memoize(call, max_size=4096):

View File

@ -158,7 +158,6 @@ def threefry_2x32(keypair, count):
return lax.reshape(out[:-1] if odd_size else out, count.shape)
@partial(jit, static_argnums=(1,))
def split(key, num=2):
"""Splits a PRNG key into `num` new keys by adding a leading axis.
@ -170,11 +169,14 @@ def split(key, num=2):
Returns:
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
"""
return _split(key, num)
@partial(jit, static_argnums=(1,))
def _split(key, num):
counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
return lax.reshape(threefry_2x32(key, counts), (num, 2))
@partial(jit, static_argnums=(1,))
def fold_in(key, data):
"""Folds in data to a PRNG key to form a new PRNG key.
@ -186,6 +188,10 @@ def fold_in(key, data):
A new PRNGKey that is a deterministic function of the inputs and is
statistically safe for producing a stream of new pseudo-random values.
"""
return _fold_in(key, data)
@partial(jit, static_argnums=(1,))
def _fold_in(key, data):
key2 = lax.tie_in(key, PRNGKey(data))
return threefry_2x32(key, key2)
@ -212,7 +218,6 @@ def _random_bits(key, bit_width, shape):
### random samplers
@partial(jit, static_argnums=(1, 2))
def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
@ -226,6 +231,10 @@ def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
Returns:
A random array with the specified shape and dtype.
"""
return _uniform(key, shape, dtype, minval, maxval)
@partial(jit, static_argnums=(1, 2))
def _uniform(key, shape, dtype, minval, maxval):
if not onp.issubdtype(dtype, onp.floating):
raise TypeError("uniform only accepts floating point dtypes.")
@ -253,7 +262,6 @@ def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
lax.reshape(floats * (maxval - minval) + minval, shape))
@partial(jit, static_argnums=(1, 4))
def randint(key, shape, minval, maxval, dtype=onp.int32):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
@ -267,6 +275,10 @@ def randint(key, shape, minval, maxval, dtype=onp.int32):
Returns:
A random array with the specified shape and dtype.
"""
return _randint(key, shape, minval, maxval, dtype)
@partial(jit, static_argnums=(1, 4))
def _randint(key, shape, minval, maxval, dtype=onp.int32):
if not onp.issubdtype(dtype, onp.integer):
raise TypeError("randint only accepts integer dtypes.")
@ -306,7 +318,6 @@ def randint(key, shape, minval, maxval, dtype=onp.int32):
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
@partial(jit, static_argnums=(2,))
def shuffle(key, x, axis=0):
"""Shuffle the elements of an array uniformly at random along an axis.
@ -318,6 +329,10 @@ def shuffle(key, x, axis=0):
Returns:
A shuffled version of x.
"""
return _shuffle(key, x, axis)
@partial(jit, static_argnums=(2,))
def _shuffle(key, x, axis):
# On parallel architectures, Fisher-Yates is more expensive than doing
# multiple sorts. This algorithm is based on one developed and analyzed by
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
@ -344,7 +359,6 @@ def shuffle(key, x, axis=0):
return x
@partial(jit, static_argnums=(1, 2))
def normal(key, shape, dtype=onp.float32):
"""Sample standard normal random values with given shape and float dtype.
@ -356,13 +370,16 @@ def normal(key, shape, dtype=onp.float32):
Returns:
A random array with the specified shape and dtype.
"""
return _normal(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _normal(key, shape, dtype):
lo = onp.nextafter(onp.array(-1., dtype), 0., dtype=dtype)
hi = onp.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi)
return onp.array(onp.sqrt(2), dtype) * lax.erf_inv(u)
@partial(jit, static_argnums=(2,))
def bernoulli(key, mean=onp.float32(0.5), shape=()):
"""Sample Bernoulli random values with given shape and mean.
@ -376,6 +393,10 @@ def bernoulli(key, mean=onp.float32(0.5), shape=()):
Returns:
A random array with the specified shape and boolean dtype.
"""
return _bernoulli(key, mean, shape)
@partial(jit, static_argnums=(2,))
def _bernoulli(key, mean, shape):
shape = shape or onp.shape(mean)
if not onp.issubdtype(lax._dtype(mean), onp.float32):
mean = lax.convert_element_type(mean, onp.float32)
@ -384,7 +405,6 @@ def bernoulli(key, mean=onp.float32(0.5), shape=()):
return lax.lt(uniform(key, shape), mean)
@partial(jit, static_argnums=(1, 2))
def cauchy(key, shape=(), dtype=onp.float32):
"""Sample Cauchy random values with given shape and float dtype.
@ -397,12 +417,15 @@ def cauchy(key, shape=(), dtype=onp.float32):
Returns:
A random array with the specified shape and dtype.
"""
return _cauchy(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _cauchy(key, shape, dtype):
u = uniform(key, shape, dtype)
pi = _constant_like(u, onp.pi)
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))
@partial(jit, static_argnums=(1, 2))
def exponential(key, shape=(), dtype=onp.float32):
"""Sample Exponential random values with given shape and float dtype.
@ -415,12 +438,15 @@ def exponential(key, shape=(), dtype=onp.float32):
Returns:
A random array with the specified shape and dtype.
"""
return _exponential(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _exponential(key, shape, dtype):
u = uniform(key, shape, dtype)
# taking 1 - u to move the domain of log to (0, 1] instead of [0, 1)
return lax.neg(lax.log(lax.sub(_constant_like(u, 1), u)))
@partial(jit, static_argnums=(1, 2))
def laplace(key, shape=(), dtype=onp.float32):
"""Sample Laplace random values with given shape and float dtype.
@ -433,11 +459,14 @@ def laplace(key, shape=(), dtype=onp.float32):
Returns:
A random array with the specified shape and dtype.
"""
return _laplace(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _laplace(key, shape, dtype):
u = uniform(key, shape, dtype, minval=-1., maxval=1.)
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
@partial(jit, static_argnums=(2, 3))
def pareto(key, b, shape=(), dtype=onp.float32):
"""Sample Pareto random values with given shape and float dtype.
@ -452,6 +481,10 @@ def pareto(key, b, shape=(), dtype=onp.float32):
Returns:
A random array with the specified shape and dtype.
"""
return _pareto(key, b, shape, dtype)
@partial(jit, static_argnums=(2, 3))
def _pareto(key, b, shape, dtype):
b = lax.convert_element_type(b, dtype)
shape = shape or onp.shape(b)
if onp.shape(b) != shape:
@ -460,7 +493,6 @@ def pareto(key, b, shape=(), dtype=onp.float32):
return lax.exp(lax.div(e, b))
@partial(jit, static_argnums=(1, 2))
def gumbel(key, shape=(), dtype=onp.float32):
"""Sample Gumbel random values with given shape and float dtype.
@ -473,4 +505,8 @@ def gumbel(key, shape=(), dtype=onp.float32):
Returns:
A random array with the specified shape and dtype.
"""
return _gumbel(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
def _gumbel(key, shape, dtype):
return -np.log(-np.log(uniform(key, shape, dtype)))

View File

@ -62,29 +62,44 @@ class APITest(jtu.JaxTestCase):
side.append(None)
return 100*x + 10*y + z
f1 = jit(f)
assert f1(1, 2, 3, flag=True) == 123
f1 = jit(f, static_argnums=(3, 4))
assert f1(1, 2, 3, True, False) == 123
assert len(side) == 1
assert f1(2, 1, 3, flag=True) == 213
assert f1(2, 1, 3, True, False) == 213
assert len(side) == 1
assert f1(2, 1, 3, flag=True, flag2=True) == 213
assert f1(2, 1, 3, True, True) == 213
assert len(side) == 2
side[:] = []
f2 = jit(f, static_argnums=[0,2])
assert f2(1, 2, 3, flag=True) == 123
f2 = jit(f, static_argnums=(0, 2, 3, 4))
assert f2(1, 2, 3, True, False) == 123
assert len(side) == 1
assert f2(1, 3, 3, flag=True) == 133
assert f2(1, 3, 3, True, False) == 133
assert len(side) == 1
assert f2(2, 2, 3, flag=True) == 223
assert f2(2, 2, 3, True, False) == 223
assert len(side) == 2
assert f2(2, 4, 3, flag=True) == 243
assert f2(2, 4, 3, True, False) == 243
assert len(side) == 2
assert f2(2, 4, 3, flag=True, flag2=True) == 243
assert f2(2, 4, 3, True, True) == 243
assert len(side) == 3
assert f2(2, 5, 3, flag=True, flag2=True) == 253
assert f2(2, 5, 3, True, True) == 253
assert len(side) == 3
def test_jit_kwargs(self):
side = []
def f(x, y, z):
side.append(None)
return 100*x + 10*y + z
f1 = jit(f)
assert f(1, 2, 3) == 123
assert len(side) == 1
assert f(1, 2, z=3) == 123
# assert len(side) == 1 # actually recompiles
f(1, 2, z=onp.zeros(3)) # doesn't crash
def test_grad_of_jit(self):
side = []

View File

@ -26,7 +26,7 @@ import jax.numpy as np
from jax import test_util as jtu
from jax import lax
from jax import lax_parallel
from jax.api import serial_pmap, papply, jit, make_jaxpr
from jax.api import _serial_pmap, _papply, jit, make_jaxpr
from jax.linear_util import wrap_init
from jax.config import config
@ -37,48 +37,48 @@ class SerialPmapTest(jtu.JaxTestCase):
def testConstantFunction(self):
f = lambda x: 3
ans = serial_pmap(f, axis_name='i')(onp.ones(4))
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
expected = 3 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceSum(self):
f = lambda x: lax_parallel.psum(x, 'i')
ans = serial_pmap(f, axis_name='i')(onp.ones(4))
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
expected = 4 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceMax(self):
f = lambda x: lax_parallel.pmax(x, 'i')
ans = serial_pmap(f, axis_name='i')(onp.arange(4))
ans = _serial_pmap(f, axis_name='i')(onp.arange(4))
expected = 3 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsplit(self):
f = lambda x: lax_parallel.psplit(x, 'i', 2)
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
ans = serial_pmap(f, axis_name='i', out_axes=2)(arg)
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
expected = arg
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsplitLike(self):
f = lambda x, y: lax_parallel.psplit_like(x, y, 'i')
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
ans = serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
expected = arg
self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self):
f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i'))
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
ans = serial_pmap(f, axis_name='i')(x)
ans = _serial_pmap(f, axis_name='i')(x)
expected = x - onp.log(onp.sum(onp.exp(x)))
self.assertAllClose(ans, expected, check_dtypes=False)
def testNested(self):
f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j')
x = onp.ones((2, 2))
ans1 = serial_pmap(serial_pmap(f, 'i'), 'j')(x)
ans2 = serial_pmap(serial_pmap(f, 'j'), 'i')(x)
ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x)
ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x)
expected = 4 * onp.ones((2, 2))
self.assertAllClose(ans1, expected, check_dtypes=False)
self.assertAllClose(ans2, expected, check_dtypes=False)
@ -87,19 +87,19 @@ class SerialPmapTest(jtu.JaxTestCase):
class PapplyTest(jtu.JaxTestCase):
def testIdentity(self):
pfun, axis_name = papply(lambda x: x, 3)
pfun, axis_name = _papply(lambda x: x, 3)
ans = pfun(onp.arange(3))
expected = onp.arange(3)
self.assertAllClose(ans, expected, check_dtypes=False)
def testMap(self):
pfun, axis_name = papply(np.sin, 3)
pfun, axis_name = _papply(np.sin, 3)
ans = pfun(onp.arange(3.))
expected = onp.sin(onp.arange(3.))
self.assertAllClose(ans, expected, check_dtypes=False)
def testSum(self):
pfun, axis_name = papply(lambda x: np.sum(x, axis=0), 5)
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0), 5)
jaxpr = make_jaxpr(pfun)(onp.ones(3))
expected_jaxpr = make_jaxpr(
@ -107,12 +107,12 @@ class PapplyTest(jtu.JaxTestCase):
assert repr(jaxpr) == repr(expected_jaxpr)
arg = onp.arange(15.).reshape((5, 3))
ans = serial_pmap(pfun, axis_name)(arg)[0]
ans = _serial_pmap(pfun, axis_name)(arg)[0]
expected = onp.sum(arg, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testMax(self):
pfun, axis_name = papply(lambda x: np.max(x, axis=0), 5)
pfun, axis_name = _papply(lambda x: np.max(x, axis=0), 5)
jaxpr = make_jaxpr(pfun)(onp.ones(3))
expected_jaxpr = make_jaxpr(
@ -120,12 +120,12 @@ class PapplyTest(jtu.JaxTestCase):
assert repr(jaxpr) == repr(expected_jaxpr)
arg = onp.arange(15.).reshape((5, 3))
ans = serial_pmap(pfun, axis_name)(arg)[0]
ans = _serial_pmap(pfun, axis_name)(arg)[0]
expected = onp.max(arg, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testSelect(self):
pfun, axis_name = papply(lax.select, 5,
pfun, axis_name = _papply(lax.select, 5,
in_axes=(None, 0, None))
p = onp.arange(15).reshape((5, 3)) % 4 == 1
@ -142,7 +142,7 @@ class PapplyTest(jtu.JaxTestCase):
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
assert repr(jaxpr) == repr(expected_jaxpr)
ans = serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
expected = lax.select(p, t, f)
self.assertAllClose(ans, expected, check_dtypes=True)
@ -152,14 +152,14 @@ class PapplyTest(jtu.JaxTestCase):
def fun(x):
return x - np.log(np.sum(np.exp(x)))
pfun, axis_name = papply(fun, 5)
pfun, axis_name = _papply(fun, 5)
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
expected_jaxpr = make_jaxpr(
lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5))
assert repr(jaxpr) == repr(expected_jaxpr)
ans = serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
expected = fun(onp.arange(1., 5.))
self.assertAllClose(ans, expected, check_dtypes=False)
@ -167,8 +167,8 @@ class PapplyTest(jtu.JaxTestCase):
x = onp.array([[1, 2, 3], [4, 5, 6]])
expected = x + x
pfun, axis_name = papply(np.add, 2)
ans = serial_pmap(pfun, axis_name)(x, x)
pfun, axis_name = _papply(np.add, 2)
ans = _serial_pmap(pfun, axis_name)(x, x)
self.assertAllClose(ans, expected, check_dtypes=True)
def testAddBroadcasting(self):
@ -180,8 +180,8 @@ class PapplyTest(jtu.JaxTestCase):
x = onp.array([[1, 2], [3, 4]])
expected = x + 3
pfun, axis_name = papply(fun, 2)
ans = serial_pmap(pfun, axis_name)(x)
pfun, axis_name = _papply(fun, 2)
ans = _serial_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=True)
def testTranspose(self):
@ -195,8 +195,8 @@ class PapplyTest(jtu.JaxTestCase):
]
for x in xs:
expected = x.T
pfun, axis_name = papply(fun, x.shape[0])
ans = serial_pmap(pfun, axis_name)(x)
pfun, axis_name = _papply(fun, x.shape[0])
ans = _serial_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeWithOddPermutation(self):
@ -210,8 +210,8 @@ class PapplyTest(jtu.JaxTestCase):
]
for x in xs:
expected = np.transpose(x, (2, 0, 1))
pfun, axis_name = papply(fun, x.shape[0])
ans = serial_pmap(pfun, axis_name)(x)
pfun, axis_name = _papply(fun, x.shape[0])
ans = _serial_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeAndAddRank2(self):
@ -222,8 +222,8 @@ class PapplyTest(jtu.JaxTestCase):
x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2))
expected = x + x.T
pfun, axis_name = papply(fun, 2)
ans = serial_pmap(pfun, axis_name)(x)
pfun, axis_name = _papply(fun, 2)
ans = _serial_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeAndAddRank3(self):
@ -234,8 +234,8 @@ class PapplyTest(jtu.JaxTestCase):
x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
expected = x + x.T
pfun, axis_name = papply(fun, 2)
ans = serial_pmap(pfun, axis_name)(x)
pfun, axis_name = _papply(fun, 2)
ans = _serial_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testDot(self):
@ -251,8 +251,8 @@ class PapplyTest(jtu.JaxTestCase):
for in_axes in in_axes_combos:
for x in xs:
expected = fun(x, x)
pfun, axis_name = papply(fun, x.shape[0], in_axes=in_axes)
ans = serial_pmap(pfun, axis_name)(x, x)
pfun, axis_name = _papply(fun, x.shape[0], in_axes=in_axes)
ans = _serial_pmap(pfun, axis_name)(x, x)
self.assertAllClose(ans, expected, check_dtypes=False)