Add deprecation warning for FROM_GDA usage since that argument is not required anymore.

PiperOrigin-RevId: 519781715
This commit is contained in:
Yash Katariya 2023-03-27 11:32:30 -07:00 committed by jax authors
parent 3c3fa042e3
commit e21aee18a8
4 changed files with 31 additions and 19 deletions

View File

@ -35,6 +35,9 @@ Remember to align the itemized text with the first line of an item within a list
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`. * `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`. * `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
jax.Arrays as input and remove the `in_shardings` argument to pjit since
it is optional.
## jaxlib 0.4.7 ## jaxlib 0.4.7

View File

@ -28,7 +28,7 @@ from jax._src import sharding_impls
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax.interpreters import xla from jax.interpreters import xla
from jax._src import pjit as pjit_lib from jax._src import pjit as pjit_lib
from jax.experimental.pjit import pjit, FROM_GDA from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jax._src import distributed from jax._src import distributed
from jax._src import config as config_internal from jax._src import config as config_internal

View File

@ -16,7 +16,6 @@
from jax._src.pjit import ( from jax._src.pjit import (
AUTO as AUTO, AUTO as AUTO,
FROM_GDA as FROM_GDA,
ParsedPartitionSpec as ParsedPartitionSpec, ParsedPartitionSpec as ParsedPartitionSpec,
get_array_mapping as get_array_mapping, get_array_mapping as get_array_mapping,
hashable_pytree as hashable_pytree, hashable_pytree as hashable_pytree,
@ -38,6 +37,7 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
from jax._src.pjit import ( from jax._src.pjit import (
NamedSharding as _deprecated_NamedSharding, NamedSharding as _deprecated_NamedSharding,
PartitionSpec as _deprecated_PartitionSpec, PartitionSpec as _deprecated_PartitionSpec,
FROM_GDA as _deprecated_FROM_GDA,
) )
import typing import typing
@ -45,21 +45,34 @@ if typing.TYPE_CHECKING:
from jax._src.pjit import ( from jax._src.pjit import (
NamedSharding as NamedSharding, NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec, PartitionSpec as PartitionSpec,
FROM_GDA as FROM_GDA,
) )
del typing del typing
_deprecations = { _deprecations = {
# Added Feb 13, 2023: # Added Feb 13, 2023:
"NamedSharding": ( "NamedSharding": (
("jax.experimental.pjit.NamedSharding is deprecated. Use " (
"jax.sharding.NamedSharding."), "jax.experimental.pjit.NamedSharding is deprecated. Use "
_deprecated_NamedSharding, "jax.sharding.NamedSharding."
), ),
"PartitionSpec": ( _deprecated_NamedSharding,
("jax.experimental.pjit.PartitionSpec is deprecated. Use " ),
"jax.sharding.PartitionSpec."), "PartitionSpec": (
_deprecated_PartitionSpec, (
), "jax.experimental.pjit.PartitionSpec is deprecated. Use "
"jax.sharding.PartitionSpec."
),
_deprecated_PartitionSpec,
),
"FROM_GDA": (
(
"jax.experimental.pjit.FROM_GDA is deprecated. Please pass in"
" sharded jax.Arrays as input and remove the in_shardings argument"
" to pjit since it is optional."
),
_deprecated_FROM_GDA,
),
} }
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr from jax._src.deprecations import deprecation_getattr as _deprecation_getattr

View File

@ -420,9 +420,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
} }
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names): with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
f = pjit.pjit( f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
)
out = f(gda1) out = f(gda1)
for s in out.addressable_shards: for s in out.addressable_shards:
device_id = s.device.id device_id = s.device.id
@ -471,9 +469,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
} }
with global_mesh: with global_mesh:
f = pjit.pjit( f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
lambda x: x, in_shardings=pjit.FROM_GDA, out_shardings=mesh_axes
)
out = f(gda1) out = f(gda1)
for s in out.addressable_shards: for s in out.addressable_shards: