Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.

Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
This commit is contained in:
Peter Hawkins 2023-02-15 08:13:58 -08:00 committed by jax authors
parent 69b8a03400
commit 00d45feee6
5 changed files with 45 additions and 9 deletions

View File

@ -30,6 +30,12 @@ Remember to align the itemized text with the first line of an item within a list
from Feb 13, 2023. from Feb 13, 2023.
* added the {mod}`jax.typing` module, with tools for type annotations of JAX * added the {mod}`jax.typing` module, with tools for type annotations of JAX
functions. functions.
* The following names have been deprecated:
* `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* Breaking Changes * Breaking Changes
* the `initial` argument to reduction functions like :func:`jax.numpy.sum` * the `initial` argument to reduction functions like :func:`jax.numpy.sum`
is now required to be a scalar, consistent with the corresponding NumPy API. is now required to be a scalar, consistent with the corresponding NumPy API.

View File

@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding, OpShardingSharding from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding
# pytype: disable=import-error # pytype: disable=import-error
try: try:
@ -311,7 +311,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
devices, op_sharding)) devices, op_sharding))
pspec = pjit.parse_flatten_op_sharding( pspec = pjit.parse_flatten_op_sharding(
op_sharding, mesh)[0].get_partition_spec() op_sharding, mesh)[0].get_partition_spec()
return callback(pjit.NamedSharding(mesh, pspec)) return callback(NamedSharding(mesh, pspec))
if len(devices) == 1: if len(devices) == 1:
# If we only have one device in our computation, we can construct a trivial # If we only have one device in our computation, we can construct a trivial
@ -562,8 +562,8 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]):
>>> import jax >>> import jax
>>> import jax.numpy as jnp >>> import jax.numpy as jnp
>>> from jax.experimental.maps import Mesh >>> from jax.experimental.pjit import pjit
>>> from jax.experimental.pjit import PartitionSpec, pjit >>> from jax.sharding import Mesh, PartitionSpec
>>> >>>
>>> x = jnp.arange(8, dtype=jnp.float32) >>> x = jnp.arange(8, dtype=jnp.float32)
>>> def f_(x): >>> def f_(x):

View File

@ -389,7 +389,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
return OpShardingSharding(devices, op_sharding) return OpShardingSharding(devices, op_sharding)
pspec = pjit.parse_flatten_op_sharding(op_sharding, pspec = pjit.parse_flatten_op_sharding(op_sharding,
mesh)[0].get_partition_spec() mesh)[0].get_partition_spec()
return pjit.NamedSharding(mesh, pspec) return jax.sharding.NamedSharding(mesh, pspec)
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition, sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
to_mesh_pspec_sharding, to_mesh_pspec_sharding,

View File

@ -17,9 +17,7 @@
from jax._src.pjit import ( from jax._src.pjit import (
AUTO as AUTO, AUTO as AUTO,
FROM_GDA as FROM_GDA, FROM_GDA as FROM_GDA,
NamedSharding as NamedSharding,
ParsedPartitionSpec as ParsedPartitionSpec, ParsedPartitionSpec as ParsedPartitionSpec,
PartitionSpec as PartitionSpec,
get_array_mapping as get_array_mapping, get_array_mapping as get_array_mapping,
hashable_pytree as hashable_pytree, hashable_pytree as hashable_pytree,
parse_flatten_op_sharding as parse_flatten_op_sharding, parse_flatten_op_sharding as parse_flatten_op_sharding,
@ -35,3 +33,35 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
_calc_is_global_sequence, _pjit_jaxpr, _calc_is_global_sequence, _pjit_jaxpr,
_create_mesh_pspec_sharding_from_parsed_pspec, _create_mesh_pspec_sharding_from_parsed_pspec,
_process_in_axis_resources) _process_in_axis_resources)
from jax._src.pjit import (
NamedSharding as _deprecated_NamedSharding,
PartitionSpec as _deprecated_PartitionSpec,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.pjit import (
NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
)
del typing
_deprecations = {
# Added Feb 13, 2023:
"NamedSharding": (
("jax.experimental.pjit.NamedSharding is deprecated. Use "
"jax.sharding.NamedSharding."),
_deprecated_NamedSharding,
),
"PartitionSpec": (
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."),
_deprecated_PartitionSpec,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -21,8 +21,8 @@ import numpy as np
import jax import jax
from jax import lax from jax import lax
from jax.config import config from jax.config import config
from jax.experimental.maps import Mesh from jax.sharding import Mesh
from jax.experimental.pjit import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jax._src import core from jax._src import core
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lib import xla_bridge from jax._src.lib import xla_bridge