mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove several deprecated jax.Array methods:
- arr.broadcast - arr.broadcast_in_dim - arr.split These have been deprecated since JAX v0.4.5 PiperOrigin-RevId: 547228974
This commit is contained in:
parent
f81a48a819
commit
b581ad1f33
@ -21,6 +21,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
|
||||
no longer supported, after being deprecated in JAX version 0.4.7.
|
||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||
* The following `jax.Array` methods have been removed, after being deprecated
|
||||
in JAX v0.4.5:
|
||||
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
|
||||
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
|
||||
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
|
||||
|
||||
* Breaking changes
|
||||
* JAX now requires ml_dtypes version 0.2.0 or newer.
|
||||
|
@ -24,7 +24,6 @@ __all__ = ['register_jax_array_methods']
|
||||
import abc
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Optional, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
@ -301,36 +300,6 @@ def _compress_method(a: ArrayLike, condition: ArrayLike,
|
||||
return lax_numpy.jaxcompress(condition, a, axis, out)
|
||||
|
||||
|
||||
@util._wraps(lax.broadcast, lax_description="""
|
||||
Deprecated. Use :func:`jax.lax.broadcast` instead.
|
||||
""")
|
||||
def _deprecated_broadcast(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
|
||||
category=FutureWarning)
|
||||
return lax.broadcast(*args, **kwargs)
|
||||
|
||||
|
||||
@util._wraps(lax.broadcast, lax_description="""
|
||||
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
|
||||
""")
|
||||
def _deprecated_broadcast_in_dim(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
|
||||
category=FutureWarning)
|
||||
return lax.broadcast_in_dim(*args, **kwargs)
|
||||
|
||||
|
||||
@util._wraps(lax.broadcast, lax_description="""
|
||||
Deprecated. Use :func:`jax.numpy.split` instead.
|
||||
""")
|
||||
def _deprecated_split(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
|
||||
category=FutureWarning)
|
||||
return lax_numpy.split(*args, **kwargs)
|
||||
|
||||
|
||||
@core.stash_axis_env()
|
||||
@partial(jax.jit, static_argnums=(1,2,3))
|
||||
def _multi_slice(arr: ArrayLike,
|
||||
@ -717,12 +686,6 @@ _array_methods = {
|
||||
# Methods exposed in order to avoid circular imports
|
||||
"_split": lax_numpy.split, # used in jacfwd/jacrev
|
||||
"_multi_slice": _multi_slice, # used in pxla for sharding
|
||||
|
||||
# Deprecated methods.
|
||||
# TODO(jakevdp): remove these after June 2023
|
||||
"broadcast": _deprecated_broadcast,
|
||||
"broadcast_in_dim": _deprecated_broadcast_in_dim,
|
||||
"split": _deprecated_split,
|
||||
}
|
||||
|
||||
_impl_only_array_methods = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user