Split loop cond/body constants into tracers and literal constants.

Only hoist tracer constants into loop-carried variables; leave non-tracers as constants.

This change attempts to reduce the sometimes very large number of trivial loop-carried constants in some models.
This commit is contained in:
Peter Hawkins 2019-03-13 10:41:42 -04:00
parent cd1e4601c7
commit 098aedd66c

View File

@ -836,6 +836,20 @@ def sort_key_val(keys, values, dimension=-1):
sorted_keys, sorted_values = result
return sorted_keys, sorted_values
class _OpaqueParam(object):
"""Wrapper that hashes on its identity, instead of its contents.
Used to pass unhashable parameters as primitive attributes."""
__slots__ = ["val", "id"]
def __init__(self, val):
self.val = val
self.id = next(_opaque_param_ids)
def __hash__(self):
return self.id
_opaque_param_ids = itertools.count()
def while_loop(cond_fun, body_fun, init_val):
"""Call `body_fun` repeatedly in a loop while `cond_fun` is True.
@ -875,12 +889,38 @@ def while_loop(cond_fun, body_fun, init_val):
body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr(flat_body_fun, (pval_flat,))
aval_out, _ = pval_out
# We don't want to promote literal constants as loop arguments; there are
# sometimes many of them. We pass tracers as loop arguments, but leave
# nontracers as constants. We also sort the constants so the nontracers are
# first.
def split_tracers_and_nontracers(jaxpr, consts):
tracer = []
nontracer = []
for x in zip(jaxpr.constvars, consts):
# TODO(phawkins): We avoid treating DeviceArrays as constant literals so
# we don't copy large arrays back to the host. We probably should relax
# this and either always copy small constants, or opportunistically use
# DeviceArray values for which we already know npy_value.
not_literal_const = (isinstance(x[1], core.Tracer) or
isinstance(x[1], xla.DeviceArray))
(tracer if not_literal_const else nontracer).append(x)
tracer_vars, tracer_consts = unzip2(tracer)
nontracer_vars, nontracer_consts = unzip2(nontracer)
return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts
cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split
if out_tree() != in_tree:
raise TypeError("body_fun input and output must have identical structure")
out_flat = while_p.bind(init_val_flat, core.pack(cond_consts),
core.pack(body_consts), aval_out=aval_out,
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
out_flat = while_p.bind(
init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts),
cond_consts=_OpaqueParam(cond_nontracer_consts),
body_consts=_OpaqueParam(body_nontracer_consts),
aval_out=aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
return build_tree(out_tree(), out_flat)
@ -3687,13 +3727,15 @@ ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
def _while_loop_abstract_eval(init_val, cond_consts, body_consts, aval_out,
def _while_loop_abstract_eval(init_val, cond_tracer_consts, body_tracer_consts,
cond_consts, body_consts, aval_out,
cond_jaxpr, body_jaxpr):
return maybe_tracer_tuple_to_abstract_tuple(aval_out)
def _while_loop_translation_rule(c, init_val, cond_consts, body_consts,
def _while_loop_translation_rule(c, init_val, cond_tracer_consts,
body_tracer_consts, cond_consts, body_consts,
aval_out, cond_jaxpr, body_jaxpr):
loop_carry = c.Tuple(init_val, cond_consts, body_consts)
loop_carry = c.Tuple(init_val, cond_tracer_consts, body_tracer_consts)
shape = c.GetShape(loop_carry)
loop_carry_var = pe.Var(0, "loop_carry")
@ -3701,33 +3743,37 @@ def _while_loop_translation_rule(c, init_val, cond_consts, body_consts,
cond_var = pe.Var(0, "cond_consts")
body_var = pe.Var(0, "body_consts")
num_cond_consts = len(cond_consts.val)
assert len(cond_jaxpr.invars) == 1
cond_jaxpr_converted = cond_jaxpr.copy()
cond_jaxpr_converted.constvars = []
cond_jaxpr_converted.constvars = cond_jaxpr.constvars[:num_cond_consts]
cond_jaxpr_converted.invars = [loop_carry_var]
cond_jaxpr_converted.eqns = (
[_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]),
_unpack_eqn(cond_var, cond_jaxpr.constvars)]
_unpack_eqn(cond_var, cond_jaxpr.constvars[num_cond_consts:])]
+ list(cond_jaxpr.eqns))
num_body_consts = len(body_consts.val)
assert len(body_jaxpr.invars) == 1
body_jaxpr_converted = body_jaxpr.copy()
body_jaxpr_converted.constvars = []
body_jaxpr_converted.constvars = body_jaxpr.constvars[:num_body_consts]
body_jaxpr_converted.invars = [loop_carry_var]
body_jaxpr_converted.outvar = outvar
body_jaxpr_converted.eqns = (
[_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]),
_unpack_eqn(body_var, body_jaxpr.constvars)]
_unpack_eqn(body_var, body_jaxpr.constvars[num_body_consts:])]
+ list(body_jaxpr.eqns) +
[_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)])
cond_computation = xla.jaxpr_computation(cond_jaxpr_converted, (), (), shape)
body_computation = xla.jaxpr_computation(body_jaxpr_converted, (), (), shape)
cond_computation = xla.jaxpr_computation(
cond_jaxpr_converted, cond_consts.val, (), shape)
body_computation = xla.jaxpr_computation(
body_jaxpr_converted, body_consts.val, (), shape)
full_ans = c.While(cond_computation, body_computation, loop_carry)
return c.GetTupleElement(full_ans, 0)
def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
body_jaxpr):
def _while_loop_batching_rule(batched_args, batch_dims, cond_consts,
body_consts, aval_out, cond_jaxpr, body_jaxpr):
# See https://github.com/google/jax/issues/441 for a discussion.
# To batch a while_loop, we need to do some masking, since the elements of the
# batch may run for different numbers of iterations. We perform that masking
@ -3736,15 +3782,15 @@ def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
# The basic strategy here is to lift `cond_jaxpr` and `body_jaxpr` back into
# traceable Python functions using `core.eval_jaxpr`. Then we can batch them
# using `batching.batch_transform` (the transform underlying `api.vmap`). This
# code also avoids broadcasting `cond_consts` and `body_consts`.
init_val, cond_consts, body_consts = batched_args
init_val_bd, cond_consts_bd, body_consts_bd = batch_dims
# code also avoids broadcasting `cond_tracer_consts` and `body_tracer_consts`.
init_val, cond_tracer_consts, body_tracer_consts = batched_args
init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd = batch_dims
sizes = _reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
size = sizes.pop()
assert not sizes
# TODO(mattjj): if cond_consts_bd is also None, we could keep cond_fun
# TODO(mattjj): if cond_tracer_consts_bd is also None, we could keep cond_fun
# unbatched and avoid the masking logic, but we ignore that optimization
init_val = batching.bdim_at_front(init_val, init_val_bd, size,
force_broadcast=True)
@ -3752,21 +3798,28 @@ def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
def batched_cond_fun(batched_loop_carry):
@lu.wrap_init
def lifted(loop_carry, cond_consts):
return core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)
f = batching.batch_transform(lifted, size, (init_val_bd, cond_consts_bd), 0)
preds = f.call_wrapped((batched_loop_carry, cond_consts))
def lifted(loop_carry, cond_tracer_consts):
cond_tracer_consts = tuple(x for x in cond_tracer_consts)
return core.eval_jaxpr(
cond_jaxpr, cond_consts.val + cond_tracer_consts, (), loop_carry)
f = batching.batch_transform(lifted, size, (init_val_bd, cond_tracer_consts_bd), 0)
preds = f.call_wrapped((batched_loop_carry, cond_tracer_consts))
return reduce(preds, onp.array(False), bitwise_or, [0])
def batched_body_fun(batched_loop_carry):
@lu.wrap_init
def lifted(loop_carry, cond_consts, body_consts):
pred = core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)
new_loop_carry = core.eval_jaxpr(body_jaxpr, body_consts, (), loop_carry)
def lifted(loop_carry, cond_tracer_consts, body_tracer_consts):
cond_tracer_consts = tuple(x for x in cond_tracer_consts)
body_tracer_consts = tuple(x for x in body_tracer_consts)
pred = core.eval_jaxpr(
cond_jaxpr, cond_consts.val + cond_tracer_consts, (), loop_carry)
new_loop_carry = core.eval_jaxpr(
body_jaxpr, body_consts.val + body_tracer_consts, (), loop_carry)
return _jaxtupletree_select(pred, new_loop_carry, loop_carry)
f = batching.batch_transform(
lifted, size, (init_val_bd, cond_consts_bd, body_consts_bd), init_val_bd)
return f.call_wrapped((batched_loop_carry, cond_consts, body_consts))
lifted, size, (init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd),
init_val_bd)
return f.call_wrapped((batched_loop_carry, cond_tracer_consts, body_tracer_consts))
return while_loop(batched_cond_fun, batched_body_fun, init_val), init_val_bd