[Rollback] Add keep_unused to pjit's API as a step to merge jit and pjit frontend API.

PiperOrigin-RevId: 495179106
This commit is contained in:
Yash Katariya 2022-12-13 18:30:33 -08:00 committed by jax authors
parent 08b6b0dd43
commit 82ca823956
4 changed files with 25 additions and 76 deletions

View File

@ -1061,8 +1061,7 @@ error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics,
keep_unused):
in_positional_semantics, out_positional_semantics):
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
enabled_errors)
out_error = error._add_placeholder_effects(effects)
@ -1101,8 +1100,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
donated_invars=new_donated_invars,
name=name,
in_positional_semantics=new_positional_sems_in,
out_positional_semantics=new_positional_sems_out,
keep_unused=keep_unused)
out_positional_semantics=new_positional_sems_out)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
error_checks[pjit.pjit_p] = pjit_error_check

View File

@ -3118,7 +3118,6 @@ def _pjit(*args: TfVal,
name: str,
in_positional_semantics,
out_positional_semantics,
keep_unused: bool,
_in_avals: Sequence[core.ShapedArray],
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
del donated_invars

View File

@ -178,7 +178,6 @@ def pjit(
static_argnums: Union[int, Sequence[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int]] = (),
keep_unused: bool = False,
) -> stages.Wrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
@ -280,15 +279,11 @@ def pjit(
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation and
automatically partitioned by the mesh available at each call site.
automaticly partitioned by the mesh available at each call site.
For example, a convolution operator can be automatically partitioned over
an arbitrary set of devices by a single :func:`~pjit` application:
@ -311,7 +306,7 @@ def pjit(
if not config.jax_array and (_is_unspecified(in_axis_resources) or
_is_unspecified(out_axis_resources)):
raise ValueError(
"in_axis_resources and out_axis_resources should not "
"in_axis_resources and out_axis_resouces should not "
"be the unspecified singleton value. Please enable `jax.Array` to use "
"this feature. You can use jax.config.update('jax_array', True) or "
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
@ -385,7 +380,7 @@ def pjit(
out_shardings = tree_map(
lambda x: x if _is_unspecified(x) else
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
# This check fails extremely rarely and has a huge cost in the dispatch
# This check fails extrememly rarely and has a huge cost in the dispatch
# path. So hide it behind the jax_enable_checks flag.
if config.jax_enable_checks:
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
@ -428,9 +423,7 @@ def pjit(
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unnamed function>'),
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
)
out_positional_semantics=out_positional_semantics)
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)
@ -453,7 +446,7 @@ def pjit(
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
in_is_global, params['keep_unused'], always_lower=True)
in_is_global, always_lower=True)
if kwargs:
args_kwargs_in_tree = in_tree
@ -942,8 +935,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics,
keep_unused):
in_positional_semantics, out_positional_semantics):
global _most_recent_pjit_call_executable
@ -958,8 +950,7 @@ def _pjit_call_impl(*args, jaxpr,
_allow_propagation_to_outputs = False
compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, in_is_global, keep_unused,
always_lower=False).compile(
donated_invars, name, in_is_global, always_lower=False).compile(
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
_most_recent_pjit_call_executable.value = compiled
# This check is expensive so only do it if enable_checks is on.
@ -1028,7 +1019,6 @@ def _pjit_lower_cached(
donated_invars,
name: str,
in_is_global: Sequence[bool],
keep_unused: bool,
always_lower: bool):
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
@ -1075,7 +1065,7 @@ def _pjit_lower_cached(
# the arguments just like dispatch.py in `sharded_lowering`.
return pxla.lower_sharding_computation(
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True,
always_lower=always_lower,
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))
@ -1091,8 +1081,7 @@ pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
in_positional_semantics, out_positional_semantics,
keep_unused):
in_positional_semantics, out_positional_semantics):
if not isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
@ -1128,7 +1117,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
vals_in, dims_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics):
# batch_jaxpr expects all batching dimensions to be equal to 0
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(vals_in, dims_in)]
@ -1156,8 +1145,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
donated_invars=donated_invars,
name=name,
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
out_positional_semantics=out_positional_semantics)
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
return vals_out, dims_out
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
@ -1187,7 +1175,7 @@ def _pjit_batcher_for_sharding(
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics):
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
jaxpr, is_nz_tangents_in, instantiate=False)
@ -1205,8 +1193,7 @@ def _pjit_jvp(primals_in, tangents_in,
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
name=wrap_name(name, 'jvp'),
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
out_positional_semantics=out_positional_semantics)
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
@ -1219,7 +1206,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics):
in_pvals = [t.pval for t in in_tracers]
known_ins = tuple(pv.is_known() for pv in in_pvals)
@ -1248,8 +1235,7 @@ def _pjit_partial_eval(trace, *in_tracers,
donated_invars=keep_where(donated_invars, known_ins),
name=name,
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
out_positional_semantics=out_positional_semantics)
if num_residuals:
in_is_global = _calc_is_global_sequence(
@ -1258,7 +1244,7 @@ def _pjit_partial_eval(trace, *in_tracers,
known_params["jaxpr"], known_params["in_shardings"],
known_params["out_shardings"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global, known_params['keep_unused'], always_lower=False).compile(
in_is_global, always_lower=False).compile(
_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
da = compiled._device_assignment
@ -1305,8 +1291,7 @@ def _pjit_partial_eval(trace, *in_tracers,
name=name,
in_positional_semantics=(keep_where(
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
out_positional_semantics=out_positional_semantics)
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_tracers_out = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
@ -1329,7 +1314,7 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused):
out_positional_semantics):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
@ -1373,8 +1358,7 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
donated_invars=(False,) * len(primals_and_nz_cts_in),
name=name,
in_positional_semantics=transpose_in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused)
out_positional_semantics=out_positional_semantics)
return tree_unflatten(cts_out_treedef, nz_cts_out)
ad.reducing_transposes[pjit_p] = _pjit_transpose

View File

@ -728,7 +728,7 @@ class PJitTest(jtu.BufferDonationTestCase):
# execution of the compiled function is blocking, so transferring data
# to infeed before executing ensures that the execution does not deadlock
# waiting for the infeed data.
logging.info('Transferring to infeed for the jit call')
logging.info('Transfering to infeed for the jit call')
d = devices[0]
d.transfer_to_infeed((y,))
d.transfer_to_infeed((z,))
@ -759,7 +759,7 @@ class PJitTest(jtu.BufferDonationTestCase):
partitions=(P(1, nr_devices),))
return x + y + z + w
logging.info('Transferring to infeed for the pjit call')
logging.info('Transfering to infeed for the pjit call')
for didx, d in enumerate(devices):
# Transfer the whole array to all devices for replicated.
d.transfer_to_infeed((y,))
@ -801,7 +801,7 @@ class PJitTest(jtu.BufferDonationTestCase):
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
self.assertAllClose(x, y, check_dtypes=True)
logging.info('Transferring from outfeed for the pjit call')
logging.info('Transfering from outfeed for the pjit call')
for didx, d in enumerate(devices):
# Transfer the whole array from all devices for replicated.
check_outfeed(d, x)
@ -2635,7 +2635,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
("in_axis_resources and out_axis_resources should not "
("in_axis_resources and out_axis_resouces should not "
"be the unspecified singleton value. Please enable `jax.Array` to use "
"this feature.")):
pjit(lambda x: x)
@ -2815,38 +2815,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
"pjit does not support kwargs when in_axis_resources is specified."):
pjit(lambda x: x, in_axis_resources=None)(x=jnp.arange(8.))
def test_pjit_keep_unused_true(self):
@partial(pjit, keep_unused=True)
def f(x, y):
return y
inp = jnp.arange(4)
unused_inp = jnp.arange(8)
out = f(unused_inp, inp)
self.assertArraysEqual(out, inp)
compiled = f.lower(unused_inp, inp).compile()
self.assertEqual(compiled._executable._kept_var_idx, {0, 1})
self.assertLen(compiled._executable.in_avals, 2)
with jtu.count_device_put() as count:
_ = f(1, 2)
self.assertEqual(count[0], 1)
def test_pjit_keep_unused_default_false(self):
@pjit
def f(x, y):
return y
inp = jnp.arange(4)
unused_inp = jnp.arange(8)
out = f(unused_inp, inp)
self.assertArraysEqual(out, inp)
compiled = f.lower(unused_inp, inp).compile()
self.assertEqual(compiled._executable._kept_var_idx, {1})
self.assertLen(compiled._executable.in_avals, 1)
class TempSharding(Sharding):