mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead. PiperOrigin-RevId: 509837529
This commit is contained in:
parent
69b8a03400
commit
00d45feee6
@ -30,6 +30,12 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
from Feb 13, 2023.
|
from Feb 13, 2023.
|
||||||
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
|
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
|
||||||
functions.
|
functions.
|
||||||
|
* The following names have been deprecated:
|
||||||
|
* `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`.
|
||||||
|
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
|
||||||
|
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||||
|
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`.
|
||||||
|
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||||
* Breaking Changes
|
* Breaking Changes
|
||||||
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
|
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
|
||||||
is now required to be a scalar, consistent with the corresponding NumPy API.
|
is now required to be a scalar, consistent with the corresponding NumPy API.
|
||||||
|
@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf
|
|||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.lib.mlir.dialects import hlo
|
from jax._src.lib.mlir.dialects import hlo
|
||||||
from jax._src.sharding import Sharding, OpShardingSharding
|
from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding
|
||||||
|
|
||||||
# pytype: disable=import-error
|
# pytype: disable=import-error
|
||||||
try:
|
try:
|
||||||
@ -311,7 +311,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
|||||||
devices, op_sharding))
|
devices, op_sharding))
|
||||||
pspec = pjit.parse_flatten_op_sharding(
|
pspec = pjit.parse_flatten_op_sharding(
|
||||||
op_sharding, mesh)[0].get_partition_spec()
|
op_sharding, mesh)[0].get_partition_spec()
|
||||||
return callback(pjit.NamedSharding(mesh, pspec))
|
return callback(NamedSharding(mesh, pspec))
|
||||||
|
|
||||||
if len(devices) == 1:
|
if len(devices) == 1:
|
||||||
# If we only have one device in our computation, we can construct a trivial
|
# If we only have one device in our computation, we can construct a trivial
|
||||||
@ -562,8 +562,8 @@ def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]):
|
|||||||
|
|
||||||
>>> import jax
|
>>> import jax
|
||||||
>>> import jax.numpy as jnp
|
>>> import jax.numpy as jnp
|
||||||
>>> from jax.experimental.maps import Mesh
|
>>> from jax.experimental.pjit import pjit
|
||||||
>>> from jax.experimental.pjit import PartitionSpec, pjit
|
>>> from jax.sharding import Mesh, PartitionSpec
|
||||||
>>>
|
>>>
|
||||||
>>> x = jnp.arange(8, dtype=jnp.float32)
|
>>> x = jnp.arange(8, dtype=jnp.float32)
|
||||||
>>> def f_(x):
|
>>> def f_(x):
|
||||||
|
@ -389,7 +389,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
|||||||
return OpShardingSharding(devices, op_sharding)
|
return OpShardingSharding(devices, op_sharding)
|
||||||
pspec = pjit.parse_flatten_op_sharding(op_sharding,
|
pspec = pjit.parse_flatten_op_sharding(op_sharding,
|
||||||
mesh)[0].get_partition_spec()
|
mesh)[0].get_partition_spec()
|
||||||
return pjit.NamedSharding(mesh, pspec)
|
return jax.sharding.NamedSharding(mesh, pspec)
|
||||||
|
|
||||||
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
|
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
|
||||||
to_mesh_pspec_sharding,
|
to_mesh_pspec_sharding,
|
||||||
|
@ -17,9 +17,7 @@
|
|||||||
from jax._src.pjit import (
|
from jax._src.pjit import (
|
||||||
AUTO as AUTO,
|
AUTO as AUTO,
|
||||||
FROM_GDA as FROM_GDA,
|
FROM_GDA as FROM_GDA,
|
||||||
NamedSharding as NamedSharding,
|
|
||||||
ParsedPartitionSpec as ParsedPartitionSpec,
|
ParsedPartitionSpec as ParsedPartitionSpec,
|
||||||
PartitionSpec as PartitionSpec,
|
|
||||||
get_array_mapping as get_array_mapping,
|
get_array_mapping as get_array_mapping,
|
||||||
hashable_pytree as hashable_pytree,
|
hashable_pytree as hashable_pytree,
|
||||||
parse_flatten_op_sharding as parse_flatten_op_sharding,
|
parse_flatten_op_sharding as parse_flatten_op_sharding,
|
||||||
@ -35,3 +33,35 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
|
|||||||
_calc_is_global_sequence, _pjit_jaxpr,
|
_calc_is_global_sequence, _pjit_jaxpr,
|
||||||
_create_mesh_pspec_sharding_from_parsed_pspec,
|
_create_mesh_pspec_sharding_from_parsed_pspec,
|
||||||
_process_in_axis_resources)
|
_process_in_axis_resources)
|
||||||
|
|
||||||
|
|
||||||
|
from jax._src.pjit import (
|
||||||
|
NamedSharding as _deprecated_NamedSharding,
|
||||||
|
PartitionSpec as _deprecated_PartitionSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
import typing
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from jax._src.pjit import (
|
||||||
|
NamedSharding as NamedSharding,
|
||||||
|
PartitionSpec as PartitionSpec,
|
||||||
|
)
|
||||||
|
del typing
|
||||||
|
|
||||||
|
_deprecations = {
|
||||||
|
# Added Feb 13, 2023:
|
||||||
|
"NamedSharding": (
|
||||||
|
("jax.experimental.pjit.NamedSharding is deprecated. Use "
|
||||||
|
"jax.sharding.NamedSharding."),
|
||||||
|
_deprecated_NamedSharding,
|
||||||
|
),
|
||||||
|
"PartitionSpec": (
|
||||||
|
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
|
||||||
|
"jax.sharding.PartitionSpec."),
|
||||||
|
_deprecated_PartitionSpec,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||||
|
del _deprecation_getattr
|
||||||
|
@ -21,8 +21,8 @@ import numpy as np
|
|||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
from jax.experimental.maps import Mesh
|
from jax.sharding import Mesh
|
||||||
from jax.experimental.pjit import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.lib import xla_bridge
|
from jax._src.lib import xla_bridge
|
||||||
|
Loading…
x
Reference in New Issue
Block a user