mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Rollback] Add keep_unused
to pjit
's API as a step to merge jit
and pjit
frontend API.
PiperOrigin-RevId: 495179106
This commit is contained in:
parent
08b6b0dd43
commit
82ca823956
@ -1061,8 +1061,7 @@ error_checks[lax.while_p] = while_loop_error_check
|
||||
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused):
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
|
||||
enabled_errors)
|
||||
out_error = error._add_placeholder_effects(effects)
|
||||
@ -1101,8 +1100,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
donated_invars=new_donated_invars,
|
||||
name=name,
|
||||
in_positional_semantics=new_positional_sems_in,
|
||||
out_positional_semantics=new_positional_sems_out,
|
||||
keep_unused=keep_unused)
|
||||
out_positional_semantics=new_positional_sems_out)
|
||||
err, *out = tree_unflatten(out_tree, err_and_out)
|
||||
return out, err
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
@ -3118,7 +3118,6 @@ def _pjit(*args: TfVal,
|
||||
name: str,
|
||||
in_positional_semantics,
|
||||
out_positional_semantics,
|
||||
keep_unused: bool,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
|
||||
del donated_invars
|
||||
|
@ -178,7 +178,6 @@ def pjit(
|
||||
static_argnums: Union[int, Sequence[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
donate_argnums: Union[int, Sequence[int]] = (),
|
||||
keep_unused: bool = False,
|
||||
) -> stages.Wrapped:
|
||||
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
||||
|
||||
@ -280,15 +279,11 @@ def pjit(
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
keep_unused: If `False` (the default), arguments that JAX determines to be
|
||||
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
||||
Such arguments will not be transferred to the device nor provided to the
|
||||
underlying executable. If `True`, unused arguments will not be pruned.
|
||||
|
||||
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
|
||||
Returns:
|
||||
A wrapped version of ``fun``, set up for just-in-time compilation and
|
||||
automatically partitioned by the mesh available at each call site.
|
||||
automaticly partitioned by the mesh available at each call site.
|
||||
|
||||
For example, a convolution operator can be automatically partitioned over
|
||||
an arbitrary set of devices by a single :func:`~pjit` application:
|
||||
@ -311,7 +306,7 @@ def pjit(
|
||||
if not config.jax_array and (_is_unspecified(in_axis_resources) or
|
||||
_is_unspecified(out_axis_resources)):
|
||||
raise ValueError(
|
||||
"in_axis_resources and out_axis_resources should not "
|
||||
"in_axis_resources and out_axis_resouces should not "
|
||||
"be the unspecified singleton value. Please enable `jax.Array` to use "
|
||||
"this feature. You can use jax.config.update('jax_array', True) or "
|
||||
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
|
||||
@ -385,7 +380,7 @@ def pjit(
|
||||
out_shardings = tree_map(
|
||||
lambda x: x if _is_unspecified(x) else
|
||||
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
|
||||
# This check fails extremely rarely and has a huge cost in the dispatch
|
||||
# This check fails extrememly rarely and has a huge cost in the dispatch
|
||||
# path. So hide it behind the jax_enable_checks flag.
|
||||
if config.jax_enable_checks:
|
||||
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
|
||||
@ -428,9 +423,7 @@ def pjit(
|
||||
donated_invars=donated_invars,
|
||||
name=getattr(flat_fun, '__name__', '<unnamed function>'),
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
)
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
return (args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
@ -453,7 +446,7 @@ def pjit(
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
in_is_global, params['keep_unused'], always_lower=True)
|
||||
in_is_global, always_lower=True)
|
||||
|
||||
if kwargs:
|
||||
args_kwargs_in_tree = in_tree
|
||||
@ -942,8 +935,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
|
||||
def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused):
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
@ -958,8 +950,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
_allow_propagation_to_outputs = False
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name, in_is_global, keep_unused,
|
||||
always_lower=False).compile(
|
||||
donated_invars, name, in_is_global, always_lower=False).compile(
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
|
||||
_most_recent_pjit_call_executable.value = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
@ -1028,7 +1019,6 @@ def _pjit_lower_cached(
|
||||
donated_invars,
|
||||
name: str,
|
||||
in_is_global: Sequence[bool],
|
||||
keep_unused: bool,
|
||||
always_lower: bool):
|
||||
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
@ -1075,7 +1065,7 @@ def _pjit_lower_cached(
|
||||
# the arguments just like dispatch.py in `sharded_lowering`.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True,
|
||||
always_lower=always_lower,
|
||||
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))
|
||||
|
||||
@ -1091,8 +1081,7 @@ pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, resource_env, donated_invars,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused):
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
if not isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
|
||||
@ -1128,7 +1117,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
vals_in, dims_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics):
|
||||
# batch_jaxpr expects all batching dimensions to be equal to 0
|
||||
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(vals_in, dims_in)]
|
||||
@ -1156,8 +1145,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
donated_invars=donated_invars,
|
||||
name=name,
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
|
||||
return vals_out, dims_out
|
||||
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
|
||||
@ -1187,7 +1175,7 @@ def _pjit_batcher_for_sharding(
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics):
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
jaxpr, is_nz_tangents_in, instantiate=False)
|
||||
@ -1205,8 +1193,7 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
||||
name=wrap_name(name, 'jvp'),
|
||||
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
|
||||
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
|
||||
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
|
||||
@ -1219,7 +1206,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
|
||||
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
||||
@ -1248,8 +1235,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
name=name,
|
||||
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
|
||||
if num_residuals:
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
@ -1258,7 +1244,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
known_params["jaxpr"], known_params["in_shardings"],
|
||||
known_params["out_shardings"], known_params["resource_env"],
|
||||
known_params["donated_invars"], known_params["name"],
|
||||
in_is_global, known_params['keep_unused'], always_lower=False).compile(
|
||||
in_is_global, always_lower=False).compile(
|
||||
_allow_propagation_to_outputs=True,
|
||||
_allow_compile_replicated=False)
|
||||
da = compiled._device_assignment
|
||||
@ -1305,8 +1291,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
name=name,
|
||||
in_positional_semantics=(keep_where(
|
||||
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
unknown_tracers_out = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
@ -1329,7 +1314,7 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
||||
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused):
|
||||
out_positional_semantics):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
|
||||
@ -1373,8 +1358,7 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
name=name,
|
||||
in_positional_semantics=transpose_in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused)
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
|
@ -728,7 +728,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
# execution of the compiled function is blocking, so transferring data
|
||||
# to infeed before executing ensures that the execution does not deadlock
|
||||
# waiting for the infeed data.
|
||||
logging.info('Transferring to infeed for the jit call')
|
||||
logging.info('Transfering to infeed for the jit call')
|
||||
d = devices[0]
|
||||
d.transfer_to_infeed((y,))
|
||||
d.transfer_to_infeed((z,))
|
||||
@ -759,7 +759,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
partitions=(P(1, nr_devices),))
|
||||
return x + y + z + w
|
||||
|
||||
logging.info('Transferring to infeed for the pjit call')
|
||||
logging.info('Transfering to infeed for the pjit call')
|
||||
for didx, d in enumerate(devices):
|
||||
# Transfer the whole array to all devices for replicated.
|
||||
d.transfer_to_infeed((y,))
|
||||
@ -801,7 +801,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
|
||||
self.assertAllClose(x, y, check_dtypes=True)
|
||||
|
||||
logging.info('Transferring from outfeed for the pjit call')
|
||||
logging.info('Transfering from outfeed for the pjit call')
|
||||
for didx, d in enumerate(devices):
|
||||
# Transfer the whole array from all devices for replicated.
|
||||
check_outfeed(d, x)
|
||||
@ -2635,7 +2635,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
("in_axis_resources and out_axis_resources should not "
|
||||
("in_axis_resources and out_axis_resouces should not "
|
||||
"be the unspecified singleton value. Please enable `jax.Array` to use "
|
||||
"this feature.")):
|
||||
pjit(lambda x: x)
|
||||
@ -2815,38 +2815,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
"pjit does not support kwargs when in_axis_resources is specified."):
|
||||
pjit(lambda x: x, in_axis_resources=None)(x=jnp.arange(8.))
|
||||
|
||||
def test_pjit_keep_unused_true(self):
|
||||
@partial(pjit, keep_unused=True)
|
||||
def f(x, y):
|
||||
return y
|
||||
|
||||
inp = jnp.arange(4)
|
||||
unused_inp = jnp.arange(8)
|
||||
out = f(unused_inp, inp)
|
||||
self.assertArraysEqual(out, inp)
|
||||
|
||||
compiled = f.lower(unused_inp, inp).compile()
|
||||
self.assertEqual(compiled._executable._kept_var_idx, {0, 1})
|
||||
self.assertLen(compiled._executable.in_avals, 2)
|
||||
|
||||
with jtu.count_device_put() as count:
|
||||
_ = f(1, 2)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_pjit_keep_unused_default_false(self):
|
||||
@pjit
|
||||
def f(x, y):
|
||||
return y
|
||||
|
||||
inp = jnp.arange(4)
|
||||
unused_inp = jnp.arange(8)
|
||||
out = f(unused_inp, inp)
|
||||
self.assertArraysEqual(out, inp)
|
||||
|
||||
compiled = f.lower(unused_inp, inp).compile()
|
||||
self.assertEqual(compiled._executable._kept_var_idx, {1})
|
||||
self.assertLen(compiled._executable.in_avals, 1)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user