Raise deprecation warnings for {in|out}_axis_resources for pjit and axis_resources for with_sharding_constraint

PiperOrigin-RevId: 520748845
This commit is contained in:
Yash Katariya 2023-03-30 14:50:07 -07:00 committed by jax authors
parent 36bf14b044
commit 69c9660aab
2 changed files with 15 additions and 2 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.9
* Deprecations
* The `in_axis_resources` and `out_axis_resources` arguments of pjit have been
deprecated. Please use `in_shardings` and `out_shardings` respectively.
## jaxlib 0.4.9
## jax 0.4.8 (March 29, 2023)

View File

@ -289,6 +289,12 @@ def _resolve_axis_resources_and_shardings_arg(
raise ValueError(
'Setting both out_shardings and out_axis_resources is not '
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
if (not _is_unspecified(in_axis_resources) or
not _is_unspecified(out_axis_resources)):
warnings.warn(
'in_axis_resources and out_axis_resources are deprecated. Please use '
'in_shardings and out_shardings as their replacement.',
DeprecationWarning)
if not _is_unspecified(in_axis_resources):
final_in_shardings = in_axis_resources
@ -1888,6 +1894,9 @@ def _resolve_wsc_args(axis_resources, shardings):
'that axis_resources is deprecated, so use the shardings argument.')
if not _is_unspecified(axis_resources):
warnings.warn(
'axis_resources is deprecated. Please use shardings argument instead.',
DeprecationWarning)
final_shardings = axis_resources
else:
final_shardings = shardings
@ -1897,8 +1906,8 @@ def _resolve_wsc_args(axis_resources, shardings):
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
shardings=_UNSPECIFIED):
def with_sharding_constraint(x, shardings=_UNSPECIFIED,
axis_resources=_UNSPECIFIED):
final_shardings = _resolve_wsc_args(axis_resources, shardings)
x_flat, tree = tree_flatten(x)
user_shardings, _, _ = _prepare_axis_resources(