add pjit forwarding rule

Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2024-05-24 01:14:16 +00:00
parent c1f5a32875
commit 0a693faf48
7 changed files with 99 additions and 31 deletions

View File

@ -1679,7 +1679,11 @@ class DynamicJaxprTracer(core.Tracer):
self.aval = aval
def full_lower(self):
return self
var = self._trace.frame.tracer_to_var.get(id(self))
if var is None: return self
val = self._trace.frame.constvar_to_val.get(var)
if val is None: return self
return core.full_lower(val)
def _contents(self):
return ()
@ -1874,7 +1878,6 @@ def _const_folding_and_forwarding(
# if the application trivially maps some inputs to outputs, simplify
if eqn.primitive in forwarding_rules and not has_input_effect:
fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn)
assert (new_eqn is None) == all(v is not None for v in fwd_vars)
for v_orig, v_new in zip(eqn.outvars, fwd_vars):
if v_new is not None: var_subs[v_orig] = v_new
if new_eqn is None: continue

View File

@ -1578,24 +1578,30 @@ def _pjit_lower_cached(
def pjit_staging_rule(trace, *args, **params):
jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
params['jaxpr'], params['out_shardings'], params['out_layouts'])
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
out_layouts=out_layouts)
if (params["inline"] and
all(is_unspecified(i) for i in params["in_shardings"]) and
all(is_unspecified(o) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
jaxpr = params['jaxpr']
if config.dynamic_shapes.value:
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
# shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
# but redundantly performs abstract evaluation again.
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
propagate_source_info=False)
out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
propagate_source_info=False)
else:
return pe.inline_jaxpr_into_trace(trace, jaxpr.jaxpr, jaxpr.consts, *args)
out_tracers = pe.inline_jaxpr_into_trace(
trace, jaxpr.jaxpr, jaxpr.consts, *args)
elif config.dynamic_shapes.value:
source_info = source_info_util.current()
out_tracers = []
for aval in _out_type(params['jaxpr']):
for aval in _out_type(jaxpr):
if type(aval) is core.DShapedArray:
shape = [args[d.val] if type(d) is core.InDBIdx else
out_tracers[d.val] if type(d) is core.OutDBIdx else
@ -1604,23 +1610,51 @@ def pjit_staging_rule(trace, *args, **params):
out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info))
eqn = core.new_jaxpr_eqn(
map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params,
params['jaxpr'].effects, source_info)
jaxpr.effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
elif any(isinstance(c, core.MutableArray) for c in params['jaxpr'].consts):
jaxpr, consts = pxla._move_mutable_consts(params['jaxpr'])
elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
jaxpr, consts = pxla._move_mutable_consts(jaxpr)
consts = map(trace.instantiate_const, consts)
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
in_layouts = (*params['in_layouts'],) + (None,) * len(consts)
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
in_layouts=in_layouts, donated_invars=donated_invars)
return trace.default_process_primitive(pjit_p, (*args, *consts), new_params)
out_tracers = trace.default_process_primitive(
pjit_p, (*args, *consts), new_params)
else:
return trace.default_process_primitive(pjit_p, args, params)
out_tracers = trace.default_process_primitive(pjit_p, args, params)
out_tracers_ = iter(out_tracers)
out_tracers = [args[f] if type(f) is int else next(out_tracers_)
for f in in_fwd]
assert next(out_tracers_, None) is None
return out_tracers
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
def _pjit_forwarding(jaxpr, out_shardings, out_layouts):
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr)
in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol
in zip(in_fwd, out_shardings, out_layouts)]
keep = [f is None for f in in_fwd]
jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep)
out_shardings = [o for o, k in zip(out_shardings, keep) if k]
out_layouts = [o for o, k in zip(out_layouts , keep) if k]
return jaxpr, in_fwd, out_shardings, out_layouts
def pjit_forwarding_rule(eqn):
jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts'])
new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None]
new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,),
out_layouts=(*out_layouts,))
new_eqn = eqn.replace(params=new_params, outvars=new_outvars)
fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd]
return fwd_vars, new_eqn
pe.forwarding_rules[pjit_p] = pjit_forwarding_rule
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
# since it's actually not possible in general to infer the type from the term
def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:

View File

@ -4677,6 +4677,37 @@ class APITest(jtu.JaxTestCase):
jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1,
modes=['rev'], atol=1e-3, rtol=1e-3)
@jtu.run_on_devices("cpu")
def test_inner_jit_forwarding_happens(self):
jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))()
self.assertLen(jaxpr.jaxpr.outvars, 1)
self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal)
self.assertEqual(jaxpr.jaxpr.outvars[0].val, 3)
@parameterized.parameters(range(8))
@jtu.run_on_devices("cpu")
def test_inner_jit_forwarding_correctness(self, num_input_fwd):
num_args = 8
rng = np.random.RandomState(0)
@jax.jit
def f(inputs):
inputs = [inputs[i] for i in rng.permutation(num_args)]
outputs = (inputs[:num_input_fwd] +
[jnp.sin(inputs[i]) for i in range(num_args - num_input_fwd)])
return [outputs[i] for i in rng.permutation(num_args)]
f2 = jax.jit(f)
inputs = list(jnp.arange(float(num_args)))
expected = f(inputs)
ans = f2(inputs)
for a, b in zip(ans, expected):
self.assertAllClose(a, b)
def test_inner_jit_forwarded_consts_stay_const(self):
out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash
self.assertEqual(out, 3)
class RematTest(jtu.JaxTestCase):

View File

@ -232,7 +232,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase):
def f(n):
m = 2 * n
x = jnp.zeros(m)
return jax.jit(lambda: x)()
return jax.jit(jnp.sin)(x)
# { lambda ; a:i32[]. let
# b:i32[] = mul a 2
@ -517,7 +517,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase):
self.assertIs(e.aval.shape[0], d)
def test_jit_abstracted_axes_return_polymorphic_shape(self):
f = jax.jit(lambda x: x, abstracted_axes=('n',))
f = jax.jit(lambda x: jnp.sin(x), abstracted_axes=('n',))
jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash
# { lambda ; a:i32[3]. let
# b:i32[3] = pjit[

View File

@ -61,8 +61,8 @@ class JaxprStatsTest(jtu.JaxTestCase):
def test_primitives_by_shape(self):
def f(x, y):
def sub(x, y):
return jnp.sum(jnp.array([x, y])), y
s, _ = jit(sub)(x, y)
return jnp.sum(jnp.array([x, y]))
s = jit(sub)(x, y)
return jnp.sin(s) + jnp.cos(y)
hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)
@ -74,7 +74,7 @@ class JaxprStatsTest(jtu.JaxTestCase):
f'cos :: float{t}[]',
f'reduce_sum :: float{t}[]',
f'concatenate :: float{t}[2]',
f'pjit :: float{t}[] *',
f'pjit :: float{t}[]',
]
for k in shapes:
self.assertEqual(hist[k], 1)

View File

@ -1282,22 +1282,22 @@ class PJitTest(jtu.BufferDonationTestCase):
y = jnp.array([4.2, 2.4], dtype=jnp.float32)
jaxpr = jax.make_jaxpr(g)(x, y)
self.assertEqual(
jaxpr.pretty_print(),
jaxpr.pretty_print(use_color=False),
textwrap.dedent("""
let f = { lambda ; a:f32[1]. let in (a,) } in
let f1 = { lambda ; b:f32[2]. let in (b,) } in
let f = { lambda ; a:f32[1]. let in () } in
let f1 = { lambda ; b:f32[2]. let in () } in
{ lambda ; c:f32[1] d:f32[2]. let
e:f32[2] = pjit[
name=g
jaxpr={ lambda ; g:f32[1] h:f32[2]. let
i:f32[1] = pjit[name=f jaxpr=f] g
j:f32[1] = pjit[name=f jaxpr=f] g
k:f32[1] = mul i j
l:f32[2] = pjit[name=f jaxpr=f1] h
m:f32[2] = pjit[name=f jaxpr=f1] h
n:f32[2] = mul l m
o:f32[2] = add k n
in (o,) }
pjit[name=f jaxpr=f] g
pjit[name=f jaxpr=f] g
i:f32[1] = mul g g
pjit[name=f jaxpr=f1] h
pjit[name=f jaxpr=f1] h
j:f32[2] = mul h h
k:f32[2] = add i j
in (k,) }
] c d
in (e,) }
""").strip(),

View File

@ -1353,13 +1353,13 @@ class ShardMapTest(jtu.JaxTestCase):
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def g(x):
return jax.jit(lambda x: x)(x)
return jax.jit(lambda x: 1. * x)(x)
jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.)
e, = jaxpr.jaxpr.eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEmpty(e1.outvars)
self.assertEmpty(e2.params['jaxpr'].eqns)
self.assertLen(e2.params['jaxpr'].eqns, 1)
def test_fanout_specs_transpose_to_psum(self):
mesh = jtu.create_global_mesh((4,), ('x',))