mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Split loop cond/body constants into tracers and literal constants.
Only hoist tracer constants into loop-carried variables; leave non-tracers as constants. This change attempts to reduce the sometimes very large number of trivial loop-carried constants in some models.
This commit is contained in:
parent
cd1e4601c7
commit
098aedd66c
109
jax/lax.py
109
jax/lax.py
@ -836,6 +836,20 @@ def sort_key_val(keys, values, dimension=-1):
|
||||
sorted_keys, sorted_values = result
|
||||
return sorted_keys, sorted_values
|
||||
|
||||
|
||||
class _OpaqueParam(object):
|
||||
"""Wrapper that hashes on its identity, instead of its contents.
|
||||
|
||||
Used to pass unhashable parameters as primitive attributes."""
|
||||
__slots__ = ["val", "id"]
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
self.id = next(_opaque_param_ids)
|
||||
def __hash__(self):
|
||||
return self.id
|
||||
_opaque_param_ids = itertools.count()
|
||||
|
||||
|
||||
def while_loop(cond_fun, body_fun, init_val):
|
||||
"""Call `body_fun` repeatedly in a loop while `cond_fun` is True.
|
||||
|
||||
@ -875,12 +889,38 @@ def while_loop(cond_fun, body_fun, init_val):
|
||||
body_jaxpr, pval_out, body_consts = pe.trace_to_jaxpr(flat_body_fun, (pval_flat,))
|
||||
aval_out, _ = pval_out
|
||||
|
||||
# We don't want to promote literal constants as loop arguments; there are
|
||||
# sometimes many of them. We pass tracers as loop arguments, but leave
|
||||
# nontracers as constants. We also sort the constants so the nontracers are
|
||||
# first.
|
||||
def split_tracers_and_nontracers(jaxpr, consts):
|
||||
tracer = []
|
||||
nontracer = []
|
||||
for x in zip(jaxpr.constvars, consts):
|
||||
# TODO(phawkins): We avoid treating DeviceArrays as constant literals so
|
||||
# we don't copy large arrays back to the host. We probably should relax
|
||||
# this and either always copy small constants, or opportunistically use
|
||||
# DeviceArray values for which we already know npy_value.
|
||||
not_literal_const = (isinstance(x[1], core.Tracer) or
|
||||
isinstance(x[1], xla.DeviceArray))
|
||||
(tracer if not_literal_const else nontracer).append(x)
|
||||
tracer_vars, tracer_consts = unzip2(tracer)
|
||||
nontracer_vars, nontracer_consts = unzip2(nontracer)
|
||||
return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts
|
||||
|
||||
cond_split = split_tracers_and_nontracers(cond_jaxpr, cond_consts)
|
||||
cond_jaxpr.constvars, cond_nontracer_consts, cond_tracer_consts = cond_split
|
||||
body_split = split_tracers_and_nontracers(body_jaxpr, body_consts)
|
||||
body_jaxpr.constvars, body_nontracer_consts, body_tracer_consts = body_split
|
||||
|
||||
|
||||
if out_tree() != in_tree:
|
||||
raise TypeError("body_fun input and output must have identical structure")
|
||||
|
||||
out_flat = while_p.bind(init_val_flat, core.pack(cond_consts),
|
||||
core.pack(body_consts), aval_out=aval_out,
|
||||
cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
out_flat = while_p.bind(
|
||||
init_val_flat, core.pack(cond_tracer_consts), core.pack(body_tracer_consts),
|
||||
cond_consts=_OpaqueParam(cond_nontracer_consts),
|
||||
body_consts=_OpaqueParam(body_nontracer_consts),
|
||||
aval_out=aval_out, cond_jaxpr=cond_jaxpr, body_jaxpr=body_jaxpr)
|
||||
return build_tree(out_tree(), out_flat)
|
||||
|
||||
|
||||
@ -3687,13 +3727,15 @@ ad.primitive_transposes[sort_key_val_p] = _sort_key_val_transpose_rule
|
||||
batching.primitive_batchers[sort_key_val_p] = _sort_key_val_batch_rule
|
||||
|
||||
|
||||
def _while_loop_abstract_eval(init_val, cond_consts, body_consts, aval_out,
|
||||
def _while_loop_abstract_eval(init_val, cond_tracer_consts, body_tracer_consts,
|
||||
cond_consts, body_consts, aval_out,
|
||||
cond_jaxpr, body_jaxpr):
|
||||
return maybe_tracer_tuple_to_abstract_tuple(aval_out)
|
||||
|
||||
def _while_loop_translation_rule(c, init_val, cond_consts, body_consts,
|
||||
def _while_loop_translation_rule(c, init_val, cond_tracer_consts,
|
||||
body_tracer_consts, cond_consts, body_consts,
|
||||
aval_out, cond_jaxpr, body_jaxpr):
|
||||
loop_carry = c.Tuple(init_val, cond_consts, body_consts)
|
||||
loop_carry = c.Tuple(init_val, cond_tracer_consts, body_tracer_consts)
|
||||
shape = c.GetShape(loop_carry)
|
||||
|
||||
loop_carry_var = pe.Var(0, "loop_carry")
|
||||
@ -3701,33 +3743,37 @@ def _while_loop_translation_rule(c, init_val, cond_consts, body_consts,
|
||||
cond_var = pe.Var(0, "cond_consts")
|
||||
body_var = pe.Var(0, "body_consts")
|
||||
|
||||
num_cond_consts = len(cond_consts.val)
|
||||
assert len(cond_jaxpr.invars) == 1
|
||||
cond_jaxpr_converted = cond_jaxpr.copy()
|
||||
cond_jaxpr_converted.constvars = []
|
||||
cond_jaxpr_converted.constvars = cond_jaxpr.constvars[:num_cond_consts]
|
||||
cond_jaxpr_converted.invars = [loop_carry_var]
|
||||
cond_jaxpr_converted.eqns = (
|
||||
[_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]),
|
||||
_unpack_eqn(cond_var, cond_jaxpr.constvars)]
|
||||
_unpack_eqn(cond_var, cond_jaxpr.constvars[num_cond_consts:])]
|
||||
+ list(cond_jaxpr.eqns))
|
||||
|
||||
num_body_consts = len(body_consts.val)
|
||||
assert len(body_jaxpr.invars) == 1
|
||||
body_jaxpr_converted = body_jaxpr.copy()
|
||||
body_jaxpr_converted.constvars = []
|
||||
body_jaxpr_converted.constvars = body_jaxpr.constvars[:num_body_consts]
|
||||
body_jaxpr_converted.invars = [loop_carry_var]
|
||||
body_jaxpr_converted.outvar = outvar
|
||||
body_jaxpr_converted.eqns = (
|
||||
[_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]),
|
||||
_unpack_eqn(body_var, body_jaxpr.constvars)]
|
||||
_unpack_eqn(body_var, body_jaxpr.constvars[num_body_consts:])]
|
||||
+ list(body_jaxpr.eqns) +
|
||||
[_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)])
|
||||
|
||||
cond_computation = xla.jaxpr_computation(cond_jaxpr_converted, (), (), shape)
|
||||
body_computation = xla.jaxpr_computation(body_jaxpr_converted, (), (), shape)
|
||||
cond_computation = xla.jaxpr_computation(
|
||||
cond_jaxpr_converted, cond_consts.val, (), shape)
|
||||
body_computation = xla.jaxpr_computation(
|
||||
body_jaxpr_converted, body_consts.val, (), shape)
|
||||
full_ans = c.While(cond_computation, body_computation, loop_carry)
|
||||
return c.GetTupleElement(full_ans, 0)
|
||||
|
||||
def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
|
||||
body_jaxpr):
|
||||
def _while_loop_batching_rule(batched_args, batch_dims, cond_consts,
|
||||
body_consts, aval_out, cond_jaxpr, body_jaxpr):
|
||||
# See https://github.com/google/jax/issues/441 for a discussion.
|
||||
# To batch a while_loop, we need to do some masking, since the elements of the
|
||||
# batch may run for different numbers of iterations. We perform that masking
|
||||
@ -3736,15 +3782,15 @@ def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
|
||||
# The basic strategy here is to lift `cond_jaxpr` and `body_jaxpr` back into
|
||||
# traceable Python functions using `core.eval_jaxpr`. Then we can batch them
|
||||
# using `batching.batch_transform` (the transform underlying `api.vmap`). This
|
||||
# code also avoids broadcasting `cond_consts` and `body_consts`.
|
||||
init_val, cond_consts, body_consts = batched_args
|
||||
init_val_bd, cond_consts_bd, body_consts_bd = batch_dims
|
||||
# code also avoids broadcasting `cond_tracer_consts` and `body_tracer_consts`.
|
||||
init_val, cond_tracer_consts, body_tracer_consts = batched_args
|
||||
init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd = batch_dims
|
||||
|
||||
sizes = _reduce(set.union, map(batching.dimsize, batch_dims, batched_args))
|
||||
size = sizes.pop()
|
||||
assert not sizes
|
||||
|
||||
# TODO(mattjj): if cond_consts_bd is also None, we could keep cond_fun
|
||||
# TODO(mattjj): if cond_tracer_consts_bd is also None, we could keep cond_fun
|
||||
# unbatched and avoid the masking logic, but we ignore that optimization
|
||||
init_val = batching.bdim_at_front(init_val, init_val_bd, size,
|
||||
force_broadcast=True)
|
||||
@ -3752,21 +3798,28 @@ def _while_loop_batching_rule(batched_args, batch_dims, aval_out, cond_jaxpr,
|
||||
|
||||
def batched_cond_fun(batched_loop_carry):
|
||||
@lu.wrap_init
|
||||
def lifted(loop_carry, cond_consts):
|
||||
return core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)
|
||||
f = batching.batch_transform(lifted, size, (init_val_bd, cond_consts_bd), 0)
|
||||
preds = f.call_wrapped((batched_loop_carry, cond_consts))
|
||||
def lifted(loop_carry, cond_tracer_consts):
|
||||
cond_tracer_consts = tuple(x for x in cond_tracer_consts)
|
||||
return core.eval_jaxpr(
|
||||
cond_jaxpr, cond_consts.val + cond_tracer_consts, (), loop_carry)
|
||||
f = batching.batch_transform(lifted, size, (init_val_bd, cond_tracer_consts_bd), 0)
|
||||
preds = f.call_wrapped((batched_loop_carry, cond_tracer_consts))
|
||||
return reduce(preds, onp.array(False), bitwise_or, [0])
|
||||
|
||||
def batched_body_fun(batched_loop_carry):
|
||||
@lu.wrap_init
|
||||
def lifted(loop_carry, cond_consts, body_consts):
|
||||
pred = core.eval_jaxpr(cond_jaxpr, cond_consts, (), loop_carry)
|
||||
new_loop_carry = core.eval_jaxpr(body_jaxpr, body_consts, (), loop_carry)
|
||||
def lifted(loop_carry, cond_tracer_consts, body_tracer_consts):
|
||||
cond_tracer_consts = tuple(x for x in cond_tracer_consts)
|
||||
body_tracer_consts = tuple(x for x in body_tracer_consts)
|
||||
pred = core.eval_jaxpr(
|
||||
cond_jaxpr, cond_consts.val + cond_tracer_consts, (), loop_carry)
|
||||
new_loop_carry = core.eval_jaxpr(
|
||||
body_jaxpr, body_consts.val + body_tracer_consts, (), loop_carry)
|
||||
return _jaxtupletree_select(pred, new_loop_carry, loop_carry)
|
||||
f = batching.batch_transform(
|
||||
lifted, size, (init_val_bd, cond_consts_bd, body_consts_bd), init_val_bd)
|
||||
return f.call_wrapped((batched_loop_carry, cond_consts, body_consts))
|
||||
lifted, size, (init_val_bd, cond_tracer_consts_bd, body_tracer_consts_bd),
|
||||
init_val_bd)
|
||||
return f.call_wrapped((batched_loop_carry, cond_tracer_consts, body_tracer_consts))
|
||||
|
||||
return while_loop(batched_cond_fun, batched_body_fun, init_val), init_val_bd
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user