Add deprecation warning for FROM_GDA usage since that argument is not required anymore.

PiperOrigin-RevId: 519781715
This commit is contained in:
Yash Katariya 2023-03-27 11:32:30 -07:00 committed by jax authors
parent 3c3fa042e3
commit e21aee18a8
4 changed files with 31 additions and 19 deletions

View File

@ -35,6 +35,9 @@ Remember to align the itemized text with the first line of an item within a list
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
jax.Arrays as input and remove the `in_shardings` argument to pjit since
it is optional.
## jaxlib 0.4.7

View File

@ -28,7 +28,7 @@ from jax._src import sharding_impls
from jax._src.interpreters import pxla
from jax.interpreters import xla
from jax._src import pjit as pjit_lib
from jax.experimental.pjit import pjit, FROM_GDA
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax._src import distributed
from jax._src import config as config_internal

View File

@ -16,7 +16,6 @@
from jax._src.pjit import (
AUTO as AUTO,
FROM_GDA as FROM_GDA,
ParsedPartitionSpec as ParsedPartitionSpec,
get_array_mapping as get_array_mapping,
hashable_pytree as hashable_pytree,
@ -38,6 +37,7 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
from jax._src.pjit import (
NamedSharding as _deprecated_NamedSharding,
PartitionSpec as _deprecated_PartitionSpec,
FROM_GDA as _deprecated_FROM_GDA,
)
import typing
@ -45,21 +45,34 @@ if typing.TYPE_CHECKING:
from jax._src.pjit import (
NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
FROM_GDA as FROM_GDA,
)
del typing
_deprecations = {
# Added Feb 13, 2023:
"NamedSharding": (
("jax.experimental.pjit.NamedSharding is deprecated. Use "
"jax.sharding.NamedSharding."),
(
"jax.experimental.pjit.NamedSharding is deprecated. Use "
"jax.sharding.NamedSharding."
),
_deprecated_NamedSharding,
),
"PartitionSpec": (
("jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."),
(
"jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."
),
_deprecated_PartitionSpec,
),
"FROM_GDA": (
(
"jax.experimental.pjit.FROM_GDA is deprecated. Please pass in"
" sharded jax.Arrays as input and remove the in_shardings argument"
" to pjit since it is optional."
),
_deprecated_FROM_GDA,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr

View File

@ -420,9 +420,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
}
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
f = pjit.pjit(
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
)
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
out = f(gda1)
for s in out.addressable_shards:
device_id = s.device.id
@ -471,9 +469,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
}
with global_mesh:
f = pjit.pjit(
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
)
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
out = f(gda1)
for s in out.addressable_shards: