Merge pull request #16159 from jakevdp:deprecations

PiperOrigin-RevId: 536003451
This commit is contained in:
jax authors 2023-05-28 07:27:45 -07:00
commit ae9160a4e9
2 changed files with 4 additions and 42 deletions

View File

@ -19,8 +19,11 @@ Remember to align the itemized text with the first line of an item within a list
as input and remove the optional `in_shardings` argument to `pjit`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
* `jax.interpreters.xla.Buffer`: use `jax.Array`.
* `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.interpreters.xla.DeviceArray`: use `jax.Array` instead
* `jax.interpreters.xla.DeviceArray`: use `jax.Array`.
* `jax.interpreters.xla.device_put`: use `jax.device_put`.
* `jax.interpreters.xla.xla_call_p`: use `jax.experimental.pjit.pjit_p`.
* `axis_resources` argument of `with_sharding_constraint` is removed. Please
use `shardings` instead.

View File

@ -13,7 +13,6 @@
# limitations under the License.
from jax._src.interpreters.xla import (
DeviceArray as _deprecated_DeviceArray,
TranslationContext as TranslationContext,
TranslationRule as TranslationRule,
abstractify as abstractify,
@ -29,7 +28,6 @@ from jax._src.interpreters.xla import (
register_translation as register_translation,
sharding_to_proto as sharding_to_proto,
translations as translations,
xla_call_p as _deprecated_xla_call_p,
xla_destructure as xla_destructure,
xla_shape_handlers as xla_shape_handlers,
)
@ -52,45 +50,6 @@ from jax._src.sharding_impls import (
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc # type: ignore
from jax._src.api import device_put as _deprecated_device_put
_deprecated_Device = xc.Device
XlaOp = xc.XlaOp
xe = xc._xla
Backend = xe.Client
Buffer = _deprecated_DeviceArray
_deprecations = {
# Added Feb 9, 2023:
"Buffer": (
(
"jax.interpreters.xla.Buffer is deprecated. Use jax.Array"
" instead."
),
_deprecated_DeviceArray,
),
"device_put": (
(
"jax.interpreters.xla.device_put is deprecated. Please use"
" jax.device_put instead."
),
_deprecated_device_put,
),
"xla_call_p": (
(
"jax.interpreters.xla.xla_call_p is deprecated. Please use"
" jax.experimental.pjit.pjit_p instead."
),
_deprecated_xla_call_p,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
import typing
if typing.TYPE_CHECKING:
from jax._src.api import device_put as device_put
from jax._src.interpreters.xla import xla_call_p as xla_call_p
del typing