mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Introducing partial discharge rules and implementations for cond_p
As things stand you can partially discharge a jaxpr with `discharge_state(should_discharge=[...])` but each equation is discharges *all* its arguments. This means that primitives like `scan_p` and `cond_p` discharge all references they refer to (no pun intended) regardless of whether the user asked for it. We provide a special discharge rule that is preferred to the normal one when present that allows the op to discharge only some of the references. This feature is especially useful for pallas kernels because contrary to all other contexts where jaxprs are expected to eventually be fully discharged, pallas kernels lower references all the way to the runtime as pointers or MLIR memrefs. Here we implement the partial discharge rule for `cond_p` and will implement it for others in due course. PiperOrigin-RevId: 681021324
This commit is contained in:
parent
2228115cf4
commit
84fc011e27
@ -33,7 +33,7 @@ from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.state.discharge import register_discharge_rule, discharge_state
|
||||
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.interpreters import ad
|
||||
@ -854,19 +854,22 @@ def _cond_lowering(ctx, index, *args, branches):
|
||||
|
||||
mlir.register_lowering(cond_p, _cond_lowering)
|
||||
|
||||
@register_discharge_rule(cond_p)
|
||||
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches):
|
||||
@register_partial_discharge_rule(cond_p)
|
||||
def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches):
|
||||
assert not should_discharge[0], "Can't discharge the index."
|
||||
discharged_branches = tuple(
|
||||
core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ())
|
||||
core.ClosedJaxpr(
|
||||
discharge_state(branch.jaxpr, (),
|
||||
should_discharge=should_discharge[1:])[0], ())
|
||||
for branch in branches)
|
||||
out_vals = cond_p.bind(*args, branches=discharged_branches)
|
||||
out_vals = cond_p.bind(index, *args, branches=discharged_branches)
|
||||
out_vals, out_ref_vals = util.split_list(
|
||||
out_vals, [len(out_avals)])
|
||||
ref_val_iter = iter(out_ref_vals)
|
||||
new_invals = []
|
||||
for aval in in_avals:
|
||||
new_invals.append(
|
||||
next(ref_val_iter) if isinstance(aval, AbstractRef) else None)
|
||||
for should, aval in zip(should_discharge, in_avals):
|
||||
discharged_inval = isinstance(aval, AbstractRef) and should
|
||||
new_invals.append(next(ref_val_iter) if discharged_inval else None)
|
||||
return new_invals, out_vals
|
||||
|
||||
|
||||
|
@ -97,11 +97,35 @@ class DischargeRule(Protocol):
|
||||
|
||||
_discharge_rules: dict[core.Primitive, DischargeRule] = {}
|
||||
|
||||
class PartialDischargeRule(Protocol):
|
||||
"""A partial discharge rule.
|
||||
|
||||
Exactly like a discharge rule only it accepts a `should_discharge`
|
||||
argument that indicates which inputs should be discharged and the
|
||||
return value returns a tuple of which the first element is the new
|
||||
inputs or none but only the ones that correspond to `True` entries
|
||||
in `should_charge`.
|
||||
"""
|
||||
|
||||
def __call__(self, should_discharge: Sequence[bool],
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue], *args: Any,
|
||||
**params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]:
|
||||
...
|
||||
|
||||
_partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {}
|
||||
|
||||
def register_discharge_rule(prim: core.Primitive):
|
||||
def register(f: DischargeRule):
|
||||
_discharge_rules[prim] = f
|
||||
return register
|
||||
|
||||
def register_partial_discharge_rule(prim: core.Primitive):
|
||||
def register(f: PartialDischargeRule):
|
||||
_partial_discharge_rules[prim] = f
|
||||
return register
|
||||
|
||||
|
||||
def _eval_jaxpr_discharge_state(
|
||||
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
|
||||
*args: Any):
|
||||
@ -116,22 +140,33 @@ def _eval_jaxpr_discharge_state(
|
||||
if d and isinstance(v.aval, AbstractRef)}
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars]
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
[invar], [outvar] = eqn.invars, eqn.outvars
|
||||
ans = env.read(invar)
|
||||
refs_to_discharge.add(id(outvar.aval))
|
||||
elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars)
|
||||
or core.internal_mutable_array_effect in eqn.effects ):
|
||||
if eqn.primitive not in _discharge_rules:
|
||||
elif (any(should_discharge)
|
||||
or core.internal_mutable_array_effect in eqn.effects
|
||||
):
|
||||
if eqn.primitive in _partial_discharge_rules:
|
||||
rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge)
|
||||
elif eqn.primitive in _discharge_rules:
|
||||
rule = _discharge_rules[eqn.primitive]
|
||||
else:
|
||||
raise NotImplementedError("No state discharge rule implemented for "
|
||||
f"primitive: {eqn.primitive}")
|
||||
invals = map(env.read, eqn.invars)
|
||||
in_avals = [v.aval for v in eqn.invars]
|
||||
out_avals = [v.aval for v in eqn.outvars]
|
||||
new_invals, ans = _discharge_rules[eqn.primitive](
|
||||
new_invals, ans = rule(
|
||||
in_avals, out_avals, *invals, **eqn.params)
|
||||
for new_inval, invar in zip(new_invals, eqn.invars):
|
||||
for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals):
|
||||
if new_inval is not None:
|
||||
if not should:
|
||||
raise ValueError(
|
||||
f"Did not ask for inval to be discharged but it was. ({invar=},"
|
||||
f" {new_inval=})"
|
||||
)
|
||||
env.write(invar, new_inval) # type: ignore[arg-type]
|
||||
else:
|
||||
# Default primitive rule, similar to `core.eval_jaxpr`. Note that here
|
||||
|
@ -739,6 +739,20 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
in_avals = [shaped_array_ref((), jnp.float32)]
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)
|
||||
|
||||
def test_partial_discharge(self):
|
||||
def f(a_ref, b_ref):
|
||||
a_ref[...] = jnp.array(0., dtype=jnp.float32)
|
||||
b_ref[...] = jnp.array(1., dtype=jnp.float32)
|
||||
return a_ref[...], b_ref[...]
|
||||
|
||||
scalar_ref = shaped_array_ref((), jnp.float32)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [scalar_ref, scalar_ref])
|
||||
|
||||
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
|
||||
prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns)
|
||||
self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr))
|
||||
self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))
|
||||
|
||||
if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@ -1061,6 +1075,27 @@ class StateControlFlowTest(jtu.JaxTestCase):
|
||||
out = jax.jit(f)(False)
|
||||
self.assertTupleEqual(out, (0., 5.))
|
||||
|
||||
def test_cond_discharge(self):
|
||||
def f0(pred, x_ref, y_ref):
|
||||
def true_fun():
|
||||
x_ref[...] = 1.
|
||||
def false_fun():
|
||||
y_ref[...] = 2.
|
||||
lax.cond(pred, true_fun, false_fun)
|
||||
return x_ref[...], y_ref[...]
|
||||
ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
|
||||
f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.))
|
||||
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True])
|
||||
# Effects on y_ref were discharged away but not the effects on x_ref
|
||||
self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)})
|
||||
self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)})
|
||||
# x_ref arg is still a reference but y_ref is discharged
|
||||
self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef)
|
||||
self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef)
|
||||
# x_ref value is returned as part of the discharged refs set.
|
||||
self.assertLen(f_jaxpr.out_avals, 2)
|
||||
self.assertLen(jaxpr.outvars, 3)
|
||||
|
||||
def test_cond_with_ref_reuse(self):
|
||||
def f(pred):
|
||||
def body(x_ref):
|
||||
|
Loading…
x
Reference in New Issue
Block a user