refactor call primitives, simpler param processing (#3491)

This commit is contained in:
Matthew Johnson 2020-06-23 09:39:45 -07:00 committed by GitHub
parent d5a5d301f2
commit 75278309aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 327 additions and 324 deletions

View File

@ -1150,16 +1150,11 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
for arg in args: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = pxla.xla_pmap(
flat_fun,
*args,
backend=backend,
axis_name=axis_name,
axis_size=local_axis_size,
global_axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
name=flat_fun.__name__,
flat_fun, *args, backend=backend, axis_name=axis_name,
axis_size=local_axis_size, global_axis_size=axis_size,
devices=None if devices is None else tuple(devices),
mapped_invars=tuple(axis is not None for axis in in_axes_flat),
donated_invars=tuple(donated_invars))
name=flat_fun.__name__, donated_invars=tuple(donated_invars))
return tree_unflatten(out_tree(), out)
return f_pmapped

View File

@ -1062,7 +1062,7 @@ def canonicalize_shape(shape):
raise TypeError(msg.format(shape))
# ------------------- Call and map -------------------
# ------------------- Call -------------------
def apply_todos(todos, outs):
todos_list = list(todos)
@ -1071,7 +1071,7 @@ def apply_todos(todos, outs):
return outs
@lu.transformation_with_aux
def process_env_traces(post_processor: str, primitive: Primitive,
def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
level: int, params_tuple: tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
@ -1084,41 +1084,59 @@ def process_env_traces(post_processor: str, primitive: Primitive,
break
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
outs = map(trace.full_raise, outs)
post_process = getattr(trace, post_processor)
outs, cur_todo = post_process(primitive, outs, params)
outs, cur_todo = primitive.post_process(trace, outs, params)
todo.append(cur_todo)
yield outs, tuple(todo) # Ensure the aux output is immutable
def _call_bind(processor: str, post_processor: str, primitive: Primitive,
f: lu.WrappedFun, *args, **params):
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun: lu.WrappedFun, *args, **params):
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
level = trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level
params_tuple = tuple(params.items())
f, env_trace_todo = process_env_traces(f, post_processor, primitive, level, params_tuple)
fun, env_trace_todo = process_env_traces(fun, primitive, level, params_tuple)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(f, *args, **params)
outs = primitive.impl(fun, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
process = getattr(top_trace, processor)
outs = map(full_lower, process(primitive, f, tracers, params))
outs = primitive.process(top_trace, fun, tracers, params)
return apply_todos(env_trace_todo(), outs)
call_bind = partial(_call_bind, 'process_call', 'post_process_call')
map_bind = partial(_call_bind, 'process_map', 'post_process_map')
class CallPrimitive(Primitive):
multiple_results = True
call_primitive = True
bind = call_bind
def process(self, trace, fun, tracers, params):
return trace.process_call(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
return trace.post_process_call(self, out_tracers, params)
def call_impl(f: lu.WrappedFun, *args, **params):
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)
call_p = Primitive('call')
call_p.multiple_results = True
call_p.call_primitive = True
call = partial(call_bind, call_p)
call_p.def_custom_bind(call)
call_p = CallPrimitive('call')
call = call_p.bind
call_p.def_impl(call_impl)
# ------------------- Map -------------------
class MapPrimitive(Primitive):
multiple_results = True
map_primitive = True
def bind(self, fun, *args, **params):
assert len(params['mapped_invars']) == len(args)
return call_bind(self, fun, *args, **params)
def process(self, trace, fun, tracers, params):
return trace.process_map(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
return trace.post_process_map(self, out_tracers, params)
# ------------------- Jaxpr checking -------------------
@ -1168,14 +1186,13 @@ def check_jaxpr(jaxpr: Jaxpr):
try:
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
except Exception as e:
exception_type = type(e)
msg_context = f"while checking jaxpr:\n\n{jaxpr}\n"
if len(e.args) == 0:
exception_args = [msg_context]
else:
msg = f"{e.args[0]}\n\n" + msg_context
msg = f"{e.args[0]}\n\n{msg_context}"
exception_args = [msg, *e.args[1:]]
raise exception_type(*exception_args) from e
raise type(e)(*exception_args) from e
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
@ -1203,6 +1220,11 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
map(write, jaxpr.invars, in_avals)
for eqn in jaxpr.eqns:
if eqn.primitive in skip_check_primitives:
map(write, eqn.outvars, [v.aval for v in eqn.outvars]) # skip checking
continue
in_avals = map(read, eqn.invars)
if eqn.primitive.call_primitive:
out_avals = check_call(eqn.primitive, in_avals, eqn.params)
@ -1218,6 +1240,8 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
map(read, jaxpr.outvars)
skip_check_primitives: Set[Primitive] = set()
def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial, update_wrapper, reduce
from functools import update_wrapper, reduce
import inspect
import operator as op
@ -262,25 +262,30 @@ def _flatten_jvp(in_tree, *args):
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, out_tree
def _custom_jvp_call_bind(prim, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
with core.new_sublevel():
outs = prim.impl(fun, jvp, *args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(prim, fun, jvp, tracers)
return map(core.full_lower, outs)
class CustomJVPCallPrimitive(core.CallPrimitive):
def bind(self, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = core.process_env_traces(
fun, self, top_trace and top_trace.level, ())
jvp, env_trace_todo2 = core.process_env_traces(
jvp, self, top_trace and top_trace.level, ())
if top_trace is None:
with core.new_sublevel():
outs = self.impl(fun, jvp, *args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if env_trace_todo:
raise core.UnexpectedTracerError
return map(core.full_lower, outs)
def _custom_jvp_call_impl(fun, _, *args):
return fun.call_wrapped(*args)
def impl(self, fun, _, *args):
return fun.call_wrapped(*args)
custom_jvp_call_p = core.Primitive('custom_jvp_call')
custom_jvp_call_p.multiple_results = True
custom_jvp_call = partial(_custom_jvp_call_bind, custom_jvp_call_p)
custom_jvp_call_p.def_custom_bind(custom_jvp_call)
custom_jvp_call_p.def_impl(_custom_jvp_call_impl)
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
custom_jvp_call = custom_jvp_call_p.bind
def custom_jvp_call_jaxpr(fun, jvp, *args):
@ -501,28 +506,25 @@ def _flatten_bwd(in_tree, out_trees, *args):
raise TypeError(msg.format(in_tree2, in_tree)) from None
yield cts_in
def _custom_vjp_call_bind(prim, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
with core.new_sublevel():
outs = prim.impl(fun, fwd, bwd, *args, out_trees=out_trees)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_vjp_call(prim, fun, fwd, bwd, tracers,
out_trees=out_trees)
outs = map(core.full_lower, outs)
return map(core.full_lower, outs)
def _custom_vjp_call_impl(fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*args)
class CustomVJPCallPrimitive(core.CallPrimitive):
def bind(self, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
outs = fun.call_wrapped(*args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
out_trees=out_trees)
return map(core.full_lower, outs)
custom_vjp_call_p = core.Primitive('custom_vjp_call')
custom_vjp_call_p.multiple_results = True
custom_vjp_call = partial(_custom_vjp_call_bind, custom_vjp_call_p)
custom_vjp_call_p.def_custom_bind(custom_vjp_call)
custom_vjp_call_p.def_impl(_custom_vjp_call_impl)
def impl(self, fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees
return fun.call_wrapped(*args)
custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
custom_vjp_call = custom_vjp_call_p.bind
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]

View File

@ -632,7 +632,8 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn,
pred1_and_token1,
xla.xla_call_p,
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_before"),
name="cond_before",
donated_invars=(False,) * (cond_nconsts + len(carry_invars) + 1)),
eqn.source_info))
# Make a new cond "lambda pred, carry, token: pred"
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
@ -667,14 +668,19 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn,
new_body_carry2 + [new_body_token2],
xla.xla_call_p,
dict(call_jaxpr=transformed_body_jaxpr.jaxpr,
name="body"),
name="body",
donated_invars=(False,) * (len(new_body_invars_body_constvars) +
len(new_body_invars_carry) +
1 + len(new_body_carry2) + 1)),
eqn.source_info),
core.new_jaxpr_eqn(
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
[new_body_pred2, new_body_token3],
xla.xla_call_p,
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_body"),
name="cond_body",
donated_invars=(False,) * (len(new_body_invars_cond_constvars) +
len(new_body_carry2) + 1 + 2)),
eqn.source_info)
]
new_body_jaxpr = _mk_typed_jaxpr(

View File

@ -243,12 +243,10 @@ class JVPTrace(Trace):
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
try:
jvp = primitive_jvps[primitive]
except KeyError as err:
raise NotImplementedError(
"Forward-mode differentiation rule for '{}' not implemented"
.format(primitive)) from err
jvp = primitive_jvps.get(primitive)
if not jvp:
msg = f"Differentiation rule for '{primitive}' not implemented"
raise NotImplementedError(msg)
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
if primitive.multiple_results:
return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
@ -258,52 +256,35 @@ class JVPTrace(Trace):
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
name = params.get('name', f.__name__)
new_params = dict(params, name=wrap_name(name, 'jvp'))
if 'donated_invars' in new_params:
new_donated_invars = (*params['donated_invars'],
*[m for m, t in zip(params['donated_invars'], tangents)
if type(t) is not Zero])
new_params['donated_invars'] = tuple(new_donated_invars)
len(primals), tangent_tree_def)
nz_tangents = [type(t) is not Zero for t in tangents]
params = dict(params, name=wrap_name(params['name'], 'jvp'))
if isinstance(call_primitive, core.MapPrimitive):
mapped_invars = params['mapped_invars']
mapped_tangents = [m for m, nz in zip(mapped_invars, nz_tangents) if nz]
params = dict(params, mapped_invars=(*mapped_invars, *mapped_tangents))
update_params = call_param_updaters.get(call_primitive)
new_params = update_params(params, nz_tangents) if update_params else params
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def post_process_call(self, call_primitive, out_tracers, params):
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out = primals + tangents
out, treedef = tree_flatten((primals, tangents))
del primals, tangents
master = self.master
def todo(x):
n = len(x) // 2
primals, tangents = x[:n], x[n:]
primals, tangents = tree_unflatten(treedef, x)
trace = JVPTrace(master, core.cur_sublevel())
return map(partial(JVPTracer, trace), primals, tangents)
return out, todo
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
# only differs from process_call in that it must update mapped_invars
# TODO de-duplicate code
assert map_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, in_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
len(primals), in_tree_def)
new_name = wrap_name(params.get('name', f.__name__), 'jvp')
new_mapped_invars = (*params['mapped_invars'],
*[m for m, t in zip(params['mapped_invars'], tangents)
if type(t) is not Zero])
new_donated_invars = (*params['donated_invars'],
*[m for m, t in zip(params['donated_invars'], tangents)
if type(t) is not Zero])
new_params = dict(params, name=new_name, mapped_invars=new_mapped_invars,
donated_invars=new_donated_invars)
result = map_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
# The only difference between process_map and process_call is that
# the `mapped_invars` param must be updated; that's handled in process_call.
process_map = process_call
post_process_map = post_process_call
def process_custom_jvp_call(self, _, __, f_jvp, tracers):
@ -363,11 +344,14 @@ def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = raise_to_shaped(get_aval(primal))
tangent_aval = raise_to_shaped(get_aval(tangent))
assert primal_aval == tangent_aval
assert primal_aval == tangent_aval, (primal_aval, tangent_aval)
call_param_updaters: Dict[core.Primitive, Callable] = {}
call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
# -------------------- Primitives --------------------
primitive_jvps : Dict[core.Primitive, Callable] = {}
primitive_transposes: Dict[core.Primitive, Callable] = {}
@ -492,13 +476,12 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
if 'donated_invars' in params:
new_donated_invars = (*[d for d, x in zip(params['donated_invars'], args)
if not is_undefined_primal(x)],
*[False for x in ct if type(x) is not Zero])
params['donated_invars'] = tuple(new_donated_invars)
out_flat = primitive.bind(fun, *all_args, **params)
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, map(is_undefined_primal, args),
[type(x) is not Zero for x in ct])
out_flat = primitive.bind(fun, *all_args, **new_params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
@ -507,14 +490,12 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
# backward_pass can only transpose linear computations, but the call_jaxpr embedded in
# remat contains primal (non-linear) equations too. Hence, we have to eliminate those
# (in this case via partial_eval) before we call into backward_pass again.
typed_call_jaxpr = core.TypedJaxpr(
call_jaxpr, [],
[raise_to_shaped(p.aval if is_undefined_primal(p) else get_aval(p)) for p in primals_in],
cotangent_in_avals)
in_avals = [raise_to_shaped(p.aval if is_undefined_primal(p) else get_aval(p))
for p in primals_in]
typed_call_jaxpr = core.TypedJaxpr(call_jaxpr, [], in_avals, cotangent_in_avals)
unknowns = map(is_undefined_primal, primals_in)
primal_jaxpr, tangent_jaxpr, out_unknowns = \
pe.partial_eval_jaxpr(typed_call_jaxpr,
unknowns=map(is_undefined_primal, primals_in),
instantiate=True,
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
trace_type=None)
def do_transpose(primals_in, cotangents_in):
@ -541,12 +522,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
new_mapped_invars = (*[m for m, x in zip(params['mapped_invars'], args)
if not is_undefined_primal(x)],
*[True for x in ct if type(x) is not Zero])
new_donated_invars = (*[d for d, x in zip(params['donated_invars'], args)
if not is_undefined_primal(x)],
*[False for x in ct if type(x) is not Zero])
new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
mapped_invars=tuple(new_mapped_invars),
donated_invars=tuple(new_donated_invars))
mapped_invars=new_mapped_invars)
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, map(is_undefined_primal, args),
[type(x) is not Zero for x in ct])
out_flat = primitive.bind(fun, *all_args, **new_params)
arg_cts = tree_unflatten(out_tree(), out_flat)

View File

@ -36,12 +36,9 @@ zip = safe_zip
# Reverse call primitive
################################################################################
invertible_call_p = core.Primitive('invertible_call')
invertible_call_p.call_primitive = True
invertible_call = partial(core.call_bind, invertible_call_p)
invertible_call_p.def_custom_bind(invertible_call)
invertible_call_p = core.CallPrimitive('invertible_call')
invertible_call = invertible_call_p.bind
invertible_call_p.def_impl(core.call_impl)
invertible_call_p.multiple_results = True
def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
uks = [not t.pval.is_known() for t in out_tracers]
@ -67,6 +64,8 @@ def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params)
pe.call_partial_eval_rules[invertible_call_p] = partial(
pe._remat_partial_eval, _invertible_call_make_output_tracers)
# TODO(mattjj): remove this when #3370 lands
core.skip_check_primitives.add(invertible_call_p)
@cache()
def _append_invars(jaxpr, avals):
@ -259,10 +258,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
in_avals = map(abstract, primals_in + primals_out + primals_out)
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
complete_ivjp_flat,
map(PartialVal.unknown, in_avals),
instantiate=True,
stage_out=False)
complete_ivjp_flat, map(PartialVal.unknown, in_avals),
instantiate=True, stage_out=False)
assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
ivjp_jaxpr = core.TypedJaxpr(ivjp_jaxpr, [], in_avals, out_avals)
@ -273,10 +270,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
unknowns = (map(ad.is_undefined_primal, primals_in) +
map(ad.is_undefined_primal, primals_out) +
[False] * len(cts_in))
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(ivjp_jaxpr,
unknowns,
instantiate=False,
trace_type=None)
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None)
unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
# Make sure we're able to compute all cotangents. We don't really care if we
# can reconstruct or primals or not, although failure to do so might result in

View File

@ -27,7 +27,7 @@ from .. import linear_util as lu
from ..abstract_arrays import ConcreteArray, raise_to_shaped
from ..ad_util import Zero
from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
cache, curry)
cache)
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
AbstractValue, unit, unitvar, abstract_unit,
TypedJaxpr, new_jaxpr_eqn)
@ -73,7 +73,7 @@ class PartialVal(tuple):
return self[1] if self[0] is None else None
def get_aval(self) -> AbstractValue:
"""Get the AbstractValue either directly for unknown values, or from the known constant."""
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
known = self.get_known()
if known is not None:
return get_aval(known)
@ -170,35 +170,29 @@ class JaxprTrace(Trace):
return out_tracer
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
name = params.get('name', f.__name__)
if (self.master.trace_type is StagingJaxprTrace
and primitive in staged_out_calls):
tracers = map(self.instantiate_const_abstracted, tracers)
params = dict(params, name=name)
if primitive in call_partial_eval_rules:
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
@curry
def modify_aval(modify, args):
pval, is_mapped = args
if pval.is_known() or not is_mapped:
return pval
return PartialVal((modify(params['axis_size'], pval[0]), pval[1]))
in_pvals = [t.pval for t in tracers]
if primitive.map_primitive:
in_pvals = map(modify_aval(core.mapped_aval), zip(in_pvals, params['mapped_invars']))
mapped_aval = partial(core.mapped_aval, params['axis_size'])
in_pvals = [pval if pval.is_known() or not is_mapped
else PartialVal.unknown(mapped_aval(pval[0]))
for pval, is_mapped in zip(in_pvals, params['mapped_invars'])]
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
f, in_pvals, partial(primitive.bind, **params))
if primitive.map_primitive:
out_pvals = map(modify_aval(core.unmapped_aval),
[(pval, True) for pval in out_pvals])
unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
out_pvals = [pval if pval.is_known()
else PartialVal.unknown(unmapped_aval(pval[0]))
for pval in out_pvals]
# Don't bother if the traced jaxpr is trivial. Simply evaluate it in here.
# XXX: We don't allow this fast path for map primitives, because this simplification might
# e.g. reduce the number of required devices if someone pmaps an identity function.
if not primitive.map_primitive and not jaxpr.eqns:
# Avoid staging out trivial calls, but maps may involve broadcasting.
if not jaxpr.eqns and not primitive.map_primitive:
env = {core.unitvar: core.unit}
map(env.setdefault, jaxpr.invars, (*env_tracers, *tracers))
map(env.setdefault, jaxpr.constvars, consts)
@ -212,94 +206,76 @@ class JaxprTrace(Trace):
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
jaxpr = _drop_invars(jaxpr, in_knowns)
jaxpr = _dce_untyped_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
# Known tracers get propagated as if they were constants
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals if pval.is_known()]
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
if pval.is_known()]
# Unknown tracers need to have the jaxpr set up as their recipe
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals if not pval.is_known()]
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
if not pval.is_known()]
unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
const_tracers = map(self.new_instantiated_const, consts)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
if 'donated_invars' in params:
new_donated_invars = ((False,) * len(const_tracers) +
(False,) * len(env_tracers) +
tuple(v for v, t in zip(params['donated_invars'], tracers) if not t.pval.is_known()))
new_params['donated_invars'] = new_donated_invars
in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
# Set up new params
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
if primitive.map_primitive:
mapped_invars = params['mapped_invars']
new_mapped_invars = ((True,) * len(const_tracers) +
(False,) * len(env_tracers) +
tuple(v for v, t in zip(params['mapped_invars'], tracers) if not t.pval.is_known()))
new_params['mapped_invars'] = new_mapped_invars
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, unknown_tracers_in)),
unknown_tracers_out, primitive, new_params,
source_info_util.current())
for t in unknown_tracers_out:
t.recipe = eqn
tuple(v for v, t in zip(mapped_invars, tracers)
if not t.pval.is_known()))
new_params = dict(new_params, mapped_invars=new_mapped_invars)
update_params = call_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
def post_process_call(self, call_primitive, out_tracers, params):
process_map = process_call
# We use post_process_call to handle both call and map primitives.
def post_process_call(self, primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.master
if primitive.map_primitive:
sz = params['axis_size']
out_pvs = [None if pv is None else core.unmapped_aval(sz, pv)
for pv in out_pvs]
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(master, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
env_tracers = map(trace.full_raise, env)
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
invars = tuple(it.chain(const_tracers, env_tracers))
new_params = dict(params, call_jaxpr=lifted_jaxpr)
if 'donated_invars' in params:
new_params['donated_invars'] = (False,) * len(invars)
# The `jaxpr` already contains the env_vars at start of invars
eqn = new_eqn_recipe(invars, out_tracers, call_primitive, new_params,
in_tracers = (*const_tracers, *map(trace.full_raise, env))
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
if primitive.map_primitive:
new_mapped_invars = (True,) * len(const_tracers) + (False,) * len(env)
new_params = dict(new_params, mapped_invars=new_mapped_invars)
update_params = call_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, [])
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
source_info_util.current())
for t in out_tracers:
t.recipe = eqn
return out_tracers
return out, todo
process_map = process_call
def post_process_map(self, map_primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs_reduced, out_pv_consts = unzip2(t.pval for t in out_tracers)
out_pvs = [None if pv is None
else core.unmapped_aval(params['axis_size'], pv)
for pv in out_pvs_reduced]
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.master
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(master, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
# The `jaxpr` already contains the env_vars at start of invars
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_donated_invars = (False,) * (len(const_tracers) + len(env))
new_mapped_invars = (True,) * len(const_tracers) + (False,) * len(env)
new_params = dict(params, donated_invars=tuple(new_donated_invars),
mapped_invars=tuple(new_mapped_invars),
call_jaxpr=lifted_jaxpr)
env_tracers = map(trace.full_raise, env)
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
out_tracers, map_primitive, new_params,
source_info_util.current())
for t in out_tracers:
t.recipe = eqn
return out_tracers
return out, todo
post_process_map = post_process_call
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# See comment at top of `JaxprTrace`. This method should be reachable
@ -333,23 +309,24 @@ class StagingJaxprTrace(JaxprTrace):
@lu.transformation_with_aux
def partial_eval_wrapper(avals: Sequence[Optional[AbstractValue]], *consts):
py_args = (map(PartialVal, zip(avals, consts)),)
jaxpr, (out_pvals, consts, env) = yield py_args, {}
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
py_args = map(PartialVal, zip(pvs, consts))
jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
out_pvs, out_consts = unzip2(out_pvals)
out = tuple(out_consts) + tuple(consts) # TODO: can consts be traced?
out = tuple(out_consts) + tuple(consts)
yield out, (out_pvs, jaxpr, env)
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
staged_out_calls: Set[core.Primitive] = set()
call_param_updaters: Dict[core.Primitive, Callable] = {}
def abstract_eval_fun(fun, *avals, **params):
pvals_in = [PartialVal.unknown(a) for a in avals]
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
instantiate=True, stage_out=True)
instantiate=True, stage_out=True)
avals_out, _ = unzip2(pvals_out)
for aval_out in avals_out:
assert isinstance(aval_out, AbstractValue) # instantiate=True
@ -406,15 +383,12 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
-> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
"""Traces a function into a Jaxpr, given PartialVals for inputs.
`trace_type` can be one of `StagingJaxprTrace` or `JaxprTrace` (see
comments for that class).
Returns (`jaxpr`, `out_pvals`, `consts`).
The `jaxpr` contains only the computation that depends on unknown inputs.
The `out_pvals` are the PartialVal for the outputs. The intermediate
values that depend only on known inputs and are needed to compute the output
of `jaxpr` are in `consts` and are passed in as the constvars of
the `jaxpr`. The handling of the known outputs depends on `instantiate`.
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
for the outputs. The intermediate values that depend only on known inputs and
are needed to compute the output of `jaxpr` are in `consts` and are passed in
as the constvars of the `jaxpr`. The handling of the known outputs depends on
`instantiate`.
For example, given `fun` defined as follows::
@ -425,11 +399,11 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
computation that depends on unknown inputs is `ui + ka` and will be the only
computation in the body of the `jaxpr`. This computation depends on the
known intermediate value `ka`, which will be computed statically. Currently,
such constants are either embedded in the Jaxpr if they are scalars, or
passed as a constvar to `jaxpr`, and then the value of the actual constant
will be in `consts`:
computation in the body of the `jaxpr`. This computation depends on the known
intermediate value `ka`, which will be computed statically. Currently, such
constants are either embedded in the Jaxpr if they are scalars, or passed as a
constvar to `jaxpr`, and then the value of the actual constant will be in
`consts`:
When `instantiate=False` we get::
@ -437,7 +411,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
{ lambda ka ; ki ui.
let c = add ui ka
in (*, c) } # known outputs are `*`
out_pvals = [known(6), unknown(ShapedArray)] # the known outputs are known PartialVal
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
consts = [3] # the constant for `ka`
When `instantiate=True` we get::
@ -446,7 +420,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
{ lambda ka kb ; ki ui.
let c = add ui ka
in (kb, c) } # known output are explicit
out_pvals = [abstract(ConcreteArray(6)), abstract(ShapedArray)] # all are unknown PartialVal
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
consts = [3, 6] # values for `ka` and `kb` constvars
"""
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
@ -509,19 +483,20 @@ def new_eqn_recipe(invars: Sequence[JaxprTracer],
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
if primitive.map_primitive:
assert "mapped_invars" in params
assert "donated_invars" in params
assert ("mapped_invars" in params and
len(params["mapped_invars"]) == len(params["call_jaxpr"].invars))
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params, source_info)
def recipe_to_eqn(unused_var: Callable[[], core.Var],
getvar: Callable[[JaxprTracer], core.Atom],
def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom],
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
outvars = [unused_var() if t is None else cast(core.Var, getvar(t))
outvars = [core.dropvar if t is None else cast(core.Var, getvar(t))
for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
@ -564,7 +539,7 @@ def tracers_to_jaxpr(
recipe = t.recipe
if isinstance(recipe, JaxprEqnRecipe):
if recipe.eqn_id not in processed_eqn_ids:
eqns.append(recipe_to_eqn(lambda: core.dropvar, getvar, recipe))
eqns.append(recipe_to_eqn(getvar, recipe))
processed_eqn_ids.add(recipe.eqn_id)
elif isinstance(recipe, LambdaBinding):
if not any(t is in_tracer for in_tracer in in_tracers):
@ -586,7 +561,7 @@ def tracers_to_jaxpr(
env_vars, env_vals = unzip2(env.items())
const_vars, const_vals = unzip2(consts.items())
# The env_vars are pre-pended to the invars
jaxpr = Jaxpr(const_vars, list(it.chain(env_vars, invars)), list(map(getvar, out_tracers)), eqns)
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
core.skip_checks or core.check_jaxpr(jaxpr)
return jaxpr, const_vals, env_vals
@ -682,12 +657,9 @@ def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, Abst
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
remat_call_p = core.Primitive('remat_call')
remat_call_p.call_primitive = True
remat_call = partial(core.call_bind, remat_call_p)
remat_call_p.def_custom_bind(remat_call)
remat_call_p = core.CallPrimitive('remat_call')
remat_call = remat_call_p.bind
remat_call_p.def_impl(core.call_impl)
remat_call_p.multiple_results = True
# We reuse the _remat_partial_eval function both for remat_call and for
# invertible_call, both of which in a sense stage out operations to

View File

@ -11,24 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of pmap and related functionality.
"""Implementation of pmap and related functionality."""
Note on ShardingSpecs and spec_to_indices():
A ShardingSpec describes at a high level how a logical array is sharded across
devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
describe how to shard inputs to a parallel computation). spec_to_indices()
encodes exactly how a given ShardingSpec is translated to device buffers,
i.e. how the sharded array is "laid out" across devices. Given a sequence of
devices, we shard the data across the devices in row-major order, with
replication treated as an extra inner dimension.
For example, given the logical data array [1, 2, 3, 4], if we were to partition
this array 4 ways with a replication factor of 2, for a total of 8 devices, the
data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
This encoding is assumed by various parts of the system, e.g. generating
replica groups for collective operations.
"""
# A ShardingSpec describes at a high level how a logical array is sharded across
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
# describe how to shard inputs to a parallel computation). spec_to_indices()
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
# how the sharded array is "laid out" across devices. Given a sequence of
# devices, we shard the data across the devices in row-major order, with
# replication treated as an extra inner dimension.
#
# For example, given the logical data array [1, 2, 3, 4], if we were to
# partition this array 4 ways with a replication factor of 2, for a total of 8
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
#
# This encoding is assumed by various parts of the system, e.g. generating
# replica groups for collective operations.
from collections import defaultdict
from contextlib import contextmanager
@ -64,7 +62,7 @@ xops = xc.ops
FLAGS = flags.FLAGS
_map = safe_map
unsafe_map, map = map, safe_map
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
@ -120,7 +118,7 @@ class ShardingSpec:
def __repr__(self):
return ("ShardingSpec(shards_per_axis=%s, is_axis_materialized=%s, "
"replication_factor=%s)" %
"replication_factors=%s)" %
(self.shards_per_axis, self.is_axis_materialized,
self.replication_factors))
@ -186,7 +184,7 @@ def spec_to_indices(shape: Tuple[int, ...],
def _axis_indices(axis_size, num_shards, is_materialized):
if not is_materialized:
assert axis_size == num_shards
assert axis_size == num_shards, f'{axis_size} != {num_shards}'
return list(range(axis_size))
if num_shards == 1:
return [slice(None)]
@ -785,7 +783,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
tuple_args = len(sharded_avals) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
xla_consts = _map(partial(xb.constant, c), consts)
xla_consts = map(partial(xb.constant, c), consts)
replicated = [not m for m in mapped_invars]
xla_args = xla._xla_callable_args(c, sharded_avals, tuple_args, replicated,
arg_parts)
@ -1115,22 +1113,23 @@ def execute_replicated(compiled,
return out_handler(out_bufs)
xla_pmap_p = core.Primitive('xla_pmap')
xla_pmap_p.map_primitive = True
xla_pmap_p.multiple_results = True
xla_pmap = partial(core.map_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
xla_pmap_p = core.MapPrimitive('xla_pmap')
xla_pmap = xla_pmap_p.bind
xla_pmap_p.def_impl(xla_pmap_impl)
pe.staged_out_calls.add(xla_pmap_p)
# Set param update handlers to update `donated_invars` just like xla_call_p
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
ad.call_transpose_param_updaters[xla_pmap_p] = \
ad.call_transpose_param_updaters[xla.xla_call_p]
def _pmap_translation_rule(c, axis_env,
in_nodes, name_stack, axis_name, axis_size,
global_axis_size, devices, name,
call_jaxpr, *, backend=None, mapped_invars,
donated_invars):
if any(donated_invars):
raise ValueError("Donating buffers passed to a a pmap nested inside a jit "
"or another pmap is not supported.")
del donated_invars # Unused.
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
if axis_env.names and devices is not None:

View File

@ -187,11 +187,8 @@ def _sharded_call_impl(fun, *args, num_partitions, in_parts, out_parts_thunk,
return compiled_fun(*args)
sharded_call_p = core.Primitive("sharded_call")
sharded_call_p.call_primitive = True
sharded_call_p.multiple_results = True
sharded_call = partial(core.call_bind, sharded_call_p)
sharded_call_p.def_custom_bind(sharded_call)
sharded_call_p = core.CallPrimitive("sharded_call")
sharded_call = sharded_call_p.bind
sharded_call_p.def_impl(_sharded_call_impl)
xla.call_translations[sharded_call_p] = _sharded_jit_translation_rule

View File

@ -35,13 +35,16 @@ from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
from ..core import Literal, pp_eqn_compact
from ..pprint_util import pp
from ..util import (partial, partialmethod, cache, prod, unzip2, memoize,
extend_name_stack, wrap_name, safe_zip)
extend_name_stack, wrap_name, safe_zip, safe_map)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from . import partial_eval as pe
from . import ad
from . import masking
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
xe = xc._xla
xops = xc._xla.ops
@ -63,7 +66,6 @@ flags.DEFINE_bool('jax_log_compiles',
bool_env('JAX_LOG_COMPILES', False),
'Print a message each time a `jit` computation is compiled.')
def _map(f, *xs): return tuple(map(f, *xs))
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
@ -199,7 +201,7 @@ def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
return True
for param in params.values():
if isinstance(param, tuple):
if any(_map(_param_uses_outfeed, param)):
if any(unsafe_map(_param_uses_outfeed, param)):
return True
elif _param_uses_outfeed(param):
return True
@ -223,7 +225,7 @@ def arg_spec(x):
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
return compiled_fun(*args)
@cache()
@ -243,8 +245,8 @@ def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
if not prim.multiple_results:
handle_result = aval_to_result_handler(device, aval_out)
else:
handlers = tuple(map(partial(aval_to_result_handler, device), aval_out))
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs))
handlers = map(partial(aval_to_result_handler, device), aval_out)
handle_result = lambda xs: tuple(h(x) for h, x in unsafe_zip(handlers, xs))
tuple_args = len(avals) > 100
if prim in initial_style_translations:
nreps = initial_style_primitive_replicas(params)
@ -386,8 +388,8 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
env = {}
write(core.unitvar, _make_unit(c))
_map(write, jaxpr.constvars, consts)
_map(write, jaxpr.invars, args)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
frame = source_info_util.user_frame(eqn.source_info)
c.set_op_metadata(xc.OpMetadata(
@ -396,7 +398,7 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
eqn.primitive.name, eqn.params)),
source_file=frame.file_name if frame else None,
source_line=frame.line_num if frame else None))
in_nodes = list(map(read, eqn.invars))
in_nodes = map(read, eqn.invars)
if eqn.primitive in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][eqn.primitive]
ans = rule(c, *in_nodes, **eqn.params)
@ -432,8 +434,8 @@ def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
c.get_shape(ans) # force xla to do shape error checking
out_nodes = xla_destructure(c, ans) if eqn.primitive.multiple_results else [ans]
c.clear_op_metadata()
_map(write, eqn.outvars, out_nodes)
return _map(read, jaxpr.outvars)
map(write, eqn.outvars, out_nodes)
return map(read, jaxpr.outvars)
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
@ -470,7 +472,7 @@ def axis_read(axis_env, axis_name):
def axis_groups(axis_env, name):
if isinstance(name, (list, tuple)):
mesh_axes = tuple(map(partial(axis_read, axis_env), name))
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
else:
mesh_axes = (axis_read(axis_env, name),)
return _axis_groups(axis_env.nreps, axis_env.sizes, mesh_axes)
@ -483,7 +485,7 @@ def _axis_groups(nrep, mesh_spec, mesh_axes):
groups = onp.reshape(
onp.moveaxis(iota, mesh_axes, onp.arange(len(mesh_axes))),
(prod(onp.take(full_spec, mesh_axes)), -1))
return tuple(map(tuple, groups.T))
return tuple(unsafe_map(tuple, groups.T))
def jaxpr_replicas(jaxpr):
"""The number of replicas needed for a jaxpr.
@ -536,7 +538,8 @@ def jaxpr_collectives(jaxpr):
### xla_call underlying jit
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, *map(arg_spec, args))
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
try:
return compiled_fun(*args)
except FloatingPointError:
@ -597,8 +600,8 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
jaxpr, pvals, consts = pe.trace_to_jaxpr(
fun, pvals, instantiate=False, stage_out=True, bottom=True)
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
jaxpr, uses_outfeed = apply_outfeed_rewriter(jaxpr)
_map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
@ -635,7 +638,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("jit_{}".format(fun.__name__))
xla_consts = _map(partial(xb.constant, c), consts)
xla_consts = map(partial(xb.constant, c), consts)
xla_args = _xla_callable_args(c, abstract_args, tuple_args)
out_nodes = jaxpr_subcomp(
c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts,
@ -779,8 +782,8 @@ def _execute_replicated(compiled: XlaExecutable, uses_outfeed: bool,
def _execute_trivial(jaxpr, device: Optional[Device], consts, handlers, *args):
env = {core.unitvar: core.unit}
_map(env.setdefault, jaxpr.invars, args)
_map(env.setdefault, jaxpr.constvars, consts)
map(env.setdefault, jaxpr.invars, args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
for v in jaxpr.outvars]
return [_copy_device_array_to_device(x, device) if type(x) is DeviceArray
@ -800,24 +803,43 @@ def _get_device(device, backend):
out, = compiled.local_devices()
return out
xla_call_p = core.Primitive('xla_call')
xla_call_p.call_primitive = True
xla_call_p.multiple_results = True
xla_call = partial(core.call_bind, xla_call_p)
xla_call_p.def_custom_bind(xla_call)
xla_call_p = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
xla_call_p.def_impl(_xla_call_impl)
pe.staged_out_calls.add(xla_call_p)
def _xla_call_partial_eval_update_params(params, in_unknowns):
call_jaxpr = params['call_jaxpr']
donated_invars = params['donated_invars']
if not in_unknowns and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
new_donated_invars = (False,) * len(call_jaxpr.invars)
else:
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk]
new_donated_invars = ((False,) * (len(call_jaxpr.invars) - len(donated_invars))
+ tuple(donated_invars))
return dict(params, donated_invars=new_donated_invars)
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
def _xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
def _xla_call_translation_rule(c, axis_env,
in_nodes, name_stack, backend, name,
call_jaxpr, device=None, donated_invars=None):
del device # Ignored.
if donated_invars is None:
donated_invars = (False,) * len(in_nodes)
elif any(donated_invars):
raise ValueError("Donating buffers passed to a jit nested inside a jit or "
"pmap is not supported.")
call_jaxpr, donated_invars, device=None):
del device, donated_invars # Ignored.
subc = xb.make_computation_builder(f"jit_{name}")
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
@ -874,7 +896,7 @@ def lower_fun(fun, multiple_results):
wrapped_fun = _tuple_output(wrapped_fun)
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
stage_out=True)
consts = _map(partial(xb.constant, c), consts)
consts = map(partial(xb.constant, c), consts)
outs = jaxpr_subcomp(c, jaxpr, None, AxisEnv(1), consts, '', *xla_args)
if multiple_results:
return xops.Tuple(c, outs)
@ -895,7 +917,7 @@ def lower_fun_initial_style(fun):
pvals = [pe.PartialVal.unknown(a) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True)
consts = _map(partial(xb.constant, c), consts)
consts = map(partial(xb.constant, c), consts)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack,
*xla_args)
return xops.Tuple(c, outs)

View File

@ -259,7 +259,7 @@ class APITest(jtu.JaxTestCase):
"Abstract evaluation for 'foo' not implemented")
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Forward-mode differentiation rule for 'foo' not implemented")
"Differentiation rule for 'foo' not implemented")
foo_p.def_abstract_eval(lambda x: x)
@ -1759,7 +1759,9 @@ class RematTest(jtu.JaxTestCase):
@jax.util.curry
def call(f, *args):
return jax.core.call(jax.linear_util.wrap_init(lambda *args: [f(*args)]), *args)[0]
return jax.core.call(
jax.linear_util.wrap_init(lambda *args: [f(*args)]),
*args, name='foo')[0]
f = call(add_one)
g = jax.remat(lambda x: add_one(f(x)))
@ -3198,8 +3200,12 @@ class BufferDonationTest(jtu.JaxTestCase):
def test_jit_nested_donate_ignored(self):
jit_fun = jit(lambda x: jit(lambda y: y ** 2, donate_argnums=0)(x))
a = jax.device_put(jnp.array(1))
with self.assertRaisesRegex(ValueError, "nested.*not supported"):
jit_fun(a)
# NOTE(mattjj): stopped raising error here and instead just ignored
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
# jit_fun(a)
jit_fun(a) # doesn't crash
def test_jnp_array_copy(self):
# https://github.com/google/jax/issues/3412
@ -3232,8 +3238,12 @@ class BufferDonationTest(jtu.JaxTestCase):
def test_pmap_nested_donate_raises(self):
pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x))
a = api.pmap(lambda x: x)(jnp.array([1]))
with self.assertRaisesRegex(ValueError, "nested.*not supported"):
pmap_fun(a)
# NOTE(mattjj): stopped raising error here and instead just ignored
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
# pmap_fun(a)
pmap_fun(a) # doesn't crash
assertDeleted = lambda self, x: self._assertDeleted(x, True)
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)