mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
[Rollback 2] Add keep_unused
to pjit
's API as a step to merge jit
and pjit
frontend API.
PiperOrigin-RevId: 495756613
This commit is contained in:
parent
1598c52faf
commit
ecaa215043
@ -1115,7 +1115,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused, inline):
|
||||
inline):
|
||||
checked_jaxpr, (out_tree, effects) = checkify_jaxpr(jaxpr, error,
|
||||
enabled_errors)
|
||||
out_error = error._add_placeholder_effects(effects)
|
||||
@ -1155,7 +1155,6 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
name=name,
|
||||
in_positional_semantics=new_positional_sems_in,
|
||||
out_positional_semantics=new_positional_sems_out,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
err, *out = tree_unflatten(out_tree, err_and_out)
|
||||
return out, err
|
||||
|
@ -497,13 +497,13 @@ class Compiled(Stage):
|
||||
else:
|
||||
raise
|
||||
outs = tree_util.tree_unflatten(params.out_tree, out_flat)
|
||||
return outs, out_flat, args_flat
|
||||
return outs, out_flat
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self._cpp_call is not None:
|
||||
return self._cpp_call(*args, **kwargs)
|
||||
|
||||
outs, _, _ = Compiled.call(self._params, *args, **kwargs)
|
||||
outs, _ = Compiled.call(self._params, *args, **kwargs)
|
||||
return outs
|
||||
|
||||
|
||||
|
@ -3118,7 +3118,6 @@ def _pjit(*args: TfVal,
|
||||
name: str,
|
||||
in_positional_semantics,
|
||||
out_positional_semantics,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
|
||||
|
@ -112,7 +112,7 @@ def _python_pjit_helper(infer_params, *args, **kwargs):
|
||||
_check_arg(arg)
|
||||
out_flat = pjit_p.bind(*args_flat, **params)
|
||||
outs = tree_unflatten(out_tree, out_flat)
|
||||
return outs, out_flat, out_tree, args_flat
|
||||
return outs, out_flat, out_tree
|
||||
|
||||
def _python_pjit(fun: Callable, infer_params):
|
||||
|
||||
@ -133,8 +133,7 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
|
||||
def cache_miss(*args, **kwargs):
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
|
||||
infer_params, *args, **kwargs)
|
||||
outs, out_flat, out_tree = _python_pjit_helper(infer_params, *args, **kwargs)
|
||||
|
||||
executable = _most_recent_pjit_call_executable.value
|
||||
_most_recent_pjit_call_executable.value = None
|
||||
@ -151,11 +150,9 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
|
||||
if use_fastpath:
|
||||
out_avals = [o.aval for o in out_flat]
|
||||
out_committed = [o._committed for o in out_flat]
|
||||
kept_var_bitvec = [i in executable._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
fastpath_data = pxla._MeshExecutableFastpathData(
|
||||
executable.xla_executable, out_tree, executable._in_shardings,
|
||||
executable._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
||||
executable._out_shardings, out_avals, out_committed)
|
||||
else:
|
||||
fastpath_data = None
|
||||
|
||||
@ -182,7 +179,6 @@ def pjit(
|
||||
static_argnums: Union[int, Sequence[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
donate_argnums: Union[int, Sequence[int]] = (),
|
||||
keep_unused: bool = False,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
inline: bool = False,
|
||||
@ -288,10 +284,6 @@ def pjit(
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
|
||||
keep_unused: If `False` (the default), arguments that JAX determines to be
|
||||
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
||||
Such arguments will not be transferred to the device nor provided to the
|
||||
underlying executable. If `True`, unused arguments will not be pruned.
|
||||
device: This argument is deprecated. Please put your arguments on the
|
||||
device you want before passing them to jit.
|
||||
Optional, the Device the jitted function will run on. (Available devices
|
||||
@ -304,7 +296,7 @@ def pjit(
|
||||
``'tpu'``.
|
||||
Returns:
|
||||
A wrapped version of ``fun``, set up for just-in-time compilation and
|
||||
automatically partitioned by the mesh available at each call site.
|
||||
automaticly partitioned by the mesh available at each call site.
|
||||
|
||||
For example, a convolution operator can be automatically partitioned over
|
||||
an arbitrary set of devices by a single :func:`~pjit` application:
|
||||
@ -327,7 +319,7 @@ def pjit(
|
||||
if not config.jax_array and (_is_unspecified(in_axis_resources) or
|
||||
_is_unspecified(out_axis_resources)):
|
||||
raise ValueError(
|
||||
"in_axis_resources and out_axis_resources should not "
|
||||
"in_axis_resources and out_axis_resouces should not "
|
||||
"be the unspecified singleton value. Please enable `jax.Array` to use "
|
||||
"this feature. You can use jax.config.update('jax_array', True) or "
|
||||
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
|
||||
@ -430,7 +422,7 @@ def pjit(
|
||||
out_shardings = tree_map(
|
||||
lambda x: x if _is_unspecified(x) else
|
||||
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
|
||||
# This check fails extremely rarely and has a huge cost in the dispatch
|
||||
# This check fails extrememly rarely and has a huge cost in the dispatch
|
||||
# path. So hide it behind the jax_enable_checks flag.
|
||||
if config.jax_enable_checks:
|
||||
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
|
||||
@ -474,13 +466,12 @@ def pjit(
|
||||
name=getattr(flat_fun, '__name__', '<unnamed function>'),
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
)
|
||||
return (args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 111:
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
|
||||
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params)
|
||||
@ -499,7 +490,7 @@ def pjit(
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
in_is_global, params['keep_unused'], always_lower=True)
|
||||
in_is_global, always_lower=True)
|
||||
|
||||
if kwargs:
|
||||
args_kwargs_in_tree = in_tree
|
||||
@ -1009,7 +1000,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused, inline):
|
||||
inline):
|
||||
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
@ -1024,8 +1015,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
_allow_propagation_to_outputs = False
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name, in_is_global, keep_unused,
|
||||
always_lower=False).compile(
|
||||
donated_invars, name, in_is_global, always_lower=False).compile(
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
|
||||
_most_recent_pjit_call_executable.value = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
@ -1094,7 +1084,6 @@ def _pjit_lower_cached(
|
||||
donated_invars,
|
||||
name: str,
|
||||
in_is_global: Sequence[bool],
|
||||
keep_unused: bool,
|
||||
always_lower: bool):
|
||||
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
@ -1141,7 +1130,7 @@ def _pjit_lower_cached(
|
||||
# the arguments just like dispatch.py in `sharded_lowering`.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True,
|
||||
always_lower=always_lower,
|
||||
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))
|
||||
|
||||
@ -1170,7 +1159,7 @@ pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, resource_env, donated_invars,
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
keep_unused, inline):
|
||||
inline):
|
||||
if not isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
|
||||
@ -1206,7 +1195,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
vals_in, dims_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
out_positional_semantics, inline):
|
||||
# batch_jaxpr expects all batching dimensions to be equal to 0
|
||||
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(vals_in, dims_in)]
|
||||
@ -1235,7 +1224,6 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
name=name,
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
|
||||
return vals_out, dims_out
|
||||
@ -1266,7 +1254,7 @@ def _pjit_batcher_for_sharding(
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
out_positional_semantics, inline):
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
jaxpr, is_nz_tangents_in, instantiate=False)
|
||||
@ -1285,7 +1273,6 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
name=wrap_name(name, 'jvp'),
|
||||
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
|
||||
@ -1299,7 +1286,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
out_positional_semantics, inline):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
|
||||
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
||||
@ -1329,7 +1316,6 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
name=name,
|
||||
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
if num_residuals:
|
||||
@ -1339,7 +1325,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
known_params["jaxpr"], known_params["in_shardings"],
|
||||
known_params["out_shardings"], known_params["resource_env"],
|
||||
known_params["donated_invars"], known_params["name"],
|
||||
in_is_global, known_params['keep_unused'], always_lower=False).compile(
|
||||
in_is_global, always_lower=False).compile(
|
||||
_allow_propagation_to_outputs=True,
|
||||
_allow_compile_replicated=False)
|
||||
da = compiled._device_assignment
|
||||
@ -1387,7 +1373,6 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
in_positional_semantics=(keep_where(
|
||||
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
unknown_tracers_out = [
|
||||
@ -1411,7 +1396,7 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
||||
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
out_positional_semantics, inline):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
|
||||
@ -1456,7 +1441,6 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
name=name,
|
||||
in_positional_semantics=transpose_in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
@ -3498,7 +3498,6 @@ class _MeshExecutableFastpathData(NamedTuple):
|
||||
out_shardings: Sequence[Any]
|
||||
out_avals: Sequence[Any]
|
||||
out_committed: Sequence[bool]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
|
||||
|
||||
class MeshExecutable(stages.XlaExecutable):
|
||||
@ -3578,24 +3577,26 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
not self.unsafe_call.has_host_callbacks):
|
||||
return None
|
||||
|
||||
if not flags.FLAGS.experimental_cpp_pjit or xc._version < 111:
|
||||
if not flags.FLAGS.experimental_cpp_pjit or xc._version < 96:
|
||||
return None
|
||||
|
||||
def aot_cache_miss(*args, **kwargs):
|
||||
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
|
||||
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
|
||||
outs, out_flat = stages.Compiled.call(params, *args, **kwargs)
|
||||
|
||||
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
|
||||
|
||||
if use_fastpath:
|
||||
out_avals = [o.aval for o in out_flat]
|
||||
out_committed = [o._committed for o in out_flat]
|
||||
kept_var_bitvec = [i in self._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
fastpath_data = _MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree, self._in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
||||
fastpath_data = _MeshExecutableFastpathData(self.xla_executable,
|
||||
out_tree,
|
||||
self._in_shardings,
|
||||
self._out_shardings,
|
||||
out_avals, out_committed)
|
||||
else:
|
||||
fastpath_data = None
|
||||
|
||||
return outs, fastpath_data
|
||||
|
||||
if xc._version < 108:
|
||||
|
@ -728,7 +728,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
# execution of the compiled function is blocking, so transferring data
|
||||
# to infeed before executing ensures that the execution does not deadlock
|
||||
# waiting for the infeed data.
|
||||
logging.info('Transferring to infeed for the jit call')
|
||||
logging.info('Transfering to infeed for the jit call')
|
||||
d = devices[0]
|
||||
d.transfer_to_infeed((y,))
|
||||
d.transfer_to_infeed((z,))
|
||||
@ -759,7 +759,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
partitions=(P(1, nr_devices),))
|
||||
return x + y + z + w
|
||||
|
||||
logging.info('Transferring to infeed for the pjit call')
|
||||
logging.info('Transfering to infeed for the pjit call')
|
||||
for didx, d in enumerate(devices):
|
||||
# Transfer the whole array to all devices for replicated.
|
||||
d.transfer_to_infeed((y,))
|
||||
@ -801,7 +801,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
|
||||
self.assertAllClose(x, y, check_dtypes=True)
|
||||
|
||||
logging.info('Transferring from outfeed for the pjit call')
|
||||
logging.info('Transfering from outfeed for the pjit call')
|
||||
for didx, d in enumerate(devices):
|
||||
# Transfer the whole array from all devices for replicated.
|
||||
check_outfeed(d, x)
|
||||
@ -2548,9 +2548,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
|
||||
@jax_array(True)
|
||||
def test_single_device_pjit_cpp_dispatch(self):
|
||||
if xla_extension_version < 111:
|
||||
self.skipTest('Does not work for xla_extension_version < 111')
|
||||
def test_single_device_pjit_perf(self):
|
||||
if xla_extension_version < 103:
|
||||
self.skipTest('Does not work for xla_extension_version < 103')
|
||||
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
@ -2579,8 +2579,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jax_array(True)
|
||||
def test_single_device_add_single_compile(self):
|
||||
if xla_extension_version < 111:
|
||||
self.skipTest('Does not work for xla_extension_version < 111')
|
||||
if xla_extension_version < 103:
|
||||
self.skipTest('Does not work for xla_extension_version < 103')
|
||||
|
||||
f1 = pjit(lambda x, y: x + y)
|
||||
a = jax.device_put(jnp.array([1, 2, 3], dtype=jnp.float32),
|
||||
@ -2635,7 +2635,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
("in_axis_resources and out_axis_resources should not "
|
||||
("in_axis_resources and out_axis_resouces should not "
|
||||
"be the unspecified singleton value. Please enable `jax.Array` to use "
|
||||
"this feature.")):
|
||||
pjit(lambda x: x)
|
||||
@ -2824,46 +2824,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
"pjit does not support kwargs when in_axis_resources is specified."):
|
||||
pjit(lambda x: x, in_axis_resources=None)(x=jnp.arange(8.))
|
||||
|
||||
def test_pjit_keep_unused_true(self):
|
||||
@partial(pjit, keep_unused=True)
|
||||
def f(x, y, z, a, b, c): # pylint: disable=unused-argument
|
||||
return c @ c.T
|
||||
|
||||
inp = jnp.arange(4)
|
||||
unused_inp = jnp.arange(8)
|
||||
|
||||
out = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
||||
# Run it again to take the C++ dispatch.
|
||||
out_again = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
||||
|
||||
self.assertArraysEqual(out, inp @ inp.T)
|
||||
self.assertArraysEqual(out_again, inp @ inp.T)
|
||||
|
||||
compiled = f.lower(
|
||||
unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp).compile()
|
||||
self.assertEqual(compiled._executable._kept_var_idx, {0, 1, 2, 3, 4, 5})
|
||||
self.assertLen(compiled._executable.in_avals, 6)
|
||||
|
||||
def test_pjit_keep_unused_default_false(self):
|
||||
@pjit
|
||||
def f(x, y, z, a, b, c): # pylint: disable=unused-argument
|
||||
return c @ c.T
|
||||
|
||||
inp = jax.device_put(jnp.arange(4), jax.devices()[0])
|
||||
unused_inp = jax.device_put(jnp.arange(8), jax.devices()[0])
|
||||
|
||||
out = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
||||
# Run it again to take the C++ dispatch.
|
||||
out_again = f(unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp)
|
||||
|
||||
self.assertArraysEqual(out, inp @ inp.T)
|
||||
self.assertArraysEqual(out_again, inp @ inp.T)
|
||||
|
||||
compiled = f.lower(
|
||||
unused_inp, unused_inp, unused_inp, unused_inp, unused_inp, inp).compile()
|
||||
self.assertEqual(compiled._executable._kept_var_idx, {5})
|
||||
self.assertLen(compiled._executable.in_avals, 1)
|
||||
|
||||
def test_pjit_with_device_arg(self):
|
||||
def mul(x):
|
||||
return x @ x.T
|
||||
|
Loading…
x
Reference in New Issue
Block a user