mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead. PiperOrigin-RevId: 509837529
This commit is contained in:
parent
69b8a03400
commit
00d45feee6
@ -30,6 +30,12 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
from Feb 13, 2023.
|
||||
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
|
||||
functions.
|
||||
* The following names have been deprecated:
|
||||
* `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`.
|
||||
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
|
||||
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`.
|
||||
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
* Breaking Changes
|
||||
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
|
||||
is now required to be a scalar, consistent with the corresponding NumPy API.
|
||||
|
@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding, OpShardingSharding
|
||||
from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding
|
||||
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
@ -311,7 +311,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
devices, op_sharding))
|
||||
pspec = pjit.parse_flatten_op_sharding(
|
||||
op_sharding, mesh)[0].get_partition_spec()
|
||||
return callback(pjit.NamedSharding(mesh, pspec))
|
||||
return callback(NamedSharding(mesh, pspec))
|
||||
|
||||
if len(devices) == 1:
|
||||
# If we only have one device in our computation, we can construct a trivial
|
||||
@ -562,8 +562,8 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]):
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental.pjit import PartitionSpec, pjit
|
||||
>>> from jax.experimental.pjit import pjit
|
||||
>>> from jax.sharding import Mesh, PartitionSpec
|
||||
>>>
|
||||
>>> x = jnp.arange(8, dtype=jnp.float32)
|
||||
>>> def f_(x):
|
||||
|
@ -389,7 +389,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
return OpShardingSharding(devices, op_sharding)
|
||||
pspec = pjit.parse_flatten_op_sharding(op_sharding,
|
||||
mesh)[0].get_partition_spec()
|
||||
return pjit.NamedSharding(mesh, pspec)
|
||||
return jax.sharding.NamedSharding(mesh, pspec)
|
||||
|
||||
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
|
||||
to_mesh_pspec_sharding,
|
||||
|
@ -17,9 +17,7 @@
|
||||
from jax._src.pjit import (
|
||||
AUTO as AUTO,
|
||||
FROM_GDA as FROM_GDA,
|
||||
NamedSharding as NamedSharding,
|
||||
ParsedPartitionSpec as ParsedPartitionSpec,
|
||||
PartitionSpec as PartitionSpec,
|
||||
get_array_mapping as get_array_mapping,
|
||||
hashable_pytree as hashable_pytree,
|
||||
parse_flatten_op_sharding as parse_flatten_op_sharding,
|
||||
@ -35,3 +33,35 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
|
||||
_calc_is_global_sequence, _pjit_jaxpr,
|
||||
_create_mesh_pspec_sharding_from_parsed_pspec,
|
||||
_process_in_axis_resources)
|
||||
|
||||
|
||||
from jax._src.pjit import (
|
||||
NamedSharding as _deprecated_NamedSharding,
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
)
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.pjit import (
|
||||
NamedSharding as NamedSharding,
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
del typing
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 13, 2023:
|
||||
"NamedSharding": (
|
||||
("jax.experimental.pjit.NamedSharding is deprecated. Use "
|
||||
"jax.sharding.NamedSharding."),
|
||||
_deprecated_NamedSharding,
|
||||
),
|
||||
"PartitionSpec": (
|
||||
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -21,8 +21,8 @@ import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.pjit import PartitionSpec as P
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge
|
||||
|
Loading…
x
Reference in New Issue
Block a user