mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jax.interpreters.pxla: remove deprecated functions:
- jax.interpreters.pxla.device_put - jax.interpreters.pxla.make_sharded_device_array
This commit is contained in:
parent
14f32653a1
commit
3f47ad367d
@ -13,6 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
https://jax.readthedocs.io/en/latest/deprecation.html
|
||||
* JAX now requires NumPy 1.22 or newer as per
|
||||
https://jax.readthedocs.io/en/latest/deprecation.html
|
||||
* `jax.interpreters.pxla.device_put` has been removed. This was deprecated in
|
||||
JAX version 0.4.6: use `jax.device_put` instead.
|
||||
* `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was
|
||||
deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays`
|
||||
instead.
|
||||
|
||||
## jaxlib 0.4.14
|
||||
|
||||
|
@ -42,7 +42,6 @@ from jax._src.interpreters.pxla import (
|
||||
_pmap_sharding_spec as _pmap_sharding_spec,
|
||||
array_types as array_types,
|
||||
custom_resource_typing_rules as custom_resource_typing_rules,
|
||||
device_put as _deprecated_device_put,
|
||||
find_replicas as find_replicas,
|
||||
full_to_shard_p as full_to_shard_p,
|
||||
global_aval_to_result_handler as global_aval_to_result_handler,
|
||||
@ -116,7 +115,6 @@ from jax._src.sharding_specs import (
|
||||
|
||||
from jax._src.interpreters.pxla import (
|
||||
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
|
||||
make_sharded_device_array as _deprecated_make_sharded_device_array,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
@ -128,30 +126,12 @@ _deprecations = {
|
||||
),
|
||||
_deprecated_ShardedDeviceArray,
|
||||
),
|
||||
# make_sharded_device_array is deprecated as of March 3, 2023. jax.Array
|
||||
# is the default since November 2022.
|
||||
"make_sharded_device_array": (
|
||||
(
|
||||
"jax.interpreters.pxla.make_sharded_device_array is deprecated as"
|
||||
" of March 3, 2023. Use jax.make_array_from_single_device_arrays."
|
||||
),
|
||||
_deprecated_make_sharded_device_array,
|
||||
),
|
||||
"device_put": (
|
||||
(
|
||||
"jax.interpreters.pxla.device_put is deprecated. Please use"
|
||||
" jax.device_put."
|
||||
),
|
||||
_deprecated_device_put,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.interpreters.pxla import (
|
||||
ShardedDeviceArray as ShardedDeviceArray,
|
||||
device_put as device_put,
|
||||
make_sharded_device_array as make_sharded_device_array,
|
||||
)
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
|
Loading…
x
Reference in New Issue
Block a user