mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[run_state] add pjit run_state discharge rule and basic test
This commit is contained in:
parent
d1c5cdc2b1
commit
5715db4832
@ -61,6 +61,7 @@ from jax._src.sharding_impls import (
|
||||
AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
@ -1825,6 +1826,30 @@ def _pjit_pp_rule(eqn, context, settings):
|
||||
core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
||||
|
||||
|
||||
|
||||
def _pjit_state_discharge_rule(
|
||||
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params):
|
||||
if not (all(map(is_unspecified, in_shardings)) and
|
||||
all(map(is_unspecified, out_shardings))): raise NotImplementedError
|
||||
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
|
||||
num_outs = len(jaxpr.outvars)
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts)
|
||||
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
|
||||
new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars)
|
||||
new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars)
|
||||
out_and_ref_vals = pjit_p.bind(
|
||||
*args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings,
|
||||
out_shardings=new_out_shardings, **params)
|
||||
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
|
||||
ref_vals_iter = iter(ref_vals)
|
||||
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, state_discharge.AbstractRef)
|
||||
else None for aval in in_avals)
|
||||
sentinel = object()
|
||||
assert next(ref_vals_iter, sentinel) is sentinel
|
||||
return new_invals, out_vals
|
||||
state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule)
|
||||
|
||||
|
||||
# -------------------- with_sharding_constraint --------------------
|
||||
|
||||
def with_sharding_constraint(x, shardings):
|
||||
|
@ -251,7 +251,8 @@ def _closed_call_discharge_rule(
|
||||
ref_vals_iter = iter(ref_vals)
|
||||
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef)
|
||||
else None for aval in in_avals)
|
||||
assert next(ref_vals_iter, None) is None
|
||||
sentinel = object()
|
||||
assert next(ref_vals_iter, sentinel) is sentinel
|
||||
return new_invals, out_vals
|
||||
|
||||
# # `run_state`
|
||||
|
@ -1288,6 +1288,23 @@ class RunStateTest(jtu.JaxTestCase):
|
||||
self.assertEqual(x, 2 + 2 * 3 * 2)
|
||||
self.assertEqual(y, 2 * 3 * 2)
|
||||
|
||||
def test_nontrivial_run_state_jit(self):
|
||||
def f(refs):
|
||||
x_ref, y_ref = refs
|
||||
|
||||
@jax.jit
|
||||
def g():
|
||||
x = x_ref[...] * y_ref[...]
|
||||
y_ref[...] = x * 2
|
||||
x_ref[...] = y_ref[...] + x_ref[...]
|
||||
# x + x * y * 2, x * y * 2
|
||||
|
||||
g()
|
||||
|
||||
x, y = run_state(f)((2, 3))
|
||||
self.assertEqual(x, 2 + 2 * 3 * 2)
|
||||
self.assertEqual(y, 2 * 3 * 2)
|
||||
|
||||
def test_simple_run_state_with_multiple_refs(self):
|
||||
out1, out2 = run_state(lambda _: None)((1, 2))
|
||||
self.assertEqual(out1, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user