mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Raise deprecation warnings for {in|out}_axis_resources
for pjit and axis_resources
for with_sharding_constraint
PiperOrigin-RevId: 520748845
This commit is contained in:
parent
36bf14b044
commit
69c9660aab
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.9
|
||||
|
||||
* Deprecations
|
||||
* The `in_axis_resources` and `out_axis_resources` arguments of pjit have been
|
||||
deprecated. Please use `in_shardings` and `out_shardings` respectively.
|
||||
|
||||
## jaxlib 0.4.9
|
||||
|
||||
## jax 0.4.8 (March 29, 2023)
|
||||
|
@ -289,6 +289,12 @@ def _resolve_axis_resources_and_shardings_arg(
|
||||
raise ValueError(
|
||||
'Setting both out_shardings and out_axis_resources is not '
|
||||
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
|
||||
if (not _is_unspecified(in_axis_resources) or
|
||||
not _is_unspecified(out_axis_resources)):
|
||||
warnings.warn(
|
||||
'in_axis_resources and out_axis_resources are deprecated. Please use '
|
||||
'in_shardings and out_shardings as their replacement.',
|
||||
DeprecationWarning)
|
||||
|
||||
if not _is_unspecified(in_axis_resources):
|
||||
final_in_shardings = in_axis_resources
|
||||
@ -1888,6 +1894,9 @@ def _resolve_wsc_args(axis_resources, shardings):
|
||||
'that axis_resources is deprecated, so use the shardings argument.')
|
||||
|
||||
if not _is_unspecified(axis_resources):
|
||||
warnings.warn(
|
||||
'axis_resources is deprecated. Please use shardings argument instead.',
|
||||
DeprecationWarning)
|
||||
final_shardings = axis_resources
|
||||
else:
|
||||
final_shardings = shardings
|
||||
@ -1897,8 +1906,8 @@ def _resolve_wsc_args(axis_resources, shardings):
|
||||
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
|
||||
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
|
||||
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
|
||||
def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
|
||||
shardings=_UNSPECIFIED):
|
||||
def with_sharding_constraint(x, shardings=_UNSPECIFIED,
|
||||
axis_resources=_UNSPECIFIED):
|
||||
final_shardings = _resolve_wsc_args(axis_resources, shardings)
|
||||
x_flat, tree = tree_flatten(x)
|
||||
user_shardings, _, _ = _prepare_axis_resources(
|
||||
|
Loading…
x
Reference in New Issue
Block a user