jax.interpreters.pxla: remove deprecated functions:

- jax.interpreters.pxla.device_put
- jax.interpreters.pxla.make_sharded_device_array
This commit is contained in:
Jake VanderPlas 2023-06-23 00:28:49 -07:00
parent 14f32653a1
commit 3f47ad367d
2 changed files with 5 additions and 20 deletions

View File

@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
https://jax.readthedocs.io/en/latest/deprecation.html
* JAX now requires NumPy 1.22 or newer as per
https://jax.readthedocs.io/en/latest/deprecation.html
* `jax.interpreters.pxla.device_put` has been removed. This was deprecated in
JAX version 0.4.6: use `jax.device_put` instead.
* `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was
deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays`
instead.
## jaxlib 0.4.14

View File

@ -42,7 +42,6 @@ from jax._src.interpreters.pxla import (
_pmap_sharding_spec as _pmap_sharding_spec,
array_types as array_types,
custom_resource_typing_rules as custom_resource_typing_rules,
device_put as _deprecated_device_put,
find_replicas as find_replicas,
full_to_shard_p as full_to_shard_p,
global_aval_to_result_handler as global_aval_to_result_handler,
@ -116,7 +115,6 @@ from jax._src.sharding_specs import (
from jax._src.interpreters.pxla import (
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
make_sharded_device_array as _deprecated_make_sharded_device_array,
)
_deprecations = {
@ -128,30 +126,12 @@ _deprecations = {
),
_deprecated_ShardedDeviceArray,
),
# make_sharded_device_array is deprecated as of March 3, 2023. jax.Array
# is the default since November 2022.
"make_sharded_device_array": (
(
"jax.interpreters.pxla.make_sharded_device_array is deprecated as"
" of March 3, 2023. Use jax.make_array_from_single_device_arrays."
),
_deprecated_make_sharded_device_array,
),
"device_put": (
(
"jax.interpreters.pxla.device_put is deprecated. Please use"
" jax.device_put."
),
_deprecated_device_put,
),
}
import typing
if typing.TYPE_CHECKING:
from jax._src.interpreters.pxla import (
ShardedDeviceArray as ShardedDeviceArray,
device_put as device_put,
make_sharded_device_array as make_sharded_device_array,
)
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr