diff --git a/CHANGELOG.md b/CHANGELOG.md index bb50d4902..c1575fbb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,16 +11,18 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * The following APIs have been removed after a 3 month deprecation period, in accordance with the {ref}`api-compatibility` policy: - - `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`. - - `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh` - - `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`. - - `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`. - - `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects + * `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`. + * `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh` + * `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`. + * `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`. + * `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects as input and remove the optional `in_shardings` argument to `pjit`. - - `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`. - - `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh` - - `jax.interpreters.xla.Device`: use `jax.Device`. - - `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead, + * `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`. + * `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh` + * `jax.interpreters.xla.Device`: use `jax.Device`. + * `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead + * `axis_resources` argument of `with_sharding_constraint` is removed. Please + use `shardings` instead. ## jaxlib 0.4.11 diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 337d44d12..d360631bc 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1786,32 +1786,7 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule # -------------------- with_sharding_constraint -------------------- -def _resolve_wsc_args(axis_resources, shardings): - if not is_unspecified(axis_resources) and not is_unspecified(shardings): - raise ValueError( - 'Setting both axis_resources and shardings is not ' - 'allowed. axis_resources is deprecated. Please use shardings.') - if is_unspecified(axis_resources) and is_unspecified(shardings): - raise ValueError( - 'Not specifying shardings to `with_sharding_constraint` is not allowed. ' - 'Please specify the shardings argument with a concrete sharding. Note ' - 'that axis_resources is deprecated, so use the shardings argument.') - - if not is_unspecified(axis_resources): - warnings.warn( - 'axis_resources is deprecated. Please use shardings argument instead.', - DeprecationWarning) - final_shardings = axis_resources - else: - final_shardings = shardings - return final_shardings - - -# TODO(yashkatariya): Remove the axis_resources argument and make the signature -# `with_sharding_constraint(x, shardings)` with no defaults after deprecation -# period is finished. The deprecation period expires 3 months from Feb 13, 2023. -def with_sharding_constraint(x, shardings=UNSPECIFIED, - axis_resources=UNSPECIFIED): +def with_sharding_constraint(x, shardings): """Mechanism to constrain the sharding of an Array inside a jitted computation This is a strict constraint for the GSPMD partitioner and not a hint. For examples @@ -1821,17 +1796,15 @@ def with_sharding_constraint(x, shardings=UNSPECIFIED, x: PyTree of jax.Arrays which will have their shardings constrainted shardings: PyTree of sharding specifications. Valid values are the same as for the ``in_shardings`` argument of :func:`jax.experimental.pjit`. - axis_resources: (deprecated) use shardings instead. Returns: x_with_shardings: PyTree of jax.Arrays with specified sharding constraints. .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ - final_shardings = _resolve_wsc_args(axis_resources, shardings) x_flat, tree = tree_flatten(x) user_shardings, _, _ = prepare_axis_resources( - final_shardings, "shardings", allow_unconstrained_dims=True) - del final_shardings + shardings, "shardings", allow_unconstrained_dims=True) + del shardings user_shardings_flat = tuple( flatten_axes("with_sharding_constraint shardings", tree, user_shardings)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c6e09ff06..1503a791d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2953,19 +2953,6 @@ class ArrayPjitTest(jtu.JaxTestCase): "Setting both out_shardings and out_axis_resources is not allowed"): pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x')) - def test_set_none_wsc_axis_resources_and_shardings(self): - with self.assertRaisesRegex( - ValueError, - "Not specifying shardings to `with_sharding_constraint` is not allowed."): - pjit(jax.lax.with_sharding_constraint(jnp.arange(8))) - - def test_set_both_wsc_axis_resources_and_shardings(self): - with self.assertRaisesRegex( - ValueError, - "Setting both axis_resources and shardings is not allowed"): - pjit(jax.lax.with_sharding_constraint( - jnp.arange(8), axis_resources=P('x'), shardings=P('x'))) - def test_with_sharding_constraint_spmd_axis_name(self): mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl')) shape = (8, 4, 2, 2)