Introducing partial discharge rules and implementations for cond_p

As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.

This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.

Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.

PiperOrigin-RevId: 681021324
This commit is contained in:
Christos Perivolaropoulos 2024-10-01 08:03:11 -07:00 committed by jax authors
parent 2228115cf4
commit 84fc011e27
3 changed files with 86 additions and 13 deletions

View File

@ -33,7 +33,7 @@ from jax._src import effects
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import util
from jax._src.state.discharge import register_discharge_rule, discharge_state
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
@ -854,19 +854,22 @@ def _cond_lowering(ctx, index, *args, branches):
mlir.register_lowering(cond_p, _cond_lowering)
@register_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches):
@register_partial_discharge_rule(cond_p)
def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches):
assert not should_discharge[0], "Can't discharge the index."
discharged_branches = tuple(
core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ())
core.ClosedJaxpr(
discharge_state(branch.jaxpr, (),
should_discharge=should_discharge[1:])[0], ())
for branch in branches)
out_vals = cond_p.bind(*args, branches=discharged_branches)
out_vals = cond_p.bind(index, *args, branches=discharged_branches)
out_vals, out_ref_vals = util.split_list(
out_vals, [len(out_avals)])
ref_val_iter = iter(out_ref_vals)
new_invals = []
for aval in in_avals:
new_invals.append(
next(ref_val_iter) if isinstance(aval, AbstractRef) else None)
for should, aval in zip(should_discharge, in_avals):
discharged_inval = isinstance(aval, AbstractRef) and should
new_invals.append(next(ref_val_iter) if discharged_inval else None)
return new_invals, out_vals

View File

@ -97,11 +97,35 @@ class DischargeRule(Protocol):
_discharge_rules: dict[core.Primitive, DischargeRule] = {}
class PartialDischargeRule(Protocol):
"""A partial discharge rule.
Exactly like a discharge rule only it accepts a `should_discharge`
argument that indicates which inputs should be discharged and the
return value returns a tuple of which the first element is the new
inputs or none but only the ones that correspond to `True` entries
in `should_charge`.
"""
def __call__(self, should_discharge: Sequence[bool],
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], *args: Any,
**params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]:
...
_partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {}
def register_discharge_rule(prim: core.Primitive):
def register(f: DischargeRule):
_discharge_rules[prim] = f
return register
def register_partial_discharge_rule(prim: core.Primitive):
def register(f: PartialDischargeRule):
_partial_discharge_rules[prim] = f
return register
def _eval_jaxpr_discharge_state(
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
*args: Any):
@ -116,22 +140,33 @@ def _eval_jaxpr_discharge_state(
if d and isinstance(v.aval, AbstractRef)}
for eqn in jaxpr.eqns:
should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars]
if eqn.primitive is core.mutable_array_p:
[invar], [outvar] = eqn.invars, eqn.outvars
ans = env.read(invar)
refs_to_discharge.add(id(outvar.aval))
elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars)
or core.internal_mutable_array_effect in eqn.effects ):
if eqn.primitive not in _discharge_rules:
elif (any(should_discharge)
or core.internal_mutable_array_effect in eqn.effects
):
if eqn.primitive in _partial_discharge_rules:
rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge)
elif eqn.primitive in _discharge_rules:
rule = _discharge_rules[eqn.primitive]
else:
raise NotImplementedError("No state discharge rule implemented for "
f"primitive: {eqn.primitive}")
invals = map(env.read, eqn.invars)
in_avals = [v.aval for v in eqn.invars]
out_avals = [v.aval for v in eqn.outvars]
new_invals, ans = _discharge_rules[eqn.primitive](
new_invals, ans = rule(
in_avals, out_avals, *invals, **eqn.params)
for new_inval, invar in zip(new_invals, eqn.invars):
for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals):
if new_inval is not None:
if not should:
raise ValueError(
f"Did not ask for inval to be discharged but it was. ({invar=},"
f" {new_inval=})"
)
env.write(invar, new_inval) # type: ignore[arg-type]
else:
# Default primitive rule, similar to `core.eval_jaxpr`. Note that here

View File

@ -739,6 +739,20 @@ class StateDischargeTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.float32)]
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)
def test_partial_discharge(self):
def f(a_ref, b_ref):
a_ref[...] = jnp.array(0., dtype=jnp.float32)
b_ref[...] = jnp.array(1., dtype=jnp.float32)
return a_ref[...], b_ref[...]
scalar_ref = shaped_array_ref((), jnp.float32)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [scalar_ref, scalar_ref])
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns)
self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr))
self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))
if CAN_USE_HYPOTHESIS:
@ -1061,6 +1075,27 @@ class StateControlFlowTest(jtu.JaxTestCase):
out = jax.jit(f)(False)
self.assertTupleEqual(out, (0., 5.))
def test_cond_discharge(self):
def f0(pred, x_ref, y_ref):
def true_fun():
x_ref[...] = 1.
def false_fun():
y_ref[...] = 2.
lax.cond(pred, true_fun, false_fun)
return x_ref[...], y_ref[...]
ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.))
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True])
# Effects on y_ref were discharged away but not the effects on x_ref
self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)})
self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)})
# x_ref arg is still a reference but y_ref is discharged
self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef)
self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef)
# x_ref value is returned as part of the discharged refs set.
self.assertLen(f_jaxpr.out_avals, 2)
self.assertLen(jaxpr.outvars, 3)
def test_cond_with_ref_reuse(self):
def f(pred):
def body(x_ref):