mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16159 from jakevdp:deprecations
PiperOrigin-RevId: 536003451
This commit is contained in:
commit
ae9160a4e9
@ -19,8 +19,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
as input and remove the optional `in_shardings` argument to `pjit`.
|
||||
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
|
||||
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
|
||||
* `jax.interpreters.xla.Buffer`: use `jax.Array`.
|
||||
* `jax.interpreters.xla.Device`: use `jax.Device`.
|
||||
* `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead
|
||||
* `jax.interpreters.xla.DeviceArray`: use `jax.Array`.
|
||||
* `jax.interpreters.xla.device_put`: use `jax.device_put`.
|
||||
* `jax.interpreters.xla.xla_call_p`: use `jax.experimental.pjit.pjit_p`.
|
||||
* `axis_resources` argument of `with_sharding_constraint` is removed. Please
|
||||
use `shardings` instead.
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.interpreters.xla import (
|
||||
DeviceArray as _deprecated_DeviceArray,
|
||||
TranslationContext as TranslationContext,
|
||||
TranslationRule as TranslationRule,
|
||||
abstractify as abstractify,
|
||||
@ -29,7 +28,6 @@ from jax._src.interpreters.xla import (
|
||||
register_translation as register_translation,
|
||||
sharding_to_proto as sharding_to_proto,
|
||||
translations as translations,
|
||||
xla_call_p as _deprecated_xla_call_p,
|
||||
xla_destructure as xla_destructure,
|
||||
xla_shape_handlers as xla_shape_handlers,
|
||||
)
|
||||
@ -52,45 +50,6 @@ from jax._src.sharding_impls import (
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc # type: ignore
|
||||
|
||||
from jax._src.api import device_put as _deprecated_device_put
|
||||
|
||||
_deprecated_Device = xc.Device
|
||||
XlaOp = xc.XlaOp
|
||||
xe = xc._xla
|
||||
Backend = xe.Client
|
||||
Buffer = _deprecated_DeviceArray
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 9, 2023:
|
||||
"Buffer": (
|
||||
(
|
||||
"jax.interpreters.xla.Buffer is deprecated. Use jax.Array"
|
||||
" instead."
|
||||
),
|
||||
_deprecated_DeviceArray,
|
||||
),
|
||||
"device_put": (
|
||||
(
|
||||
"jax.interpreters.xla.device_put is deprecated. Please use"
|
||||
" jax.device_put instead."
|
||||
),
|
||||
_deprecated_device_put,
|
||||
),
|
||||
"xla_call_p": (
|
||||
(
|
||||
"jax.interpreters.xla.xla_call_p is deprecated. Please use"
|
||||
" jax.experimental.pjit.pjit_p instead."
|
||||
),
|
||||
_deprecated_xla_call_p,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.api import device_put as device_put
|
||||
from jax._src.interpreters.xla import xla_call_p as xla_call_p
|
||||
del typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user