Deprecate three jax.Array methods:

- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
This commit is contained in:
Jake VanderPlas 2023-02-23 16:15:09 -08:00
parent 5a8c12db9f
commit a283aa0cc3
4 changed files with 46 additions and 12 deletions

View File

@ -11,6 +11,11 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
* The following `jax.Array` methods are deprecated and will be removed 3 months from
Feb 23 2023:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
## jaxlib 0.4.5

View File

@ -1592,7 +1592,7 @@ def _split(x, indices, axis):
if isinstance(x, np.ndarray):
return np.split(x, indices, axis)
else:
return x.split(indices, axis)
return x._split(indices, axis)
def vmap(fun: F,

View File

@ -117,9 +117,6 @@ class Array(abc.ABC):
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def astype(self, dtype) -> Array: ...
def broadcast(self, sizes: Sequence[int]) -> Array: ...
def broadcast_in_dim(self, shape: Sequence[Union[int, Any]],
broadcast_dimensions: Sequence[int]) -> Array: ...
def choose(self, choices, out=None, mode='raise') -> Array: ...
def clip(self, min=None, max=None, out=None) -> Array: ...
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
@ -158,7 +155,6 @@ class Array(abc.ABC):
def round(self, decimals=0, out=None) -> Array: ...
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
def split(self, indices_or_sections: ArrayLike, axis: int = 0) -> List[Array]: ...
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...

View File

@ -5127,6 +5127,36 @@ def _compress_method(a: ArrayLike, condition: ArrayLike,
return compress(condition, a, axis, out)
@_wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast` instead.
""")
def _deprecated_broadcast(*args, **kwargs):
warnings.warn(
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
category=FutureWarning)
return lax.broadcast(*args, **kwargs)
@_wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
""")
def _deprecated_broadcast_in_dim(*args, **kwargs):
warnings.warn(
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
category=FutureWarning)
return lax.broadcast_in_dim(*args, **kwargs)
@_wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.numpy.split` instead.
""")
def _deprecated_split(*args, **kwargs):
warnings.warn(
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
category=FutureWarning)
return split(*args, **kwargs)
@core.stash_axis_env()
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr: ArrayLike,
@ -5487,6 +5517,7 @@ _array_methods = {
"clip": _clip,
"conj": conj,
"conjugate": conjugate,
"compress": _compress_method,
"copy": copy,
"cumprod": cumprod,
"cumsum": cumsum,
@ -5516,13 +5547,15 @@ _array_methods = {
"var": var,
"view": _view,
# Extra methods handy for specializing dispatch
# TODO(jakevdp): find another mechanism for exposing these.
"broadcast": lax.broadcast,
"broadcast_in_dim": lax.broadcast_in_dim,
"split": split,
"compress": _compress_method,
"_multi_slice": _multi_slice,
# Methods exposed in order to avoid circular imports
"_split": split, # used in jacfwd/jacrev
"_multi_slice": _multi_slice, # used in pxla for sharding
# Deprecated methods.
# TODO(jakevdp): remove these after June 2023
"broadcast": _deprecated_broadcast,
"broadcast_in_dim": _deprecated_broadcast_in_dim,
"split": _deprecated_split,
}
_array_properties = {