diff --git a/CHANGELOG.md b/CHANGELOG.md index ac82cff3d..d6692ff69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,12 @@ Remember to align the itemized text with the first line of an item within a list from Feb 13, 2023. * added the {mod}`jax.typing` module, with tools for type annotations of JAX functions. + * The following names have been deprecated: + * `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`. + * `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`. + * `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`. + * `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`. + * `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`. * Breaking Changes * the `initial` argument to reduction functions like :func:`jax.numpy.sum` is now required to be a scalar, consistent with the corresponding NumPy API. diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index df9ec4109..494e43d31 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding import Sharding, OpShardingSharding +from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding # pytype: disable=import-error try: @@ -311,7 +311,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, devices, op_sharding)) pspec = pjit.parse_flatten_op_sharding( op_sharding, mesh)[0].get_partition_spec() - return callback(pjit.NamedSharding(mesh, pspec)) + return callback(NamedSharding(mesh, pspec)) if len(devices) == 1: # If we only have one device in our computation, we can construct a trivial @@ -562,8 +562,8 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]): >>> import jax >>> import jax.numpy as jnp - >>> from jax.experimental.maps import Mesh - >>> from jax.experimental.pjit import PartitionSpec, pjit + >>> from jax.experimental.pjit import pjit + >>> from jax.sharding import Mesh, PartitionSpec >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> def f_(x): diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 029952fd4..dddaecf53 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -389,7 +389,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, return OpShardingSharding(devices, op_sharding) pspec = pjit.parse_flatten_op_sharding(op_sharding, mesh)[0].get_partition_spec() - return pjit.NamedSharding(mesh, pspec) + return jax.sharding.NamedSharding(mesh, pspec) sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition, to_mesh_pspec_sharding, diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 0289a161b..6565c70b0 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -17,9 +17,7 @@ from jax._src.pjit import ( AUTO as AUTO, FROM_GDA as FROM_GDA, - NamedSharding as NamedSharding, ParsedPartitionSpec as ParsedPartitionSpec, - PartitionSpec as PartitionSpec, get_array_mapping as get_array_mapping, hashable_pytree as hashable_pytree, parse_flatten_op_sharding as parse_flatten_op_sharding, @@ -35,3 +33,35 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources, _calc_is_global_sequence, _pjit_jaxpr, _create_mesh_pspec_sharding_from_parsed_pspec, _process_in_axis_resources) + + +from jax._src.pjit import ( + NamedSharding as _deprecated_NamedSharding, + PartitionSpec as _deprecated_PartitionSpec, +) + +import typing +if typing.TYPE_CHECKING: + from jax._src.pjit import ( + NamedSharding as NamedSharding, + PartitionSpec as PartitionSpec, + ) +del typing + +_deprecations = { + # Added Feb 13, 2023: + "NamedSharding": ( + ("jax.experimental.pjit.NamedSharding is deprecated. Use " + "jax.sharding.NamedSharding."), + _deprecated_NamedSharding, + ), + "PartitionSpec": ( + ("jax.experimental.pjit.PartitionSpec is deprecated. Use " + "jax.sharding.PartitionSpec."), + _deprecated_PartitionSpec, + ), +} + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index b1e05fee5..eaccdf629 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -21,8 +21,8 @@ import numpy as np import jax from jax import lax from jax.config import config -from jax.experimental.maps import Mesh -from jax.experimental.pjit import PartitionSpec as P +from jax.sharding import Mesh +from jax.sharding import PartitionSpec as P from jax._src import core from jax._src import test_util as jtu from jax._src.lib import xla_bridge