mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove several deprecated jax.Array methods:
- arr.broadcast - arr.broadcast_in_dim - arr.split These have been deprecated since JAX v0.4.5 PiperOrigin-RevId: 547228974
This commit is contained in:
parent
f81a48a819
commit
b581ad1f33
@ -21,6 +21,11 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
|
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
|
||||||
no longer supported, after being deprecated in JAX version 0.4.7.
|
no longer supported, after being deprecated in JAX version 0.4.7.
|
||||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||||
|
* The following `jax.Array` methods have been removed, after being deprecated
|
||||||
|
in JAX v0.4.5:
|
||||||
|
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
|
||||||
|
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
|
||||||
|
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
|
||||||
|
|
||||||
* Breaking changes
|
* Breaking changes
|
||||||
* JAX now requires ml_dtypes version 0.2.0 or newer.
|
* JAX now requires ml_dtypes version 0.2.0 or newer.
|
||||||
|
@ -24,7 +24,6 @@ __all__ = ['register_jax_array_methods']
|
|||||||
import abc
|
import abc
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax
|
import jax
|
||||||
@ -301,36 +300,6 @@ def _compress_method(a: ArrayLike, condition: ArrayLike,
|
|||||||
return lax_numpy.jaxcompress(condition, a, axis, out)
|
return lax_numpy.jaxcompress(condition, a, axis, out)
|
||||||
|
|
||||||
|
|
||||||
@util._wraps(lax.broadcast, lax_description="""
|
|
||||||
Deprecated. Use :func:`jax.lax.broadcast` instead.
|
|
||||||
""")
|
|
||||||
def _deprecated_broadcast(*args, **kwargs):
|
|
||||||
warnings.warn(
|
|
||||||
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
|
|
||||||
category=FutureWarning)
|
|
||||||
return lax.broadcast(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@util._wraps(lax.broadcast, lax_description="""
|
|
||||||
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
|
|
||||||
""")
|
|
||||||
def _deprecated_broadcast_in_dim(*args, **kwargs):
|
|
||||||
warnings.warn(
|
|
||||||
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
|
|
||||||
category=FutureWarning)
|
|
||||||
return lax.broadcast_in_dim(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@util._wraps(lax.broadcast, lax_description="""
|
|
||||||
Deprecated. Use :func:`jax.numpy.split` instead.
|
|
||||||
""")
|
|
||||||
def _deprecated_split(*args, **kwargs):
|
|
||||||
warnings.warn(
|
|
||||||
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
|
|
||||||
category=FutureWarning)
|
|
||||||
return lax_numpy.split(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@core.stash_axis_env()
|
@core.stash_axis_env()
|
||||||
@partial(jax.jit, static_argnums=(1,2,3))
|
@partial(jax.jit, static_argnums=(1,2,3))
|
||||||
def _multi_slice(arr: ArrayLike,
|
def _multi_slice(arr: ArrayLike,
|
||||||
@ -717,12 +686,6 @@ _array_methods = {
|
|||||||
# Methods exposed in order to avoid circular imports
|
# Methods exposed in order to avoid circular imports
|
||||||
"_split": lax_numpy.split, # used in jacfwd/jacrev
|
"_split": lax_numpy.split, # used in jacfwd/jacrev
|
||||||
"_multi_slice": _multi_slice, # used in pxla for sharding
|
"_multi_slice": _multi_slice, # used in pxla for sharding
|
||||||
|
|
||||||
# Deprecated methods.
|
|
||||||
# TODO(jakevdp): remove these after June 2023
|
|
||||||
"broadcast": _deprecated_broadcast,
|
|
||||||
"broadcast_in_dim": _deprecated_broadcast_in_dim,
|
|
||||||
"split": _deprecated_split,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_impl_only_array_methods = {
|
_impl_only_array_methods = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user