Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.

Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
This commit is contained in:
Peter Hawkins 2023-02-15 08:13:58 -08:00 committed by jax authors
parent 69b8a03400
commit 00d45feee6
5 changed files with 45 additions and 9 deletions

View File

@ -30,6 +30,12 @@ Remember to align the itemized text with the first line of an item within a list
from Feb 13, 2023.
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
functions.
* The following names have been deprecated:
* `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* Breaking Changes
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
is now required to be a scalar, consistent with the corresponding NumPy API.

View File

@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding, OpShardingSharding
from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding
# pytype: disable=import-error
try:
@ -311,7 +311,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
devices, op_sharding))
pspec = pjit.parse_flatten_op_sharding(
op_sharding, mesh)[0].get_partition_spec()
return callback(pjit.NamedSharding(mesh, pspec))
return callback(NamedSharding(mesh, pspec))
if len(devices) == 1:
# If we only have one device in our computation, we can construct a trivial
@ -562,8 +562,8 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]):
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental.pjit import PartitionSpec, pjit
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh, PartitionSpec
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> def f_(x):

View File

@ -389,7 +389,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
return OpShardingSharding(devices, op_sharding)
pspec = pjit.parse_flatten_op_sharding(op_sharding,
mesh)[0].get_partition_spec()
return pjit.NamedSharding(mesh, pspec)
return jax.sharding.NamedSharding(mesh, pspec)
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
to_mesh_pspec_sharding,

View File

@ -17,9 +17,7 @@
from jax._src.pjit import (
AUTO as AUTO,
FROM_GDA as FROM_GDA,
NamedSharding as NamedSharding,
ParsedPartitionSpec as ParsedPartitionSpec,
PartitionSpec as PartitionSpec,
get_array_mapping as get_array_mapping,
hashable_pytree as hashable_pytree,
parse_flatten_op_sharding as parse_flatten_op_sharding,
@ -35,3 +33,35 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
_calc_is_global_sequence, _pjit_jaxpr,
_create_mesh_pspec_sharding_from_parsed_pspec,
_process_in_axis_resources)
from jax._src.pjit import (
NamedSharding as _deprecated_NamedSharding,
PartitionSpec as _deprecated_PartitionSpec,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.pjit import (
NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
)
del typing
_deprecations = {
# Added Feb 13, 2023:
"NamedSharding": (
("jax.experimental.pjit.NamedSharding is deprecated. Use "
"jax.sharding.NamedSharding."),
_deprecated_NamedSharding,
),
"PartitionSpec": (
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."),
_deprecated_PartitionSpec,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -21,8 +21,8 @@ import numpy as np
import jax
from jax import lax
from jax.config import config
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec as P
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax._src import core
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge