From 69c9660aab9ab8cbab13ca93aa691298f5d072c4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 30 Mar 2023 14:50:07 -0700 Subject: [PATCH] Raise deprecation warnings for `{in|out}_axis_resources` for pjit and `axis_resources` for with_sharding_constraint PiperOrigin-RevId: 520748845 --- CHANGELOG.md | 4 ++++ jax/_src/pjit.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index abdf856e5..1f9cc5c1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.9 +* Deprecations + * The `in_axis_resources` and `out_axis_resources` arguments of pjit have been + deprecated. Please use `in_shardings` and `out_shardings` respectively. + ## jaxlib 0.4.9 ## jax 0.4.8 (March 29, 2023) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 864feb1e5..1accfe316 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -289,6 +289,12 @@ def _resolve_axis_resources_and_shardings_arg( raise ValueError( 'Setting both out_shardings and out_axis_resources is not ' 'allowed. out_axis_resources is deprecated. Please use out_shardings.') + if (not _is_unspecified(in_axis_resources) or + not _is_unspecified(out_axis_resources)): + warnings.warn( + 'in_axis_resources and out_axis_resources are deprecated. Please use ' + 'in_shardings and out_shardings as their replacement.', + DeprecationWarning) if not _is_unspecified(in_axis_resources): final_in_shardings = in_axis_resources @@ -1888,6 +1894,9 @@ def _resolve_wsc_args(axis_resources, shardings): 'that axis_resources is deprecated, so use the shardings argument.') if not _is_unspecified(axis_resources): + warnings.warn( + 'axis_resources is deprecated. Please use shardings argument instead.', + DeprecationWarning) final_shardings = axis_resources else: final_shardings = shardings @@ -1897,8 +1906,8 @@ def _resolve_wsc_args(axis_resources, shardings): # TODO(yashkatariya): Remove the axis_resources argument and make the signature # `with_sharding_constraint(x, shardings)` with no defaults after deprecation # period is finished. The deprecation period expires 3 months from Feb 13, 2023. -def with_sharding_constraint(x, axis_resources=_UNSPECIFIED, - shardings=_UNSPECIFIED): +def with_sharding_constraint(x, shardings=_UNSPECIFIED, + axis_resources=_UNSPECIFIED): final_shardings = _resolve_wsc_args(axis_resources, shardings) x_flat, tree = tree_flatten(x) user_shardings, _, _ = _prepare_axis_resources(