diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 64c4c0009..23122784f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1679,7 +1679,11 @@ class DynamicJaxprTracer(core.Tracer): self.aval = aval def full_lower(self): - return self + var = self._trace.frame.tracer_to_var.get(id(self)) + if var is None: return self + val = self._trace.frame.constvar_to_val.get(var) + if val is None: return self + return core.full_lower(val) def _contents(self): return () @@ -1874,7 +1878,6 @@ def _const_folding_and_forwarding( # if the application trivially maps some inputs to outputs, simplify if eqn.primitive in forwarding_rules and not has_input_effect: fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn) - assert (new_eqn is None) == all(v is not None for v in fwd_vars) for v_orig, v_new in zip(eqn.outvars, fwd_vars): if v_new is not None: var_subs[v_orig] = v_new if new_eqn is None: continue diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index c3d5022bf..1a30df3fc 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1578,24 +1578,30 @@ def _pjit_lower_cached( def pjit_staging_rule(trace, *args, **params): + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + params['jaxpr'], params['out_shardings'], params['out_layouts']) + params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) + if (params["inline"] and all(is_unspecified(i) for i in params["in_shardings"]) and all(is_unspecified(o) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): - jaxpr = params['jaxpr'] + if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, # but redundantly performs abstract evaluation again. - return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) + out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, + propagate_source_info=False) else: - return pe.inline_jaxpr_into_trace(trace, jaxpr.jaxpr, jaxpr.consts, *args) + out_tracers = pe.inline_jaxpr_into_trace( + trace, jaxpr.jaxpr, jaxpr.consts, *args) elif config.dynamic_shapes.value: source_info = source_info_util.current() out_tracers = [] - for aval in _out_type(params['jaxpr']): + for aval in _out_type(jaxpr): if type(aval) is core.DShapedArray: shape = [args[d.val] if type(d) is core.InDBIdx else out_tracers[d.val] if type(d) is core.OutDBIdx else @@ -1604,23 +1610,51 @@ def pjit_staging_rule(trace, *args, **params): out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info)) eqn = core.new_jaxpr_eqn( map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, - params['jaxpr'].effects, source_info) + jaxpr.effects, source_info) trace.frame.add_eqn(eqn) - return out_tracers - elif any(isinstance(c, core.MutableArray) for c in params['jaxpr'].consts): - jaxpr, consts = pxla._move_mutable_consts(params['jaxpr']) + elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): + jaxpr, consts = pxla._move_mutable_consts(jaxpr) consts = map(trace.instantiate_const, consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings, in_layouts=in_layouts, donated_invars=donated_invars) - return trace.default_process_primitive(pjit_p, (*args, *consts), new_params) + out_tracers = trace.default_process_primitive( + pjit_p, (*args, *consts), new_params) else: - return trace.default_process_primitive(pjit_p, args, params) + out_tracers = trace.default_process_primitive(pjit_p, args, params) + + out_tracers_ = iter(out_tracers) + out_tracers = [args[f] if type(f) is int else next(out_tracers_) + for f in in_fwd] + assert next(out_tracers_, None) is None + return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule +def _pjit_forwarding(jaxpr, out_shardings, out_layouts): + in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) + in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol + in zip(in_fwd, out_shardings, out_layouts)] + keep = [f is None for f in in_fwd] + jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) + out_shardings = [o for o, k in zip(out_shardings, keep) if k] + out_layouts = [o for o, k in zip(out_layouts , keep) if k] + return jaxpr, in_fwd, out_shardings, out_layouts + +def pjit_forwarding_rule(eqn): + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) + new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] + new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,), + out_layouts=(*out_layouts,)) + new_eqn = eqn.replace(params=new_params, outvars=new_outvars) + fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] + return fwd_vars, new_eqn +pe.forwarding_rules[pjit_p] = pjit_forwarding_rule + + # TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them, # since it's actually not possible in general to infer the type from the term def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]: diff --git a/tests/api_test.py b/tests/api_test.py index 9810e5008..71dba9535 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4677,6 +4677,37 @@ class APITest(jtu.JaxTestCase): jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1, modes=['rev'], atol=1e-3, rtol=1e-3) + @jtu.run_on_devices("cpu") + def test_inner_jit_forwarding_happens(self): + jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() + self.assertLen(jaxpr.jaxpr.outvars, 1) + self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) + self.assertEqual(jaxpr.jaxpr.outvars[0].val, 3) + + @parameterized.parameters(range(8)) + @jtu.run_on_devices("cpu") + def test_inner_jit_forwarding_correctness(self, num_input_fwd): + num_args = 8 + rng = np.random.RandomState(0) + + @jax.jit + def f(inputs): + inputs = [inputs[i] for i in rng.permutation(num_args)] + outputs = (inputs[:num_input_fwd] + + [jnp.sin(inputs[i]) for i in range(num_args - num_input_fwd)]) + return [outputs[i] for i in rng.permutation(num_args)] + + f2 = jax.jit(f) + inputs = list(jnp.arange(float(num_args))) + expected = f(inputs) + ans = f2(inputs) + for a, b in zip(ans, expected): + self.assertAllClose(a, b) + + def test_inner_jit_forwarded_consts_stay_const(self): + out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash + self.assertEqual(out, 3) + class RematTest(jtu.JaxTestCase): diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 9d5e9f671..e75b7c4c9 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -232,7 +232,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase): def f(n): m = 2 * n x = jnp.zeros(m) - return jax.jit(lambda: x)() + return jax.jit(jnp.sin)(x) # { lambda ; a:i32[]. let # b:i32[] = mul a 2 @@ -517,7 +517,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase): self.assertIs(e.aval.shape[0], d) def test_jit_abstracted_axes_return_polymorphic_shape(self): - f = jax.jit(lambda x: x, abstracted_axes=('n',)) + f = jax.jit(lambda x: jnp.sin(x), abstracted_axes=('n',)) jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash # { lambda ; a:i32[3]. let # b:i32[3] = pjit[ diff --git a/tests/jaxpr_util_test.py b/tests/jaxpr_util_test.py index a8df5a325..4597ce6bd 100644 --- a/tests/jaxpr_util_test.py +++ b/tests/jaxpr_util_test.py @@ -61,8 +61,8 @@ class JaxprStatsTest(jtu.JaxTestCase): def test_primitives_by_shape(self): def f(x, y): def sub(x, y): - return jnp.sum(jnp.array([x, y])), y - s, _ = jit(sub)(x, y) + return jnp.sum(jnp.array([x, y])) + s = jit(sub)(x, y) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr) @@ -74,7 +74,7 @@ class JaxprStatsTest(jtu.JaxTestCase): f'cos :: float{t}[]', f'reduce_sum :: float{t}[]', f'concatenate :: float{t}[2]', - f'pjit :: float{t}[] *', + f'pjit :: float{t}[]', ] for k in shapes: self.assertEqual(hist[k], 1) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 208e272a1..0a5820829 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1282,22 +1282,22 @@ class PJitTest(jtu.BufferDonationTestCase): y = jnp.array([4.2, 2.4], dtype=jnp.float32) jaxpr = jax.make_jaxpr(g)(x, y) self.assertEqual( - jaxpr.pretty_print(), + jaxpr.pretty_print(use_color=False), textwrap.dedent(""" - let f = { lambda ; a:f32[1]. let in (a,) } in - let f1 = { lambda ; b:f32[2]. let in (b,) } in + let f = { lambda ; a:f32[1]. let in () } in + let f1 = { lambda ; b:f32[2]. let in () } in { lambda ; c:f32[1] d:f32[2]. let e:f32[2] = pjit[ name=g jaxpr={ lambda ; g:f32[1] h:f32[2]. let - i:f32[1] = pjit[name=f jaxpr=f] g - j:f32[1] = pjit[name=f jaxpr=f] g - k:f32[1] = mul i j - l:f32[2] = pjit[name=f jaxpr=f1] h - m:f32[2] = pjit[name=f jaxpr=f1] h - n:f32[2] = mul l m - o:f32[2] = add k n - in (o,) } + pjit[name=f jaxpr=f] g + pjit[name=f jaxpr=f] g + i:f32[1] = mul g g + pjit[name=f jaxpr=f1] h + pjit[name=f jaxpr=f1] h + j:f32[2] = mul h h + k:f32[2] = add i j + in (k,) } ] c d in (e,) } """).strip(), diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index b02cd17e7..9b3076434 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1353,13 +1353,13 @@ class ShardMapTest(jtu.JaxTestCase): @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def g(x): - return jax.jit(lambda x: x)(x) + return jax.jit(lambda x: 1. * x)(x) jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.) e, = jaxpr.jaxpr.eqns e1, e2 = e.params['jaxpr'].eqns self.assertEmpty(e1.outvars) - self.assertEmpty(e2.params['jaxpr'].eqns) + self.assertLen(e2.params['jaxpr'].eqns, 1) def test_fanout_specs_transpose_to_psum(self): mesh = jtu.create_global_mesh((4,), ('x',))