mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
This commit is contained in:
parent
25a9a978fb
commit
fe3fed3627
20
CHANGELOG.md
20
CHANGELOG.md
@ -11,16 +11,18 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Deprecations
|
||||
* The following APIs have been removed after a 3 month deprecation period, in
|
||||
accordance with the {ref}`api-compatibility` policy:
|
||||
- `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
- `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
|
||||
- `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
|
||||
- `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
- `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
|
||||
* `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
* `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
|
||||
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
|
||||
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
* `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
|
||||
as input and remove the optional `in_shardings` argument to `pjit`.
|
||||
- `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
- `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
|
||||
- `jax.interpreters.xla.Device`: use `jax.Device`.
|
||||
- `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead,
|
||||
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
|
||||
* `jax.interpreters.xla.Device`: use `jax.Device`.
|
||||
* `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead
|
||||
* `axis_resources` argument of `with_sharding_constraint` is removed. Please
|
||||
use `shardings` instead.
|
||||
|
||||
|
||||
## jaxlib 0.4.11
|
||||
|
@ -1786,32 +1786,7 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
||||
|
||||
# -------------------- with_sharding_constraint --------------------
|
||||
|
||||
def _resolve_wsc_args(axis_resources, shardings):
|
||||
if not is_unspecified(axis_resources) and not is_unspecified(shardings):
|
||||
raise ValueError(
|
||||
'Setting both axis_resources and shardings is not '
|
||||
'allowed. axis_resources is deprecated. Please use shardings.')
|
||||
if is_unspecified(axis_resources) and is_unspecified(shardings):
|
||||
raise ValueError(
|
||||
'Not specifying shardings to `with_sharding_constraint` is not allowed. '
|
||||
'Please specify the shardings argument with a concrete sharding. Note '
|
||||
'that axis_resources is deprecated, so use the shardings argument.')
|
||||
|
||||
if not is_unspecified(axis_resources):
|
||||
warnings.warn(
|
||||
'axis_resources is deprecated. Please use shardings argument instead.',
|
||||
DeprecationWarning)
|
||||
final_shardings = axis_resources
|
||||
else:
|
||||
final_shardings = shardings
|
||||
return final_shardings
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
|
||||
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
|
||||
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
|
||||
def with_sharding_constraint(x, shardings=UNSPECIFIED,
|
||||
axis_resources=UNSPECIFIED):
|
||||
def with_sharding_constraint(x, shardings):
|
||||
"""Mechanism to constrain the sharding of an Array inside a jitted computation
|
||||
|
||||
This is a strict constraint for the GSPMD partitioner and not a hint. For examples
|
||||
@ -1821,17 +1796,15 @@ def with_sharding_constraint(x, shardings=UNSPECIFIED,
|
||||
x: PyTree of jax.Arrays which will have their shardings constrainted
|
||||
shardings: PyTree of sharding specifications. Valid values are the same as for
|
||||
the ``in_shardings`` argument of :func:`jax.experimental.pjit`.
|
||||
axis_resources: (deprecated) use shardings instead.
|
||||
Returns:
|
||||
x_with_shardings: PyTree of jax.Arrays with specified sharding constraints.
|
||||
|
||||
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
|
||||
"""
|
||||
final_shardings = _resolve_wsc_args(axis_resources, shardings)
|
||||
x_flat, tree = tree_flatten(x)
|
||||
user_shardings, _, _ = prepare_axis_resources(
|
||||
final_shardings, "shardings", allow_unconstrained_dims=True)
|
||||
del final_shardings
|
||||
shardings, "shardings", allow_unconstrained_dims=True)
|
||||
del shardings
|
||||
|
||||
user_shardings_flat = tuple(
|
||||
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
|
||||
|
@ -2953,19 +2953,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
"Setting both out_shardings and out_axis_resources is not allowed"):
|
||||
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))
|
||||
|
||||
def test_set_none_wsc_axis_resources_and_shardings(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Not specifying shardings to `with_sharding_constraint` is not allowed."):
|
||||
pjit(jax.lax.with_sharding_constraint(jnp.arange(8)))
|
||||
|
||||
def test_set_both_wsc_axis_resources_and_shardings(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Setting both axis_resources and shardings is not allowed"):
|
||||
pjit(jax.lax.with_sharding_constraint(
|
||||
jnp.arange(8), axis_resources=P('x'), shardings=P('x')))
|
||||
|
||||
def test_with_sharding_constraint_spmd_axis_name(self):
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
|
||||
shape = (8, 4, 2, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user