Remove several deprecated jax.Array methods:

- arr.broadcast
- arr.broadcast_in_dim
- arr.split

These have been deprecated since JAX v0.4.5

PiperOrigin-RevId: 547228974
This commit is contained in:
Jake VanderPlas 2023-07-11 10:26:33 -07:00 committed by jax authors
parent f81a48a819
commit b581ad1f33
2 changed files with 5 additions and 37 deletions

View File

@ -21,6 +21,11 @@ Remember to align the itemized text with the first line of an item within a list
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* The following `jax.Array` methods have been removed, after being deprecated
in JAX v0.4.5:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
* Breaking changes
* JAX now requires ml_dtypes version 0.2.0 or newer.

View File

@ -24,7 +24,6 @@ __all__ = ['register_jax_array_methods']
import abc
from functools import partial, wraps
from typing import Any, Optional, Union
import warnings
import numpy as np
import jax
@ -301,36 +300,6 @@ def _compress_method(a: ArrayLike, condition: ArrayLike,
return lax_numpy.jaxcompress(condition, a, axis, out)
@util._wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast` instead.
""")
def _deprecated_broadcast(*args, **kwargs):
warnings.warn(
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
category=FutureWarning)
return lax.broadcast(*args, **kwargs)
@util._wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
""")
def _deprecated_broadcast_in_dim(*args, **kwargs):
warnings.warn(
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
category=FutureWarning)
return lax.broadcast_in_dim(*args, **kwargs)
@util._wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.numpy.split` instead.
""")
def _deprecated_split(*args, **kwargs):
warnings.warn(
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
category=FutureWarning)
return lax_numpy.split(*args, **kwargs)
@core.stash_axis_env()
@partial(jax.jit, static_argnums=(1,2,3))
def _multi_slice(arr: ArrayLike,
@ -717,12 +686,6 @@ _array_methods = {
# Methods exposed in order to avoid circular imports
"_split": lax_numpy.split, # used in jacfwd/jacrev
"_multi_slice": _multi_slice, # used in pxla for sharding
# Deprecated methods.
# TODO(jakevdp): remove these after June 2023
"broadcast": _deprecated_broadcast,
"broadcast_in_dim": _deprecated_broadcast_in_dim,
"split": _deprecated_split,
}
_impl_only_array_methods = {