mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
d573084783
commit
9c2e1c35b1
80
jax/api.py
80
jax/api.py
@ -37,7 +37,8 @@ from . import core
|
||||
from . import linear_util as lu
|
||||
from .core import pack, eval_jaxpr
|
||||
from .api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
|
||||
pytree_fun_to_flatjaxtuple_fun, apply_jaxtree_fun, wraps)
|
||||
pytree_fun_to_flatjaxtuple_fun, apply_jaxtree_fun, wraps,
|
||||
pytree_fun_to_jaxtupletree_fun2)
|
||||
from .tree_util import (process_pytree, node_types, build_tree, PyTreeDef,
|
||||
tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
tree_transpose, leaf)
|
||||
@ -69,12 +70,13 @@ def jit(fun, static_argnums=()):
|
||||
fun: Function to be jitted. Should be a pure function, as side-effects may
|
||||
only be executed once. Its positional arguments and return value should be
|
||||
arrays, scalars, or standard Python containers (tuple/list/dict) thereof.
|
||||
Keyword arguments and positional arguments specified by `static_argnums`
|
||||
can be anything at all. These are treated as static (see below).
|
||||
static_argnums: A tuple of ints. Specifies which arguments to treat as
|
||||
static (compile-time constant). Operations that only depend on static
|
||||
arguments will be constant-folded. Calling the jitted function with
|
||||
different values for these constants will trigger recompilation.
|
||||
Positional arguments indicated by `static_argnums` can be anything at all.
|
||||
static_argnums: A tuple of ints. Specifies which positional arguments to
|
||||
treat as static (compile-time constant). Operations that only depend on
|
||||
static arguments will be constant-folded. Calling the jitted function with
|
||||
different values for these constants will trigger recompilation. If the
|
||||
jitted function is called with fewer positional arguments than indicated
|
||||
by `static_argnums` then an error is raised.
|
||||
|
||||
Returns:
|
||||
A wrapped version of `fun`, set up for just-in-time compilation.
|
||||
@ -98,16 +100,22 @@ def jit(fun, static_argnums=()):
|
||||
def f_jitted(*args, **kwargs):
|
||||
if _jit_is_disabled or config.read('jax_disable_jit'):
|
||||
return fun(*args, **kwargs)
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
if static_argnums and max(static_argnums) >= len(args):
|
||||
msg = ("jitted function has static_argnums={} but was called with only {}"
|
||||
" positional arguments")
|
||||
raise TypeError(msg.format(static_argnums, len(args)))
|
||||
f = lu.wrap_init(fun)
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
||||
f, dyn_args = _argnums_partial(f, dyn_argnums, args)
|
||||
jaxtupletree_args, in_trees = unzip2(map(pytree_to_jaxtupletree, dyn_args))
|
||||
_check_args(jaxtupletree_args)
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
||||
jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
|
||||
return build_tree(out_tree(), jaxtupletree_out)
|
||||
jaxtuple_args, in_trees = unzip2(map(pytree_to_jaxtupletree, dyn_args))
|
||||
jaxtuple_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
|
||||
_check_args(jaxtuple_args)
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun2(f, kwargs_tree, in_trees)
|
||||
out = xla.xla_call(jaxtree_fun, jaxtuple_kwargs, *jaxtuple_args)
|
||||
return build_tree(out_tree(), out)
|
||||
|
||||
f_jitted.__name__ = "jit({})".format(f_jitted.__name__)
|
||||
jitted_name = "jit({}, static_argnums={})"
|
||||
f_jitted.__name__ = jitted_name.format(f_jitted.__name__, static_argnums)
|
||||
return f_jitted
|
||||
|
||||
|
||||
@ -157,15 +165,16 @@ def xla_computation(fun, static_argnums=()):
|
||||
aval = xla.abstractify(x)
|
||||
return pe.PartialVal((aval, core.unit))
|
||||
|
||||
|
||||
@wraps(fun)
|
||||
def computation_maker(*args, **kwargs):
|
||||
wrapped = lu.wrap_init(fun)
|
||||
jax_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
|
||||
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
|
||||
pvals = map(pv_like, jax_args)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
|
||||
return xla.build_jaxpr(jaxpr, consts, *map(xla.abstractify, jax_args))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun2(wrapped, kwargs_tree, in_trees)
|
||||
pvals = map(pv_like, (jax_kwargs,) + tuple(jax_args))
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
||||
return xla.build_jaxpr(jaxpr, consts, xla.abstractify(jax_kwargs),
|
||||
*map(xla.abstractify, jax_args))
|
||||
|
||||
return computation_maker
|
||||
|
||||
@ -436,19 +445,20 @@ def pmap(fun, axis_name=None):
|
||||
raise TypeError(msg.format(axis_sizes))
|
||||
axis_size = axis_sizes.pop()
|
||||
|
||||
jaxtupletree_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
_check_args(jaxtupletree_args)
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
||||
jaxtupletree_out = pxla.xla_pmap(f, *jaxtupletree_args,
|
||||
axis_name=axis_name, axis_size=axis_size)
|
||||
return build_tree(out_tree(), jaxtupletree_out)
|
||||
f = lu.wrap_init(fun)
|
||||
jaxtuple_kwargs, kwargs_tree = pytree_to_jaxtupletree(kwargs)
|
||||
jaxtuple_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
_check_args(jaxtuple_args)
|
||||
f, out_tree = pytree_fun_to_jaxtupletree_fun2(f, kwargs_tree, in_trees)
|
||||
out = pxla.xla_pmap(f, jaxtuple_kwargs, *jaxtuple_args,
|
||||
axis_name=axis_name, axis_size=axis_size)
|
||||
return build_tree(out_tree(), out)
|
||||
|
||||
namestr = "pmap({}, axis_name={})".format
|
||||
f_jitted.__name__ = namestr(f_jitted.__name__, axis_name)
|
||||
return f_jitted
|
||||
|
||||
def serial_pmap(fun, axis_name=None, in_axes=0, out_axes=0):
|
||||
def _serial_pmap(fun, axis_name=None, in_axes=0, out_axes=0):
|
||||
"""Vectorizing pseudo-map for single-program multiple-data (SPMD) functions."""
|
||||
axis_name = _TempAxisName() if axis_name is None else axis_name
|
||||
|
||||
@ -467,7 +477,7 @@ class _TempAxisName(object):
|
||||
return '<temp axis {}>'.format(hex(id(self)))
|
||||
|
||||
|
||||
def papply(fun, axis_size, in_axes=0, out_axes=0):
|
||||
def _papply(fun, axis_size, in_axes=0, out_axes=0):
|
||||
"""Apply a function using parallel computation by sharding inputs."""
|
||||
axis_name = parallel.newvar()
|
||||
|
||||
@ -652,10 +662,10 @@ def vjp(fun, *primals, **kwargs):
|
||||
|
||||
|
||||
def trace_to_jaxpr(traceable, py_pvals, **kwargs):
|
||||
fun = lu.wrap_init(traceable)
|
||||
fun = lu.wrap_init(traceable, kwargs)
|
||||
pvals, in_trees = unzip2(map(tree_to_pval_tuples, py_pvals))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
|
||||
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
|
||||
jaxpr, out_pval, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
||||
return jaxpr, consts, out_pval, (in_trees, out_tree())
|
||||
|
||||
def lift_jaxpr(jaxpr, consts, io_tree, pvals, py_args):
|
||||
@ -714,11 +724,11 @@ def make_jaxpr(fun):
|
||||
|
||||
@wraps(fun)
|
||||
def jaxpr_maker(*args, **kwargs):
|
||||
wrapped = lu.wrap_init(fun)
|
||||
wrapped = lu.wrap_init(fun, kwargs)
|
||||
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
|
||||
pvals = map(pv_like, jax_args)
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
||||
return jaxpr
|
||||
|
||||
jaxpr_maker.__name__ = "make_jaxpr({})".format(jaxpr_maker.__name__)
|
||||
@ -751,7 +761,7 @@ def _argnums_partial_(dyn_argnums, fixed_args, *dyn_args):
|
||||
args = [None if arg is None else arg.val for arg in fixed_args]
|
||||
for i, arg in zip(dyn_argnums, dyn_args):
|
||||
args[i] = arg
|
||||
ans = yield args
|
||||
ans = yield args, {}
|
||||
yield ans
|
||||
|
||||
def _check_args(args):
|
||||
@ -863,11 +873,11 @@ def make_graphviz(fun):
|
||||
|
||||
@wraps(fun)
|
||||
def graphviz_maker(*args, **kwargs):
|
||||
wrapped = lu.wrap_init(fun)
|
||||
wrapped = lu.wrap_init(fun, kwargs)
|
||||
jax_args, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(wrapped, in_trees)
|
||||
pvals = map(pv_like, jax_args)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, **kwargs)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
||||
return jaxpr_to_graphviz(jaxpr, consts)
|
||||
|
||||
graphviz_maker.__name__ = "make_graphviz({})".format(graphviz_maker.__name__)
|
||||
|
@ -39,11 +39,17 @@ def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
|
||||
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
|
||||
def get_doc(fun): return getattr(fun, "__doc__", "")
|
||||
|
||||
@transformation_with_aux
|
||||
def pytree_fun_to_jaxtupletree_fun(args_trees, *args):
|
||||
py_args = map(build_tree, args_trees, args)
|
||||
ans = yield py_args, {}
|
||||
yield pytree_to_jaxtupletree(ans)
|
||||
|
||||
@transformation_with_aux
|
||||
def pytree_fun_to_jaxtupletree_fun(in_trees, *args):
|
||||
py_args = map(build_tree, in_trees, args)
|
||||
ans = yield py_args
|
||||
def pytree_fun_to_jaxtupletree_fun2(kwargs_tree, args_trees, kwargs, *args):
|
||||
py_args = map(build_tree, args_trees, args)
|
||||
py_kwargs = build_tree(kwargs_tree, kwargs)
|
||||
ans = yield py_args, py_kwargs
|
||||
yield pytree_to_jaxtupletree(ans)
|
||||
|
||||
def apply_jaxtree_fun(fun, io_tree, *py_args):
|
||||
@ -62,7 +68,7 @@ pytree_to_jaxtupletree = partial(process_pytree, pack)
|
||||
@transformation_with_aux
|
||||
def pytree_fun_to_flatjaxtuple_fun(in_trees, *args):
|
||||
py_args = map(tree_unflatten, in_trees, args)
|
||||
ans = yield py_args
|
||||
ans = yield py_args, {}
|
||||
yield pytree_to_flatjaxtuple(ans)
|
||||
|
||||
def pytree_to_flatjaxtuple(pytree):
|
||||
|
@ -525,7 +525,7 @@ def apply_todos(todos, x):
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def process_env_traces(primitive, level, *args):
|
||||
ans = yield args
|
||||
ans = yield args, {}
|
||||
todo = []
|
||||
while isinstance(ans, Tracer) and ans.trace.level > level:
|
||||
t = ans.trace
|
||||
|
@ -39,5 +39,5 @@ def ravel_list(*lst):
|
||||
@transformation_with_aux
|
||||
def ravel_fun(unravel_inputs, flat_in, **kwargs):
|
||||
pytree_args = unravel_inputs(flat_in)
|
||||
ans = yield pytree_args
|
||||
ans = yield pytree_args, {}
|
||||
yield ravel_pytree(ans)
|
||||
|
@ -44,7 +44,7 @@ def jvp(fun, has_aux=False):
|
||||
@transformation
|
||||
def jvpfun(primals, tangents):
|
||||
with new_master(JVPTrace) as master:
|
||||
out_primal, out_tangent = yield master, primals, tangents
|
||||
out_primal, out_tangent = yield (master, primals, tangents), {}
|
||||
del master
|
||||
out_tangent = instantiate_zeros(out_primal, out_tangent)
|
||||
yield (out_primal, out_tangent)
|
||||
@ -55,7 +55,7 @@ def jvp_subtrace(master, primals, tangents):
|
||||
for x in list(primals) + list(tangents):
|
||||
if isinstance(x, Tracer):
|
||||
assert x.trace.level < trace.level
|
||||
ans = yield map(partial(JVPTracer, trace), primals, tangents)
|
||||
ans = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
||||
out_tracer = trace.full_raise(ans)
|
||||
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
|
||||
yield (out_primal, out_tangent)
|
||||
@ -66,7 +66,7 @@ def jvp_subtrace_aux(master, primals, tangents):
|
||||
for x in list(primals) + list(tangents):
|
||||
if isinstance(x, Tracer):
|
||||
assert x.trace.level < trace.level
|
||||
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents)
|
||||
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
|
||||
out_tracer, aux_tracer = map(trace.full_raise, (ans, aux))
|
||||
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
|
||||
aux = aux_tracer.primal # ignore aux tangent
|
||||
@ -75,7 +75,7 @@ def jvp_subtrace_aux(master, primals, tangents):
|
||||
|
||||
@transformation
|
||||
def pack_output(*args):
|
||||
ans = yield args
|
||||
ans = yield args, {}
|
||||
yield pack(ans)
|
||||
|
||||
def linearize(traceable, *primals, **kwargs):
|
||||
@ -399,7 +399,7 @@ def instantiate_zeros(example, tangent):
|
||||
@transformation_with_aux
|
||||
def traceable(in_tree_def, new_primals, new_tangents):
|
||||
new_tangents = build_tree(in_tree_def, new_tangents)
|
||||
primal_out, tangent_out = yield new_primals, new_tangents
|
||||
primal_out, tangent_out = yield (new_primals, new_tangents), {}
|
||||
out_jtuple, tree_def = tree_to_jaxtuples((primal_out, tangent_out))
|
||||
yield out_jtuple, tree_def
|
||||
|
||||
@ -407,7 +407,7 @@ def traceable(in_tree_def, new_primals, new_tangents):
|
||||
def transposed_fun(jaxpr, in_tree_def, args):
|
||||
args, consts, freevar_vals, ct = args
|
||||
args, ct, freevar_vals = build_tree(in_tree_def, (args, ct, freevar_vals))
|
||||
freevar_cts, cotangents_out = yield jaxpr, consts, freevar_vals, args, ct
|
||||
freevar_cts, cotangents_out = yield (jaxpr, consts, freevar_vals, args, ct), {}
|
||||
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
|
||||
yield out_jtuple, tree_def
|
||||
|
||||
@ -428,7 +428,7 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
def transposed_mapped(jaxpr, in_tree_def, freevar_vals, args):
|
||||
args, consts, ct = args
|
||||
args, ct = build_tree(in_tree_def, (args, ct))
|
||||
freevar_cts, cotangents_out = yield jaxpr, consts, freevar_vals, args, ct
|
||||
freevar_cts, cotangents_out = yield (jaxpr, consts, freevar_vals, args, ct), {}
|
||||
out_jtuple, tree_def = tree_to_jaxtuples((cotangents_out, freevar_cts))
|
||||
yield out_jtuple, tree_def
|
||||
|
||||
|
@ -51,7 +51,7 @@ def batch_transform(size, in_dims, out_dim_dst, vals):
|
||||
with new_master(BatchTrace) as master:
|
||||
trace = BatchTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(BatchTracer, trace), vals, in_dims)
|
||||
out_tracer = yield in_tracers
|
||||
out_tracer = yield in_tracers, {}
|
||||
out_tracer = trace.full_raise(out_tracer)
|
||||
out_val, out_dim = out_tracer.val, out_tracer.batch_dim
|
||||
del master
|
||||
@ -61,7 +61,7 @@ def batch_transform(size, in_dims, out_dim_dst, vals):
|
||||
@transformation_with_aux
|
||||
def batch_subtrace(master, dims, *vals):
|
||||
trace = BatchTrace(master, core.cur_sublevel())
|
||||
ans = yield map(partial(BatchTracer, trace), vals, dims)
|
||||
ans = yield map(partial(BatchTracer, trace), vals, dims), {}
|
||||
out_tracer = trace.full_raise(ans)
|
||||
out_val, out_dim = out_tracer.val, out_tracer.batch_dim
|
||||
yield out_val, out_dim
|
||||
|
@ -56,7 +56,7 @@ def serial_pmap_transform(name, axes, *vals):
|
||||
with new_master(SerialPmapTrace) as master:
|
||||
trace = SerialPmapTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(SerialPmapTracer, trace, name), vals, axes)
|
||||
ans = yield in_tracers
|
||||
ans = yield in_tracers, {}
|
||||
out_tracer = trace.full_raise(ans)
|
||||
out_val, out_axis = out_tracer.val, out_tracer.axis
|
||||
del master, out_tracer
|
||||
@ -65,7 +65,7 @@ def serial_pmap_transform(name, axes, *vals):
|
||||
@lu.transformation_with_aux
|
||||
def serial_pmap_subtrace(master, name, axes, *vals):
|
||||
trace = SerialPmapTrace(master, core.cur_sublevel())
|
||||
ans = yield map(partial(SerialPmapTracer, trace, name), vals, axes)
|
||||
ans = yield map(partial(SerialPmapTracer, trace, name), vals, axes), {}
|
||||
out_tracer = trace.full_raise(ans)
|
||||
out_val, out_axis = out_tracer.val, out_tracer.axis
|
||||
yield out_val, out_axis
|
||||
@ -205,7 +205,7 @@ def papply_transform(name, args, axis_size, in_axes, out_axis):
|
||||
with new_master(PapplyTrace) as master:
|
||||
trace = PapplyTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(PapplyTracer, trace, name, axis_size), args, in_axes)
|
||||
out_tracer = yield in_tracers
|
||||
out_tracer = yield in_tracers, {}
|
||||
out_tracer = trace.full_raise(out_tracer)
|
||||
out_tracer = ensure_axis(out_axis, out_tracer.axis, out_tracer)
|
||||
out_val = out_tracer.val
|
||||
|
@ -170,8 +170,9 @@ def partial_eval(f, trace, pvs):
|
||||
|
||||
|
||||
@transformation_with_aux
|
||||
def partial_eval_wrapper(avals, *consts, **kwargs):
|
||||
jaxpr, (out_pval, consts, env) = yield (map(PartialVal, zip(avals, consts)),)
|
||||
def partial_eval_wrapper(avals, *consts):
|
||||
py_args = (map(PartialVal, zip(avals, consts)),)
|
||||
jaxpr, (out_pval, consts, env) = yield py_args, {}
|
||||
out_pv, out_const = out_pval
|
||||
out = pack((out_const, pack(consts)))
|
||||
yield out, (out_pv, jaxpr, env)
|
||||
@ -334,13 +335,13 @@ def abstractify(x):
|
||||
return PartialVal((core.concrete_aval(x), unit))
|
||||
|
||||
def trace_unwrapped_to_jaxpr(fun, pvals, **kwargs):
|
||||
return trace_to_jaxpr(lu.wrap_init(fun), pvals, **kwargs)
|
||||
return trace_to_jaxpr(lu.wrap_init(fun, kwargs), pvals)
|
||||
|
||||
def trace_to_jaxpr(fun, pvals, **kwargs):
|
||||
def trace_to_jaxpr(fun, pvals):
|
||||
"""Traces a function, given abstract inputs, to a jaxpr."""
|
||||
with new_master(JaxprTrace) as master:
|
||||
fun = trace_to_subjaxpr(fun, master)
|
||||
jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals, **kwargs)
|
||||
jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
del master
|
||||
|
||||
@ -351,7 +352,7 @@ def trace_to_subjaxpr(master, pvals):
|
||||
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
||||
trace = JaxprTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(trace.new_arg, pvals)
|
||||
out_tracer = yield in_tracers
|
||||
out_tracer = yield in_tracers, {}
|
||||
out_tracer = trace.full_raise(out_tracer)
|
||||
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracer)
|
||||
out_pval = out_tracer.pval
|
||||
@ -464,7 +465,7 @@ def eval_jaxpr_raw(jaxpr, consts, freevar_vals, *args):
|
||||
map(write, eqn.outvars, outvals)
|
||||
return read(jaxpr.outvar)
|
||||
|
||||
def compiled_call_impl(fun, *args, **kwargs):
|
||||
def compiled_call_impl(fun, *args):
|
||||
with new_master(JaxprTrace, True) as master:
|
||||
pvals = map(abstractify, args)
|
||||
jaxpr, (pval, consts, env) = trace_to_subjaxpr(fun, master).call_wrapped(pvals)
|
||||
|
@ -412,7 +412,7 @@ def xla_shape(x):
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun(in_trees, *flat_args):
|
||||
jtuple_trees = tuple(map(partial(build_tree, iter(flat_args)), in_trees))
|
||||
ans = yield jtuple_trees
|
||||
ans = yield jtuple_trees, {}
|
||||
aval = core.get_aval(ans)
|
||||
if type(aval) is AbstractTuple:
|
||||
ans_flat, out_tree = tree_flatten(ans)
|
||||
|
@ -120,15 +120,15 @@ class WrappedFun(object):
|
||||
f: the function to be transformed.
|
||||
transforms: a list of `(gen, gen_args, out_store)` tuples representing
|
||||
transformations to apply to `f.`
|
||||
kwargs: keyword arguments to pass to `f`.
|
||||
params: extra parameters to pass as keyword arguments to `f`.
|
||||
"""
|
||||
def __init__(self, f, transforms, kwargs):
|
||||
def __init__(self, f, transforms, params):
|
||||
self.f = f
|
||||
self.transforms = transforms
|
||||
self.kwargs = kwargs
|
||||
self.params = params
|
||||
|
||||
def wrap(self, *transformation):
|
||||
return WrappedFun(self.f, [transformation] + self.transforms, self.kwargs)
|
||||
return WrappedFun(self.f, [transformation] + self.transforms, self.params)
|
||||
|
||||
def populate_stores(self, other):
|
||||
for (_, _, self_store), (_, _, other_store) in zip(self.transforms,
|
||||
@ -136,15 +136,16 @@ class WrappedFun(object):
|
||||
if self_store is not None:
|
||||
self_store.store(other_store.val)
|
||||
|
||||
def call_wrapped(self, *args):
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
stack = []
|
||||
for gen, gen_args, out_store in self.transforms:
|
||||
gen = gen(*(gen_args + tuple(args)))
|
||||
args = next(gen)
|
||||
gen = gen(*(gen_args + tuple(args)), **kwargs)
|
||||
args, kwargs = next(gen)
|
||||
assert type(args) in (tuple, list) and type(kwargs) is dict
|
||||
stack.append((gen, out_store))
|
||||
|
||||
del gen
|
||||
ans = self.f(*args, **self.kwargs)
|
||||
ans = self.f(*args, **dict(self.params, **kwargs))
|
||||
del args
|
||||
while stack:
|
||||
gen, out_store = stack.pop()
|
||||
@ -165,7 +166,7 @@ class WrappedFun(object):
|
||||
def hashable_payload(self):
|
||||
return (self.f,
|
||||
tuple((gen, tuple(gen_args)) for gen, gen_args, _ in self.transforms),
|
||||
tuple(sorted(self.kwargs.items())))
|
||||
tuple(sorted(self.params.items())))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.hashable_payload())
|
||||
@ -189,9 +190,9 @@ def fun_name(f):
|
||||
except:
|
||||
return str(f)
|
||||
|
||||
def wrap_init(f, kwargs={}):
|
||||
def wrap_init(f, params={}):
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
return WrappedFun(f, [], kwargs)
|
||||
return WrappedFun(f, [], params)
|
||||
|
||||
|
||||
def memoize(call, max_size=4096):
|
||||
|
@ -158,7 +158,6 @@ def threefry_2x32(keypair, count):
|
||||
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def split(key, num=2):
|
||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||
|
||||
@ -170,11 +169,14 @@ def split(key, num=2):
|
||||
Returns:
|
||||
An array with shape (num, 2) and dtype uint32 representing `num` new keys.
|
||||
"""
|
||||
return _split(key, num)
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def _split(key, num):
|
||||
counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
|
||||
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def fold_in(key, data):
|
||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||
|
||||
@ -186,6 +188,10 @@ def fold_in(key, data):
|
||||
A new PRNGKey that is a deterministic function of the inputs and is
|
||||
statistically safe for producing a stream of new pseudo-random values.
|
||||
"""
|
||||
return _fold_in(key, data)
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def _fold_in(key, data):
|
||||
key2 = lax.tie_in(key, PRNGKey(data))
|
||||
return threefry_2x32(key, key2)
|
||||
|
||||
@ -212,7 +218,6 @@ def _random_bits(key, bit_width, shape):
|
||||
### random samplers
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
|
||||
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
||||
|
||||
@ -226,6 +231,10 @@ def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _uniform(key, shape, dtype, minval, maxval)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _uniform(key, shape, dtype, minval, maxval):
|
||||
if not onp.issubdtype(dtype, onp.floating):
|
||||
raise TypeError("uniform only accepts floating point dtypes.")
|
||||
|
||||
@ -253,7 +262,6 @@ def uniform(key, shape, dtype=onp.float32, minval=0., maxval=1.):
|
||||
lax.reshape(floats * (maxval - minval) + minval, shape))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 4))
|
||||
def randint(key, shape, minval, maxval, dtype=onp.int32):
|
||||
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
||||
|
||||
@ -267,6 +275,10 @@ def randint(key, shape, minval, maxval, dtype=onp.int32):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _randint(key, shape, minval, maxval, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 4))
|
||||
def _randint(key, shape, minval, maxval, dtype=onp.int32):
|
||||
if not onp.issubdtype(dtype, onp.integer):
|
||||
raise TypeError("randint only accepts integer dtypes.")
|
||||
|
||||
@ -306,7 +318,6 @@ def randint(key, shape, minval, maxval, dtype=onp.int32):
|
||||
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def shuffle(key, x, axis=0):
|
||||
"""Shuffle the elements of an array uniformly at random along an axis.
|
||||
|
||||
@ -318,6 +329,10 @@ def shuffle(key, x, axis=0):
|
||||
Returns:
|
||||
A shuffled version of x.
|
||||
"""
|
||||
return _shuffle(key, x, axis)
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def _shuffle(key, x, axis):
|
||||
# On parallel architectures, Fisher-Yates is more expensive than doing
|
||||
# multiple sorts. This algorithm is based on one developed and analyzed by
|
||||
# tjablin@. We sort according to randomly-generated 32bit keys, but those keys
|
||||
@ -344,7 +359,6 @@ def shuffle(key, x, axis=0):
|
||||
return x
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def normal(key, shape, dtype=onp.float32):
|
||||
"""Sample standard normal random values with given shape and float dtype.
|
||||
|
||||
@ -356,13 +370,16 @@ def normal(key, shape, dtype=onp.float32):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _normal(key, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _normal(key, shape, dtype):
|
||||
lo = onp.nextafter(onp.array(-1., dtype), 0., dtype=dtype)
|
||||
hi = onp.array(1., dtype)
|
||||
u = uniform(key, shape, dtype, lo, hi)
|
||||
return onp.array(onp.sqrt(2), dtype) * lax.erf_inv(u)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def bernoulli(key, mean=onp.float32(0.5), shape=()):
|
||||
"""Sample Bernoulli random values with given shape and mean.
|
||||
|
||||
@ -376,6 +393,10 @@ def bernoulli(key, mean=onp.float32(0.5), shape=()):
|
||||
Returns:
|
||||
A random array with the specified shape and boolean dtype.
|
||||
"""
|
||||
return _bernoulli(key, mean, shape)
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def _bernoulli(key, mean, shape):
|
||||
shape = shape or onp.shape(mean)
|
||||
if not onp.issubdtype(lax._dtype(mean), onp.float32):
|
||||
mean = lax.convert_element_type(mean, onp.float32)
|
||||
@ -384,7 +405,6 @@ def bernoulli(key, mean=onp.float32(0.5), shape=()):
|
||||
return lax.lt(uniform(key, shape), mean)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def cauchy(key, shape=(), dtype=onp.float32):
|
||||
"""Sample Cauchy random values with given shape and float dtype.
|
||||
|
||||
@ -397,12 +417,15 @@ def cauchy(key, shape=(), dtype=onp.float32):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _cauchy(key, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _cauchy(key, shape, dtype):
|
||||
u = uniform(key, shape, dtype)
|
||||
pi = _constant_like(u, onp.pi)
|
||||
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def exponential(key, shape=(), dtype=onp.float32):
|
||||
"""Sample Exponential random values with given shape and float dtype.
|
||||
|
||||
@ -415,12 +438,15 @@ def exponential(key, shape=(), dtype=onp.float32):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _exponential(key, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _exponential(key, shape, dtype):
|
||||
u = uniform(key, shape, dtype)
|
||||
# taking 1 - u to move the domain of log to (0, 1] instead of [0, 1)
|
||||
return lax.neg(lax.log(lax.sub(_constant_like(u, 1), u)))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def laplace(key, shape=(), dtype=onp.float32):
|
||||
"""Sample Laplace random values with given shape and float dtype.
|
||||
|
||||
@ -433,11 +459,14 @@ def laplace(key, shape=(), dtype=onp.float32):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _laplace(key, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _laplace(key, shape, dtype):
|
||||
u = uniform(key, shape, dtype, minval=-1., maxval=1.)
|
||||
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2, 3))
|
||||
def pareto(key, b, shape=(), dtype=onp.float32):
|
||||
"""Sample Pareto random values with given shape and float dtype.
|
||||
|
||||
@ -452,6 +481,10 @@ def pareto(key, b, shape=(), dtype=onp.float32):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _pareto(key, b, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(2, 3))
|
||||
def _pareto(key, b, shape, dtype):
|
||||
b = lax.convert_element_type(b, dtype)
|
||||
shape = shape or onp.shape(b)
|
||||
if onp.shape(b) != shape:
|
||||
@ -460,7 +493,6 @@ def pareto(key, b, shape=(), dtype=onp.float32):
|
||||
return lax.exp(lax.div(e, b))
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def gumbel(key, shape=(), dtype=onp.float32):
|
||||
"""Sample Gumbel random values with given shape and float dtype.
|
||||
|
||||
@ -473,4 +505,8 @@ def gumbel(key, shape=(), dtype=onp.float32):
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
return _gumbel(key, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
def _gumbel(key, shape, dtype):
|
||||
return -np.log(-np.log(uniform(key, shape, dtype)))
|
||||
|
@ -62,29 +62,44 @@ class APITest(jtu.JaxTestCase):
|
||||
side.append(None)
|
||||
return 100*x + 10*y + z
|
||||
|
||||
f1 = jit(f)
|
||||
assert f1(1, 2, 3, flag=True) == 123
|
||||
f1 = jit(f, static_argnums=(3, 4))
|
||||
assert f1(1, 2, 3, True, False) == 123
|
||||
assert len(side) == 1
|
||||
assert f1(2, 1, 3, flag=True) == 213
|
||||
assert f1(2, 1, 3, True, False) == 213
|
||||
assert len(side) == 1
|
||||
assert f1(2, 1, 3, flag=True, flag2=True) == 213
|
||||
assert f1(2, 1, 3, True, True) == 213
|
||||
assert len(side) == 2
|
||||
|
||||
side[:] = []
|
||||
f2 = jit(f, static_argnums=[0,2])
|
||||
assert f2(1, 2, 3, flag=True) == 123
|
||||
f2 = jit(f, static_argnums=(0, 2, 3, 4))
|
||||
assert f2(1, 2, 3, True, False) == 123
|
||||
assert len(side) == 1
|
||||
assert f2(1, 3, 3, flag=True) == 133
|
||||
assert f2(1, 3, 3, True, False) == 133
|
||||
assert len(side) == 1
|
||||
assert f2(2, 2, 3, flag=True) == 223
|
||||
assert f2(2, 2, 3, True, False) == 223
|
||||
assert len(side) == 2
|
||||
assert f2(2, 4, 3, flag=True) == 243
|
||||
assert f2(2, 4, 3, True, False) == 243
|
||||
assert len(side) == 2
|
||||
assert f2(2, 4, 3, flag=True, flag2=True) == 243
|
||||
assert f2(2, 4, 3, True, True) == 243
|
||||
assert len(side) == 3
|
||||
assert f2(2, 5, 3, flag=True, flag2=True) == 253
|
||||
assert f2(2, 5, 3, True, True) == 253
|
||||
assert len(side) == 3
|
||||
|
||||
def test_jit_kwargs(self):
|
||||
side = []
|
||||
|
||||
def f(x, y, z):
|
||||
side.append(None)
|
||||
return 100*x + 10*y + z
|
||||
|
||||
f1 = jit(f)
|
||||
assert f(1, 2, 3) == 123
|
||||
assert len(side) == 1
|
||||
assert f(1, 2, z=3) == 123
|
||||
# assert len(side) == 1 # actually recompiles
|
||||
|
||||
f(1, 2, z=onp.zeros(3)) # doesn't crash
|
||||
|
||||
def test_grad_of_jit(self):
|
||||
side = []
|
||||
|
||||
|
@ -26,7 +26,7 @@ import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax import lax
|
||||
from jax import lax_parallel
|
||||
from jax.api import serial_pmap, papply, jit, make_jaxpr
|
||||
from jax.api import _serial_pmap, _papply, jit, make_jaxpr
|
||||
from jax.linear_util import wrap_init
|
||||
|
||||
from jax.config import config
|
||||
@ -37,48 +37,48 @@ class SerialPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testConstantFunction(self):
|
||||
f = lambda x: 3
|
||||
ans = serial_pmap(f, axis_name='i')(onp.ones(4))
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
|
||||
expected = 3 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReduceSum(self):
|
||||
f = lambda x: lax_parallel.psum(x, 'i')
|
||||
ans = serial_pmap(f, axis_name='i')(onp.ones(4))
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
|
||||
expected = 4 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReduceMax(self):
|
||||
f = lambda x: lax_parallel.pmax(x, 'i')
|
||||
ans = serial_pmap(f, axis_name='i')(onp.arange(4))
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.arange(4))
|
||||
expected = 3 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsplit(self):
|
||||
f = lambda x: lax_parallel.psplit(x, 'i', 2)
|
||||
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
|
||||
ans = serial_pmap(f, axis_name='i', out_axes=2)(arg)
|
||||
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
|
||||
expected = arg
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsplitLike(self):
|
||||
f = lambda x, y: lax_parallel.psplit_like(x, y, 'i')
|
||||
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
|
||||
ans = serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
|
||||
ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
|
||||
expected = arg
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testLogSoftmax(self):
|
||||
f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i'))
|
||||
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
|
||||
ans = serial_pmap(f, axis_name='i')(x)
|
||||
ans = _serial_pmap(f, axis_name='i')(x)
|
||||
expected = x - onp.log(onp.sum(onp.exp(x)))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testNested(self):
|
||||
f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j')
|
||||
x = onp.ones((2, 2))
|
||||
ans1 = serial_pmap(serial_pmap(f, 'i'), 'j')(x)
|
||||
ans2 = serial_pmap(serial_pmap(f, 'j'), 'i')(x)
|
||||
ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x)
|
||||
ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x)
|
||||
expected = 4 * onp.ones((2, 2))
|
||||
self.assertAllClose(ans1, expected, check_dtypes=False)
|
||||
self.assertAllClose(ans2, expected, check_dtypes=False)
|
||||
@ -87,19 +87,19 @@ class SerialPmapTest(jtu.JaxTestCase):
|
||||
class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
def testIdentity(self):
|
||||
pfun, axis_name = papply(lambda x: x, 3)
|
||||
pfun, axis_name = _papply(lambda x: x, 3)
|
||||
ans = pfun(onp.arange(3))
|
||||
expected = onp.arange(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testMap(self):
|
||||
pfun, axis_name = papply(np.sin, 3)
|
||||
pfun, axis_name = _papply(np.sin, 3)
|
||||
ans = pfun(onp.arange(3.))
|
||||
expected = onp.sin(onp.arange(3.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSum(self):
|
||||
pfun, axis_name = papply(lambda x: np.sum(x, axis=0), 5)
|
||||
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0), 5)
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.ones(3))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
@ -107,12 +107,12 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
arg = onp.arange(15.).reshape((5, 3))
|
||||
ans = serial_pmap(pfun, axis_name)(arg)[0]
|
||||
ans = _serial_pmap(pfun, axis_name)(arg)[0]
|
||||
expected = onp.sum(arg, axis=0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testMax(self):
|
||||
pfun, axis_name = papply(lambda x: np.max(x, axis=0), 5)
|
||||
pfun, axis_name = _papply(lambda x: np.max(x, axis=0), 5)
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.ones(3))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
@ -120,12 +120,12 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
arg = onp.arange(15.).reshape((5, 3))
|
||||
ans = serial_pmap(pfun, axis_name)(arg)[0]
|
||||
ans = _serial_pmap(pfun, axis_name)(arg)[0]
|
||||
expected = onp.max(arg, axis=0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSelect(self):
|
||||
pfun, axis_name = papply(lax.select, 5,
|
||||
pfun, axis_name = _papply(lax.select, 5,
|
||||
in_axes=(None, 0, None))
|
||||
|
||||
p = onp.arange(15).reshape((5, 3)) % 4 == 1
|
||||
@ -142,7 +142,7 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
ans = serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
|
||||
ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
|
||||
expected = lax.select(p, t, f)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
@ -152,14 +152,14 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
def fun(x):
|
||||
return x - np.log(np.sum(np.exp(x)))
|
||||
|
||||
pfun, axis_name = papply(fun, 5)
|
||||
pfun, axis_name = _papply(fun, 5)
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5))
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
ans = serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
|
||||
ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
|
||||
expected = fun(onp.arange(1., 5.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -167,8 +167,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
x = onp.array([[1, 2, 3], [4, 5, 6]])
|
||||
expected = x + x
|
||||
|
||||
pfun, axis_name = papply(np.add, 2)
|
||||
ans = serial_pmap(pfun, axis_name)(x, x)
|
||||
pfun, axis_name = _papply(np.add, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x, x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def testAddBroadcasting(self):
|
||||
@ -180,8 +180,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
x = onp.array([[1, 2], [3, 4]])
|
||||
expected = x + 3
|
||||
|
||||
pfun, axis_name = papply(fun, 2)
|
||||
ans = serial_pmap(pfun, axis_name)(x)
|
||||
pfun, axis_name = _papply(fun, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def testTranspose(self):
|
||||
@ -195,8 +195,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
]
|
||||
for x in xs:
|
||||
expected = x.T
|
||||
pfun, axis_name = papply(fun, x.shape[0])
|
||||
ans = serial_pmap(pfun, axis_name)(x)
|
||||
pfun, axis_name = _papply(fun, x.shape[0])
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testTransposeWithOddPermutation(self):
|
||||
@ -210,8 +210,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
]
|
||||
for x in xs:
|
||||
expected = np.transpose(x, (2, 0, 1))
|
||||
pfun, axis_name = papply(fun, x.shape[0])
|
||||
ans = serial_pmap(pfun, axis_name)(x)
|
||||
pfun, axis_name = _papply(fun, x.shape[0])
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testTransposeAndAddRank2(self):
|
||||
@ -222,8 +222,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2))
|
||||
expected = x + x.T
|
||||
|
||||
pfun, axis_name = papply(fun, 2)
|
||||
ans = serial_pmap(pfun, axis_name)(x)
|
||||
pfun, axis_name = _papply(fun, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testTransposeAndAddRank3(self):
|
||||
@ -234,8 +234,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
|
||||
expected = x + x.T
|
||||
|
||||
pfun, axis_name = papply(fun, 2)
|
||||
ans = serial_pmap(pfun, axis_name)(x)
|
||||
pfun, axis_name = _papply(fun, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testDot(self):
|
||||
@ -251,8 +251,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
for in_axes in in_axes_combos:
|
||||
for x in xs:
|
||||
expected = fun(x, x)
|
||||
pfun, axis_name = papply(fun, x.shape[0], in_axes=in_axes)
|
||||
ans = serial_pmap(pfun, axis_name)(x, x)
|
||||
pfun, axis_name = _papply(fun, x.shape[0], in_axes=in_axes)
|
||||
ans = _serial_pmap(pfun, axis_name)(x, x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user