mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
This commit is contained in:
parent
3c3fa042e3
commit
e21aee18a8
@ -35,6 +35,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
|
||||
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
|
||||
* `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
|
||||
jax.Arrays as input and remove the `in_shardings` argument to pjit since
|
||||
it is optional.
|
||||
|
||||
## jaxlib 0.4.7
|
||||
|
||||
|
@ -28,7 +28,7 @@ from jax._src import sharding_impls
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax.experimental.pjit import pjit, FROM_GDA
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import distributed
|
||||
from jax._src import config as config_internal
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
from jax._src.pjit import (
|
||||
AUTO as AUTO,
|
||||
FROM_GDA as FROM_GDA,
|
||||
ParsedPartitionSpec as ParsedPartitionSpec,
|
||||
get_array_mapping as get_array_mapping,
|
||||
hashable_pytree as hashable_pytree,
|
||||
@ -38,6 +37,7 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
|
||||
from jax._src.pjit import (
|
||||
NamedSharding as _deprecated_NamedSharding,
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
FROM_GDA as _deprecated_FROM_GDA,
|
||||
)
|
||||
|
||||
import typing
|
||||
@ -45,21 +45,34 @@ if typing.TYPE_CHECKING:
|
||||
from jax._src.pjit import (
|
||||
NamedSharding as NamedSharding,
|
||||
PartitionSpec as PartitionSpec,
|
||||
FROM_GDA as FROM_GDA,
|
||||
)
|
||||
del typing
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 13, 2023:
|
||||
"NamedSharding": (
|
||||
("jax.experimental.pjit.NamedSharding is deprecated. Use "
|
||||
"jax.sharding.NamedSharding."),
|
||||
_deprecated_NamedSharding,
|
||||
),
|
||||
"PartitionSpec": (
|
||||
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
# Added Feb 13, 2023:
|
||||
"NamedSharding": (
|
||||
(
|
||||
"jax.experimental.pjit.NamedSharding is deprecated. Use "
|
||||
"jax.sharding.NamedSharding."
|
||||
),
|
||||
_deprecated_NamedSharding,
|
||||
),
|
||||
"PartitionSpec": (
|
||||
(
|
||||
"jax.experimental.pjit.PartitionSpec is deprecated. Use "
|
||||
"jax.sharding.PartitionSpec."
|
||||
),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
"FROM_GDA": (
|
||||
(
|
||||
"jax.experimental.pjit.FROM_GDA is deprecated. Please pass in"
|
||||
" sharded jax.Arrays as input and remove the in_shardings argument"
|
||||
" to pjit since it is optional."
|
||||
),
|
||||
_deprecated_FROM_GDA,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
|
@ -420,9 +420,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
}
|
||||
|
||||
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
f = pjit.pjit(
|
||||
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
|
||||
)
|
||||
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
|
||||
out = f(gda1)
|
||||
for s in out.addressable_shards:
|
||||
device_id = s.device.id
|
||||
@ -471,9 +469,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
}
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(
|
||||
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
|
||||
)
|
||||
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
|
||||
out = f(gda1)
|
||||
|
||||
for s in out.addressable_shards:
|
||||
|
Loading…
x
Reference in New Issue
Block a user