mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for inline
to pjit
. This is to merge the jit
and pjit
frontend API.
Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 495726005
This commit is contained in:
parent
a99e3e7079
commit
3b9088f9a3
@ -1115,7 +1115,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused):
|
||||
keep_unused, inline):
|
||||
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
|
||||
enabled_errors)
|
||||
out_error = error._add_placeholder_effects(effects)
|
||||
@ -1155,7 +1155,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
name=name,
|
||||
in_positional_semantics=new_positional_sems_in,
|
||||
out_positional_semantics=new_positional_sems_out,
|
||||
keep_unused=keep_unused)
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
err, *out = tree_unflatten(out_tree, err_and_out)
|
||||
return out, err
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
@ -3119,6 +3119,7 @@ def _pjit(*args: TfVal,
|
||||
in_positional_semantics,
|
||||
out_positional_semantics,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
|
||||
del donated_invars
|
||||
|
@ -185,6 +185,7 @@ def pjit(
|
||||
keep_unused: bool = False,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
inline: bool = False,
|
||||
) -> stages.Wrapped:
|
||||
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
||||
|
||||
@ -474,6 +475,7 @@ def pjit(
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
)
|
||||
return (args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
@ -1007,7 +1009,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused):
|
||||
keep_unused, inline):
|
||||
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
@ -1144,6 +1146,18 @@ def _pjit_lower_cached(
|
||||
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))
|
||||
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
if (params["inline"] and
|
||||
all(_is_unspecified(i) for i in params["in_shardings"]) and
|
||||
all(_is_unspecified(o) for o in params["out_shardings"])):
|
||||
jaxpr = params['jaxpr']
|
||||
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
|
||||
else:
|
||||
return trace.default_process_primitive(pjit_p, args, params)
|
||||
|
||||
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
||||
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
|
||||
out_positional_semantics, **_):
|
||||
if jaxpr.effects:
|
||||
@ -1156,7 +1170,7 @@ pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, resource_env, donated_invars,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused):
|
||||
keep_unused, inline):
|
||||
if not isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
|
||||
@ -1192,7 +1206,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
vals_in, dims_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
# batch_jaxpr expects all batching dimensions to be equal to 0
|
||||
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(vals_in, dims_in)]
|
||||
@ -1221,7 +1235,8 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
name=name,
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
|
||||
return vals_out, dims_out
|
||||
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
|
||||
@ -1251,7 +1266,7 @@ def _pjit_batcher_for_sharding(
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
jaxpr, is_nz_tangents_in, instantiate=False)
|
||||
@ -1270,7 +1285,8 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
name=wrap_name(name, 'jvp'),
|
||||
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
|
||||
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
|
||||
@ -1283,7 +1299,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
|
||||
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
||||
@ -1313,7 +1329,8 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
name=name,
|
||||
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
if num_residuals:
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
@ -1370,7 +1387,8 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
in_positional_semantics=(keep_where(
|
||||
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
unknown_tracers_out = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
@ -1393,7 +1411,7 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
||||
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
|
||||
@ -1438,7 +1456,8 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
name=name,
|
||||
in_positional_semantics=transpose_in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
|
@ -2795,20 +2795,29 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return x, y, z
|
||||
|
||||
o1, o2, o3 = f(a, y=b, z=c)
|
||||
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertArraysEqual(o1, a)
|
||||
self.assertArraysEqual(o2, b)
|
||||
self.assertArraysEqual(o3, c)
|
||||
|
||||
o4, o5, o6 = f(x=a, y=b, z=c)
|
||||
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertArraysEqual(o4, a)
|
||||
self.assertArraysEqual(o5, b)
|
||||
self.assertArraysEqual(o6, c)
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses + 1)
|
||||
|
||||
o7, o8, o9 = f(a, b, c)
|
||||
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertArraysEqual(o7, a)
|
||||
self.assertArraysEqual(o8, b)
|
||||
self.assertArraysEqual(o9, c)
|
||||
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
||||
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
|
||||
|
||||
def test_pjit_kwargs_axis_resources_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -2961,6 +2970,21 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
"device is also specified as an argument to jit."):
|
||||
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
|
||||
|
||||
def test_pjit_inline(self):
|
||||
@partial(pjit, inline=False)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(3)
|
||||
self.assertIn('pjit', str(jaxpr))
|
||||
|
||||
@partial(pjit, inline=True)
|
||||
def g(x):
|
||||
return x * 2
|
||||
|
||||
jaxpr = jax.make_jaxpr(g)(3)
|
||||
self.assertNotIn('pjit', str(jaxpr))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user