mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
add pjit forwarding rule
Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
c1f5a32875
commit
0a693faf48
@ -1679,7 +1679,11 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
self.aval = aval
|
||||
|
||||
def full_lower(self):
|
||||
return self
|
||||
var = self._trace.frame.tracer_to_var.get(id(self))
|
||||
if var is None: return self
|
||||
val = self._trace.frame.constvar_to_val.get(var)
|
||||
if val is None: return self
|
||||
return core.full_lower(val)
|
||||
|
||||
def _contents(self):
|
||||
return ()
|
||||
@ -1874,7 +1878,6 @@ def _const_folding_and_forwarding(
|
||||
# if the application trivially maps some inputs to outputs, simplify
|
||||
if eqn.primitive in forwarding_rules and not has_input_effect:
|
||||
fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn)
|
||||
assert (new_eqn is None) == all(v is not None for v in fwd_vars)
|
||||
for v_orig, v_new in zip(eqn.outvars, fwd_vars):
|
||||
if v_new is not None: var_subs[v_orig] = v_new
|
||||
if new_eqn is None: continue
|
||||
|
@ -1578,24 +1578,30 @@ def _pjit_lower_cached(
|
||||
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
|
||||
params['jaxpr'], params['out_shardings'], params['out_layouts'])
|
||||
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
|
||||
out_layouts=out_layouts)
|
||||
|
||||
if (params["inline"] and
|
||||
all(is_unspecified(i) for i in params["in_shardings"]) and
|
||||
all(is_unspecified(o) for o in params["out_shardings"]) and
|
||||
all(i is None for i in params["in_layouts"]) and
|
||||
all(o is None for o in params["out_layouts"])):
|
||||
jaxpr = params['jaxpr']
|
||||
|
||||
if config.dynamic_shapes.value:
|
||||
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
|
||||
# shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
|
||||
# but redundantly performs abstract evaluation again.
|
||||
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
|
||||
propagate_source_info=False)
|
||||
out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
|
||||
propagate_source_info=False)
|
||||
else:
|
||||
return pe.inline_jaxpr_into_trace(trace, jaxpr.jaxpr, jaxpr.consts, *args)
|
||||
out_tracers = pe.inline_jaxpr_into_trace(
|
||||
trace, jaxpr.jaxpr, jaxpr.consts, *args)
|
||||
elif config.dynamic_shapes.value:
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = []
|
||||
for aval in _out_type(params['jaxpr']):
|
||||
for aval in _out_type(jaxpr):
|
||||
if type(aval) is core.DShapedArray:
|
||||
shape = [args[d.val] if type(d) is core.InDBIdx else
|
||||
out_tracers[d.val] if type(d) is core.OutDBIdx else
|
||||
@ -1604,23 +1610,51 @@ def pjit_staging_rule(trace, *args, **params):
|
||||
out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info))
|
||||
eqn = core.new_jaxpr_eqn(
|
||||
map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params,
|
||||
params['jaxpr'].effects, source_info)
|
||||
jaxpr.effects, source_info)
|
||||
trace.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
elif any(isinstance(c, core.MutableArray) for c in params['jaxpr'].consts):
|
||||
jaxpr, consts = pxla._move_mutable_consts(params['jaxpr'])
|
||||
elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
|
||||
jaxpr, consts = pxla._move_mutable_consts(jaxpr)
|
||||
consts = map(trace.instantiate_const, consts)
|
||||
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
|
||||
in_layouts = (*params['in_layouts'],) + (None,) * len(consts)
|
||||
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
|
||||
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
in_layouts=in_layouts, donated_invars=donated_invars)
|
||||
return trace.default_process_primitive(pjit_p, (*args, *consts), new_params)
|
||||
out_tracers = trace.default_process_primitive(
|
||||
pjit_p, (*args, *consts), new_params)
|
||||
else:
|
||||
return trace.default_process_primitive(pjit_p, args, params)
|
||||
out_tracers = trace.default_process_primitive(pjit_p, args, params)
|
||||
|
||||
out_tracers_ = iter(out_tracers)
|
||||
out_tracers = [args[f] if type(f) is int else next(out_tracers_)
|
||||
for f in in_fwd]
|
||||
assert next(out_tracers_, None) is None
|
||||
return out_tracers
|
||||
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
||||
|
||||
def _pjit_forwarding(jaxpr, out_shardings, out_layouts):
|
||||
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr)
|
||||
in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol
|
||||
in zip(in_fwd, out_shardings, out_layouts)]
|
||||
keep = [f is None for f in in_fwd]
|
||||
jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep)
|
||||
out_shardings = [o for o, k in zip(out_shardings, keep) if k]
|
||||
out_layouts = [o for o, k in zip(out_layouts , keep) if k]
|
||||
return jaxpr, in_fwd, out_shardings, out_layouts
|
||||
|
||||
def pjit_forwarding_rule(eqn):
|
||||
jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
|
||||
eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts'])
|
||||
new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None]
|
||||
new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,),
|
||||
out_layouts=(*out_layouts,))
|
||||
new_eqn = eqn.replace(params=new_params, outvars=new_outvars)
|
||||
fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd]
|
||||
return fwd_vars, new_eqn
|
||||
pe.forwarding_rules[pjit_p] = pjit_forwarding_rule
|
||||
|
||||
|
||||
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
|
||||
# since it's actually not possible in general to infer the type from the term
|
||||
def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:
|
||||
|
@ -4677,6 +4677,37 @@ class APITest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1,
|
||||
modes=['rev'], atol=1e-3, rtol=1e-3)
|
||||
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_inner_jit_forwarding_happens(self):
|
||||
jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))()
|
||||
self.assertLen(jaxpr.jaxpr.outvars, 1)
|
||||
self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal)
|
||||
self.assertEqual(jaxpr.jaxpr.outvars[0].val, 3)
|
||||
|
||||
@parameterized.parameters(range(8))
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_inner_jit_forwarding_correctness(self, num_input_fwd):
|
||||
num_args = 8
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
@jax.jit
|
||||
def f(inputs):
|
||||
inputs = [inputs[i] for i in rng.permutation(num_args)]
|
||||
outputs = (inputs[:num_input_fwd] +
|
||||
[jnp.sin(inputs[i]) for i in range(num_args - num_input_fwd)])
|
||||
return [outputs[i] for i in rng.permutation(num_args)]
|
||||
|
||||
f2 = jax.jit(f)
|
||||
inputs = list(jnp.arange(float(num_args)))
|
||||
expected = f(inputs)
|
||||
ans = f2(inputs)
|
||||
for a, b in zip(ans, expected):
|
||||
self.assertAllClose(a, b)
|
||||
|
||||
def test_inner_jit_forwarded_consts_stay_const(self):
|
||||
out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash
|
||||
self.assertEqual(out, 3)
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -232,7 +232,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase):
|
||||
def f(n):
|
||||
m = 2 * n
|
||||
x = jnp.zeros(m)
|
||||
return jax.jit(lambda: x)()
|
||||
return jax.jit(jnp.sin)(x)
|
||||
|
||||
# { lambda ; a:i32[]. let
|
||||
# b:i32[] = mul a 2
|
||||
@ -517,7 +517,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase):
|
||||
self.assertIs(e.aval.shape[0], d)
|
||||
|
||||
def test_jit_abstracted_axes_return_polymorphic_shape(self):
|
||||
f = jax.jit(lambda x: x, abstracted_axes=('n',))
|
||||
f = jax.jit(lambda x: jnp.sin(x), abstracted_axes=('n',))
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash
|
||||
# { lambda ; a:i32[3]. let
|
||||
# b:i32[3] = pjit[
|
||||
|
@ -61,8 +61,8 @@ class JaxprStatsTest(jtu.JaxTestCase):
|
||||
def test_primitives_by_shape(self):
|
||||
def f(x, y):
|
||||
def sub(x, y):
|
||||
return jnp.sum(jnp.array([x, y])), y
|
||||
s, _ = jit(sub)(x, y)
|
||||
return jnp.sum(jnp.array([x, y]))
|
||||
s = jit(sub)(x, y)
|
||||
return jnp.sin(s) + jnp.cos(y)
|
||||
|
||||
hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)
|
||||
@ -74,7 +74,7 @@ class JaxprStatsTest(jtu.JaxTestCase):
|
||||
f'cos :: float{t}[]',
|
||||
f'reduce_sum :: float{t}[]',
|
||||
f'concatenate :: float{t}[2]',
|
||||
f'pjit :: float{t}[] *',
|
||||
f'pjit :: float{t}[]',
|
||||
]
|
||||
for k in shapes:
|
||||
self.assertEqual(hist[k], 1)
|
||||
|
@ -1282,22 +1282,22 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
y = jnp.array([4.2, 2.4], dtype=jnp.float32)
|
||||
jaxpr = jax.make_jaxpr(g)(x, y)
|
||||
self.assertEqual(
|
||||
jaxpr.pretty_print(),
|
||||
jaxpr.pretty_print(use_color=False),
|
||||
textwrap.dedent("""
|
||||
let f = { lambda ; a:f32[1]. let in (a,) } in
|
||||
let f1 = { lambda ; b:f32[2]. let in (b,) } in
|
||||
let f = { lambda ; a:f32[1]. let in () } in
|
||||
let f1 = { lambda ; b:f32[2]. let in () } in
|
||||
{ lambda ; c:f32[1] d:f32[2]. let
|
||||
e:f32[2] = pjit[
|
||||
name=g
|
||||
jaxpr={ lambda ; g:f32[1] h:f32[2]. let
|
||||
i:f32[1] = pjit[name=f jaxpr=f] g
|
||||
j:f32[1] = pjit[name=f jaxpr=f] g
|
||||
k:f32[1] = mul i j
|
||||
l:f32[2] = pjit[name=f jaxpr=f1] h
|
||||
m:f32[2] = pjit[name=f jaxpr=f1] h
|
||||
n:f32[2] = mul l m
|
||||
o:f32[2] = add k n
|
||||
in (o,) }
|
||||
pjit[name=f jaxpr=f] g
|
||||
pjit[name=f jaxpr=f] g
|
||||
i:f32[1] = mul g g
|
||||
pjit[name=f jaxpr=f1] h
|
||||
pjit[name=f jaxpr=f1] h
|
||||
j:f32[2] = mul h h
|
||||
k:f32[2] = add i j
|
||||
in (k,) }
|
||||
] c d
|
||||
in (e,) }
|
||||
""").strip(),
|
||||
|
@ -1353,13 +1353,13 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def g(x):
|
||||
return jax.jit(lambda x: x)(x)
|
||||
return jax.jit(lambda x: 1. * x)(x)
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.)
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEmpty(e1.outvars)
|
||||
self.assertEmpty(e2.params['jaxpr'].eqns)
|
||||
self.assertLen(e2.params['jaxpr'].eqns, 1)
|
||||
|
||||
def test_fanout_specs_transpose_to_psum(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user