From f0ce0d8c6a3e903e33443cab4bf1a67cd8a92d8d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 17 Jul 2023 06:35:13 -0700 Subject: [PATCH] Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit. PiperOrigin-RevId: 548673905 --- CHANGELOG.md | 9 +++++++++ jax/_src/pjit.py | 38 -------------------------------------- tests/pjit_test.py | 11 ----------- 3 files changed, 9 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 299631235..4d6733a41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,15 @@ Remember to align the itemized text with the first line of an item within a list parameters listed in either donate_argnums or donate_argnames will be donated. +* Deletions + * `in_axis_resources` and `out_axis_resources` have been deleted from pjit since + it has been more than 3 months since their deprecation. Please use + `in_shardings` and `out_shardings` as the replacement. + This is a safe and trivial name replacement. It does not change any of the + current pjit semantics and doesn't break any code. + You can still pass in `PartitionSpecs` to in_shardings and out_shardings. + + * Deprecations * Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 05cd788b1..d004f4cdb 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -264,37 +264,6 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames, return cpp_pjitted_f -def _resolve_axis_resources_and_shardings_arg( - in_shardings, out_shardings, in_axis_resources, out_axis_resources): - if (in_shardings is not None and in_axis_resources is not None and - not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources)): - raise ValueError( - 'Setting both in_shardings and in_axis_resources is not ' - 'allowed. in_axis_resources is deprecated. Please use in_shardings.') - if (out_shardings is not None and out_axis_resources is not None and - not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources)): - raise ValueError( - 'Setting both out_shardings and out_axis_resources is not ' - 'allowed. out_axis_resources is deprecated. Please use out_shardings.') - if ((in_axis_resources is not None and not is_unspecified(in_axis_resources)) or - (out_axis_resources is not None and not is_unspecified(out_axis_resources))): - warnings.warn( - 'in_axis_resources and out_axis_resources are deprecated. Please use ' - 'in_shardings and out_shardings as their replacement.', - DeprecationWarning) - - if in_axis_resources is not None and not is_unspecified(in_axis_resources): - final_in_shardings = in_axis_resources - else: - final_in_shardings = in_shardings - - if out_axis_resources is not None and not is_unspecified(out_axis_resources): - final_out_shardings = out_axis_resources - else: - final_out_shardings = out_shardings - return final_in_shardings, final_out_shardings - - def pre_infer_params(fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, @@ -588,8 +557,6 @@ def pjit( fun: Callable, in_shardings=UNSPECIFIED, out_shardings=UNSPECIFIED, - in_axis_resources=UNSPECIFIED, - out_axis_resources=UNSPECIFIED, static_argnums: Union[int, Sequence[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, donate_argnums: Union[int, Sequence[int], None] = None, @@ -698,8 +665,6 @@ def pjit( assignment for function outputs. The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` will use GSPMD's sharding propagation to determine how to shard the outputs. - in_axis_resources: (Deprecated) Please use in_shardings. - out_axis_resources: (Deprecated) Please use out_shardings. static_argnums: An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in @@ -779,9 +744,6 @@ def pjit( ... print(f(x)) # doctest: +SKIP [ 0.5 2. 4. 6. 8. 10. 12. 10. ] """ - in_shardings, out_shardings = _resolve_axis_resources_and_shardings_arg( - in_shardings, out_shardings, in_axis_resources, out_axis_resources) - (in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames) = pre_infer_params( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9e89701de..e5244a1c7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3010,17 +3010,6 @@ class ArrayPjitTest(jtu.JaxTestCase): pf1(jnp.arange(8.)) self.assertEqual(count[0], 1) - def test_set_both_axis_resources_and_shardings(self): - with self.assertRaisesRegex( - ValueError, - "Setting both in_shardings and in_axis_resources is not allowed"): - pjit(lambda x: x, in_shardings=P('x'), in_axis_resources=P('x')) - - with self.assertRaisesRegex( - ValueError, - "Setting both out_shardings and out_axis_resources is not allowed"): - pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x')) - def test_with_sharding_constraint_spmd_axis_name(self): mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl')) shape = (8, 4, 2, 2)