From 2fc64bee13e3b1a986851b95bb5721ca66a8aa63 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 13 Feb 2023 10:53:21 -0800 Subject: [PATCH] Change the `axis_resources` argument of `with_sharding_constraint` to `shardings` to match `pjit` and `jit`. PiperOrigin-RevId: 509275107 --- CHANGELOG.md | 5 ++++ jax/_src/pjit.py | 60 +++++++++++++++++++++++++++++++++------------- tests/pjit_test.py | 15 +++++++++++- 3 files changed, 63 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05a8f8b09..1d8ae2893 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,11 @@ Remember to align the itemized text with the first line of an item within a list `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. The merge must be disabled via an environment variable since it affects JAX at import time so it needs to be disabled before jax is imported. + * `axis_resources` argument of `with_sharding_constraint` is deprecated. + Please use `shardings` instead. There is no change needed if you were using + `axis_resources` as an arg. If you were using it as a kwarg, then please + use `shardings` instead. `axis_resources` will be removed after 3 months + from Feb 13, 2023. ## jaxlib 0.4.4 * Breaking changes diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d0028aac5..7ea6f38e1 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1975,37 +1975,65 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule # -------------------- with_sharding_constraint -------------------- -def with_sharding_constraint(x, axis_resources): +def _resolve_wsc_args(axis_resources, shardings): + if not _is_unspecified(axis_resources) and not _is_unspecified(shardings): + raise ValueError( + 'Setting both axis_resources and shardings is not ' + 'allowed. axis_resources is deprecated. Please use shardings.') + if _is_unspecified(axis_resources) and _is_unspecified(shardings): + raise ValueError( + 'Not specifying shardings to `with_sharding_constraint` is not allowed. ' + 'Please specify the shardings argument with a concrete sharding. Note ' + 'that axis_resources is deprecated, so use the shardings argument.') + + if not _is_unspecified(axis_resources): + final_shardings = axis_resources + else: + final_shardings = shardings + return final_shardings + + +# TODO(yashkatariya): Remove the axis_resources argument and make the signature +# `with_sharding_constraint(x, shardings)` with no defaults after deprecation +# period is finished. The deprecation period expires 3 months from Feb 13, 2023. +def with_sharding_constraint(x, axis_resources=_UNSPECIFIED, + shardings=_UNSPECIFIED): + final_shardings = _resolve_wsc_args(axis_resources, shardings) x_flat, tree = tree_flatten(x) - axis_resources, _, _ = _prepare_axis_resources( - axis_resources, "axis_resources", allow_unconstrained_dims=True) - axis_resources_flat = tuple( - flatten_axes("with_sharding_constraint sharding", tree, axis_resources)) + user_shardings, _, _ = _prepare_axis_resources( + final_shardings, "shardings", allow_unconstrained_dims=True) + del final_shardings + + user_shardings_flat = tuple( + flatten_axes("with_sharding_constraint shardings", tree, user_shardings)) + del user_shardings + resource_env = pxla.thread_resources.env mesh = resource_env.physical_mesh if config.jax_array: - sharding_flat = [_create_sharding_for_array(mesh, a) - for a in axis_resources_flat] - unconstrained_dims = [ - get_unconstrained_dims(s) if isinstance(s, NamedSharding) else {} - for s in sharding_flat - ] + shardings_flat = [_create_sharding_for_array(mesh, a) + for a in user_shardings_flat] + unconstrained_dims = [get_unconstrained_dims(s) + if isinstance(s, NamedSharding) else {} + for s in shardings_flat] else: - sharding_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a) - for a in axis_resources_flat] + shardings_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a) + for a in user_shardings_flat] # Calculate unconstrained_dims from NamedSharding because that information # is lost when converted to OpSharding. Bind unconstrained_dims to # with_sharding_constraint primitive. - unconstrained_dims = [get_unconstrained_dims(s) for s in sharding_flat] + unconstrained_dims = [get_unconstrained_dims(s) for s in shardings_flat] - pjit_check_aval_sharding(sharding_flat, x_flat, "with_sharding_constraint arguments", + del user_shardings_flat + + pjit_check_aval_sharding(shardings_flat, x_flat, "with_sharding_constraint arguments", allow_uneven_sharding=True) outs = [sharding_constraint_p.bind(xf, sharding=to_op_sharding_sharding(i, xf.ndim), resource_env=resource_env, unconstrained_dims=ud) - for xf, i, ud in safe_zip(x_flat, sharding_flat, unconstrained_dims)] + for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)] return tree_unflatten(tree, outs) def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 47ef6506f..d5d0d8ebe 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3348,6 +3348,19 @@ class ArrayPjitTest(jtu.JaxTestCase): "Setting both out_shardings and out_axis_resources is not allowed"): pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x')) + def test_set_none_wsc_axis_resources_and_shardings(self): + with self.assertRaisesRegex( + ValueError, + "Not specifying shardings to `with_sharding_constraint` is not allowed."): + pjit(jax.lax.with_sharding_constraint(jnp.arange(8))) + + def test_set_both_wsc_axis_resources_and_shardings(self): + with self.assertRaisesRegex( + ValueError, + "Setting both axis_resources and shardings is not allowed"): + pjit(jax.lax.with_sharding_constraint( + jnp.arange(8), axis_resources=P('x'), shardings=P('x'))) + class TempSharding(Sharding): @@ -3520,7 +3533,7 @@ class PJitErrorTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 2)]) def testConstraintShardsXMapAxis(self): spec = P('x') - f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec), + f = xmap(lambda x: with_sharding_constraint(x, spec), in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'}) x = jnp.arange(4).reshape((2, 2)) error = (r"with_sharding_constraint input has an axis resources specification of " +