Add support for inline to pjit. This is to merge the jit and pjit frontend API.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 495726005
This commit is contained in:
Yash Katariya 2022-12-15 16:25:45 -08:00 committed by jax authors
parent a99e3e7079
commit 3b9088f9a3
4 changed files with 58 additions and 13 deletions

View File

@ -1115,7 +1115,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics,
keep_unused):
keep_unused, inline):
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
enabled_errors)
out_error = error._add_placeholder_effects(effects)
@ -1155,7 +1155,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
name=name,
in_positional_semantics=new_positional_sems_in,
out_positional_semantics=new_positional_sems_out,
keep_unused=keep_unused)
keep_unused=keep_unused,
inline=inline)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
error_checks[pjit.pjit_p] = pjit_error_check

View File

@ -3119,6 +3119,7 @@ def _pjit(*args: TfVal,
in_positional_semantics,
out_positional_semantics,
keep_unused: bool,
inline: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
del donated_invars

View File

@ -185,6 +185,7 @@ def pjit(
keep_unused: bool = False,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
inline: bool = False,
) -> stages.Wrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
@ -474,6 +475,7 @@ def pjit(
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline,
)
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)
@ -1007,7 +1009,7 @@ def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics,
keep_unused):
keep_unused, inline):
global _most_recent_pjit_call_executable
@ -1144,6 +1146,18 @@ def _pjit_lower_cached(
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))
def pjit_staging_rule(trace, *args, **params):
if (params["inline"] and
all(_is_unspecified(i) for i in params["in_shardings"]) and
all(_is_unspecified(o) for o in params["out_shardings"])):
jaxpr = params['jaxpr']
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
else:
return trace.default_process_primitive(pjit_p, args, params)
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
out_positional_semantics, **_):
if jaxpr.effects:
@ -1156,7 +1170,7 @@ pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
in_positional_semantics, out_positional_semantics,
keep_unused):
keep_unused, inline):
if not isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
@ -1192,7 +1206,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
vals_in, dims_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics, keep_unused, inline):
# batch_jaxpr expects all batching dimensions to be equal to 0
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(vals_in, dims_in)]
@ -1221,7 +1235,8 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
name=name,
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
keep_unused=keep_unused,
inline=inline)
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
return vals_out, dims_out
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
@ -1251,7 +1266,7 @@ def _pjit_batcher_for_sharding(
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics, keep_unused, inline):
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
jaxpr, is_nz_tangents_in, instantiate=False)
@ -1270,7 +1285,8 @@ def _pjit_jvp(primals_in, tangents_in,
name=wrap_name(name, 'jvp'),
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
keep_unused=keep_unused,
inline=inline)
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
@ -1283,7 +1299,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics, keep_unused, inline):
in_pvals = [t.pval for t in in_tracers]
known_ins = tuple(pv.is_known() for pv in in_pvals)
@ -1313,7 +1329,8 @@ def _pjit_partial_eval(trace, *in_tracers,
name=name,
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
keep_unused=keep_unused,
inline=inline)
if num_residuals:
in_is_global = _calc_is_global_sequence(
@ -1370,7 +1387,8 @@ def _pjit_partial_eval(trace, *in_tracers,
in_positional_semantics=(keep_where(
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
keep_unused=keep_unused,
inline=inline)
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_tracers_out = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
@ -1393,7 +1411,7 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics, keep_unused, inline):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
@ -1438,7 +1456,8 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
name=name,
in_positional_semantics=transpose_in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
keep_unused=keep_unused,
inline=inline)
return tree_unflatten(cts_out_treedef, nz_cts_out)
ad.reducing_transposes[pjit_p] = _pjit_transpose

View File

@ -2795,20 +2795,29 @@ class ArrayPjitTest(jtu.JaxTestCase):
return x, y, z
o1, o2, o3 = f(a, y=b, z=c)
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
self.assertArraysEqual(o1, a)
self.assertArraysEqual(o2, b)
self.assertArraysEqual(o3, c)
o4, o5, o6 = f(x=a, y=b, z=c)
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
self.assertArraysEqual(o4, a)
self.assertArraysEqual(o5, b)
self.assertArraysEqual(o6, c)
self.assertEqual(cache_info2.hits, cache_info1.hits)
self.assertEqual(cache_info2.misses, cache_info1.misses + 1)
o7, o8, o9 = f(a, b, c)
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
self.assertArraysEqual(o7, a)
self.assertArraysEqual(o8, b)
self.assertArraysEqual(o9, c)
self.assertEqual(cache_info3.hits, cache_info2.hits)
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
def test_pjit_kwargs_axis_resources_error(self):
with self.assertRaisesRegex(
ValueError,
@ -2961,6 +2970,21 @@ class ArrayPjitTest(jtu.JaxTestCase):
"device is also specified as an argument to jit."):
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
def test_pjit_inline(self):
@partial(pjit, inline=False)
def f(x):
return x * 2
jaxpr = jax.make_jaxpr(f)(3)
self.assertIn('pjit', str(jaxpr))
@partial(pjit, inline=True)
def g(x):
return x * 2
jaxpr = jax.make_jaxpr(g)(3)
self.assertNotIn('pjit', str(jaxpr))
class TempSharding(Sharding):