Remove several deprecated jax.Array methods:

- arr.broadcast
- arr.broadcast_in_dim
- arr.split

These have been deprecated since JAX v0.4.5

PiperOrigin-RevId: 547228974
This commit is contained in:
Jake VanderPlas 2023-07-11 10:26:33 -07:00 committed by jax authors
parent f81a48a819
commit b581ad1f33
2 changed files with 5 additions and 37 deletions

View File

@ -21,6 +21,11 @@ Remember to align the itemized text with the first line of an item within a list
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7. no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* The following `jax.Array` methods have been removed, after being deprecated
in JAX v0.4.5:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
* Breaking changes * Breaking changes
* JAX now requires ml_dtypes version 0.2.0 or newer. * JAX now requires ml_dtypes version 0.2.0 or newer.

View File

@ -24,7 +24,6 @@ __all__ = ['register_jax_array_methods']
import abc import abc
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Optional, Union from typing import Any, Optional, Union
import warnings
import numpy as np import numpy as np
import jax import jax
@ -301,36 +300,6 @@ def _compress_method(a: ArrayLike, condition: ArrayLike,
return lax_numpy.jaxcompress(condition, a, axis, out) return lax_numpy.jaxcompress(condition, a, axis, out)
@util._wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast` instead.
""")
def _deprecated_broadcast(*args, **kwargs):
warnings.warn(
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
category=FutureWarning)
return lax.broadcast(*args, **kwargs)
@util._wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
""")
def _deprecated_broadcast_in_dim(*args, **kwargs):
warnings.warn(
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
category=FutureWarning)
return lax.broadcast_in_dim(*args, **kwargs)
@util._wraps(lax.broadcast, lax_description="""
Deprecated. Use :func:`jax.numpy.split` instead.
""")
def _deprecated_split(*args, **kwargs):
warnings.warn(
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
category=FutureWarning)
return lax_numpy.split(*args, **kwargs)
@core.stash_axis_env() @core.stash_axis_env()
@partial(jax.jit, static_argnums=(1,2,3)) @partial(jax.jit, static_argnums=(1,2,3))
def _multi_slice(arr: ArrayLike, def _multi_slice(arr: ArrayLike,
@ -717,12 +686,6 @@ _array_methods = {
# Methods exposed in order to avoid circular imports # Methods exposed in order to avoid circular imports
"_split": lax_numpy.split, # used in jacfwd/jacrev "_split": lax_numpy.split, # used in jacfwd/jacrev
"_multi_slice": _multi_slice, # used in pxla for sharding "_multi_slice": _multi_slice, # used in pxla for sharding
# Deprecated methods.
# TODO(jakevdp): remove these after June 2023
"broadcast": _deprecated_broadcast,
"broadcast_in_dim": _deprecated_broadcast_in_dim,
"split": _deprecated_split,
} }
_impl_only_array_methods = { _impl_only_array_methods = {