mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
This commit is contained in:
parent
3c3fa042e3
commit
e21aee18a8
@ -35,6 +35,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||||
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
|
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
|
||||||
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
|
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
|
||||||
|
* `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
|
||||||
|
jax.Arrays as input and remove the `in_shardings` argument to pjit since
|
||||||
|
it is optional.
|
||||||
|
|
||||||
## jaxlib 0.4.7
|
## jaxlib 0.4.7
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from jax._src import sharding_impls
|
|||||||
from jax._src.interpreters import pxla
|
from jax._src.interpreters import pxla
|
||||||
from jax.interpreters import xla
|
from jax.interpreters import xla
|
||||||
from jax._src import pjit as pjit_lib
|
from jax._src import pjit as pjit_lib
|
||||||
from jax.experimental.pjit import pjit, FROM_GDA
|
from jax.experimental.pjit import pjit
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
from jax._src import distributed
|
from jax._src import distributed
|
||||||
from jax._src import config as config_internal
|
from jax._src import config as config_internal
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
from jax._src.pjit import (
|
from jax._src.pjit import (
|
||||||
AUTO as AUTO,
|
AUTO as AUTO,
|
||||||
FROM_GDA as FROM_GDA,
|
|
||||||
ParsedPartitionSpec as ParsedPartitionSpec,
|
ParsedPartitionSpec as ParsedPartitionSpec,
|
||||||
get_array_mapping as get_array_mapping,
|
get_array_mapping as get_array_mapping,
|
||||||
hashable_pytree as hashable_pytree,
|
hashable_pytree as hashable_pytree,
|
||||||
@ -38,6 +37,7 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
|
|||||||
from jax._src.pjit import (
|
from jax._src.pjit import (
|
||||||
NamedSharding as _deprecated_NamedSharding,
|
NamedSharding as _deprecated_NamedSharding,
|
||||||
PartitionSpec as _deprecated_PartitionSpec,
|
PartitionSpec as _deprecated_PartitionSpec,
|
||||||
|
FROM_GDA as _deprecated_FROM_GDA,
|
||||||
)
|
)
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
@ -45,21 +45,34 @@ if typing.TYPE_CHECKING:
|
|||||||
from jax._src.pjit import (
|
from jax._src.pjit import (
|
||||||
NamedSharding as NamedSharding,
|
NamedSharding as NamedSharding,
|
||||||
PartitionSpec as PartitionSpec,
|
PartitionSpec as PartitionSpec,
|
||||||
|
FROM_GDA as FROM_GDA,
|
||||||
)
|
)
|
||||||
del typing
|
del typing
|
||||||
|
|
||||||
_deprecations = {
|
_deprecations = {
|
||||||
# Added Feb 13, 2023:
|
# Added Feb 13, 2023:
|
||||||
"NamedSharding": (
|
"NamedSharding": (
|
||||||
("jax.experimental.pjit.NamedSharding is deprecated. Use "
|
(
|
||||||
"jax.sharding.NamedSharding."),
|
"jax.experimental.pjit.NamedSharding is deprecated. Use "
|
||||||
_deprecated_NamedSharding,
|
"jax.sharding.NamedSharding."
|
||||||
),
|
),
|
||||||
"PartitionSpec": (
|
_deprecated_NamedSharding,
|
||||||
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
|
),
|
||||||
"jax.sharding.PartitionSpec."),
|
"PartitionSpec": (
|
||||||
_deprecated_PartitionSpec,
|
(
|
||||||
),
|
"jax.experimental.pjit.PartitionSpec is deprecated. Use "
|
||||||
|
"jax.sharding.PartitionSpec."
|
||||||
|
),
|
||||||
|
_deprecated_PartitionSpec,
|
||||||
|
),
|
||||||
|
"FROM_GDA": (
|
||||||
|
(
|
||||||
|
"jax.experimental.pjit.FROM_GDA is deprecated. Please pass in"
|
||||||
|
" sharded jax.Arrays as input and remove the in_shardings argument"
|
||||||
|
" to pjit since it is optional."
|
||||||
|
),
|
||||||
|
_deprecated_FROM_GDA,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
|
@ -420,9 +420,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||||
f = pjit.pjit(
|
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
|
||||||
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
|
|
||||||
)
|
|
||||||
out = f(gda1)
|
out = f(gda1)
|
||||||
for s in out.addressable_shards:
|
for s in out.addressable_shards:
|
||||||
device_id = s.device.id
|
device_id = s.device.id
|
||||||
@ -471,9 +469,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with global_mesh:
|
with global_mesh:
|
||||||
f = pjit.pjit(
|
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
|
||||||
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
|
|
||||||
)
|
|
||||||
out = f(gda1)
|
out = f(gda1)
|
||||||
|
|
||||||
for s in out.addressable_shards:
|
for s in out.addressable_shards:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user