[run_state] add pjit run_state discharge rule and basic test

This commit is contained in:
Matthew Johnson 2023-10-04 12:57:17 -07:00
parent d1c5cdc2b1
commit 5715db4832
3 changed files with 44 additions and 1 deletions

View File

@ -61,6 +61,7 @@ from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.state import discharge as state_discharge
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
@ -1825,6 +1826,30 @@ def _pjit_pp_rule(eqn, context, settings):
core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
def _pjit_state_discharge_rule(
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params):
if not (all(map(is_unspecified, in_shardings)) and
all(map(is_unspecified, out_shardings))): raise NotImplementedError
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
num_outs = len(jaxpr.outvars)
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts)
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars)
new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars)
out_and_ref_vals = pjit_p.bind(
*args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings,
out_shardings=new_out_shardings, **params)
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
ref_vals_iter = iter(ref_vals)
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, state_discharge.AbstractRef)
else None for aval in in_avals)
sentinel = object()
assert next(ref_vals_iter, sentinel) is sentinel
return new_invals, out_vals
state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule)
# -------------------- with_sharding_constraint --------------------
def with_sharding_constraint(x, shardings):

View File

@ -251,7 +251,8 @@ def _closed_call_discharge_rule(
ref_vals_iter = iter(ref_vals)
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef)
else None for aval in in_avals)
assert next(ref_vals_iter, None) is None
sentinel = object()
assert next(ref_vals_iter, sentinel) is sentinel
return new_invals, out_vals
# # `run_state`

View File

@ -1288,6 +1288,23 @@ class RunStateTest(jtu.JaxTestCase):
self.assertEqual(x, 2 + 2 * 3 * 2)
self.assertEqual(y, 2 * 3 * 2)
def test_nontrivial_run_state_jit(self):
def f(refs):
x_ref, y_ref = refs
@jax.jit
def g():
x = x_ref[...] * y_ref[...]
y_ref[...] = x * 2
x_ref[...] = y_ref[...] + x_ref[...]
# x + x * y * 2, x * y * 2
g()
x, y = run_state(f)((2, 3))
self.assertEqual(x, 2 + 2 * 3 * 2)
self.assertEqual(y, 2 * 3 * 2)
def test_simple_run_state_with_multiple_refs(self):
out1, out2 = run_state(lambda _: None)((1, 2))
self.assertEqual(out1, 1)