mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead - jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead - jax.Array.split: use jnp.split instead These are removed because they are not part of the np.ndarray API.
This commit is contained in:
parent
5a8c12db9f
commit
a283aa0cc3
@ -11,6 +11,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Deprecations
|
||||
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
|
||||
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
|
||||
* The following `jax.Array` methods are deprecated and will be removed 3 months from
|
||||
Feb 23 2023:
|
||||
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
|
||||
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
|
||||
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
|
||||
|
||||
## jaxlib 0.4.5
|
||||
|
||||
|
@ -1592,7 +1592,7 @@ def _split(x, indices, axis):
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.split(x, indices, axis)
|
||||
else:
|
||||
return x.split(indices, axis)
|
||||
return x._split(indices, axis)
|
||||
|
||||
|
||||
def vmap(fun: F,
|
||||
|
@ -117,9 +117,6 @@ class Array(abc.ABC):
|
||||
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
|
||||
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def astype(self, dtype) -> Array: ...
|
||||
def broadcast(self, sizes: Sequence[int]) -> Array: ...
|
||||
def broadcast_in_dim(self, shape: Sequence[Union[int, Any]],
|
||||
broadcast_dimensions: Sequence[int]) -> Array: ...
|
||||
def choose(self, choices, out=None, mode='raise') -> Array: ...
|
||||
def clip(self, min=None, max=None, out=None) -> Array: ...
|
||||
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
|
||||
@ -158,7 +155,6 @@ class Array(abc.ABC):
|
||||
def round(self, decimals=0, out=None) -> Array: ...
|
||||
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
||||
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||
def split(self, indices_or_sections: ArrayLike, axis: int = 0) -> List[Array]: ...
|
||||
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
|
||||
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||
|
@ -5127,6 +5127,36 @@ def _compress_method(a: ArrayLike, condition: ArrayLike,
|
||||
return compress(condition, a, axis, out)
|
||||
|
||||
|
||||
@_wraps(lax.broadcast, lax_description="""
|
||||
Deprecated. Use :func:`jax.lax.broadcast` instead.
|
||||
""")
|
||||
def _deprecated_broadcast(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
|
||||
category=FutureWarning)
|
||||
return lax.broadcast(*args, **kwargs)
|
||||
|
||||
|
||||
@_wraps(lax.broadcast, lax_description="""
|
||||
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
|
||||
""")
|
||||
def _deprecated_broadcast_in_dim(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
|
||||
category=FutureWarning)
|
||||
return lax.broadcast_in_dim(*args, **kwargs)
|
||||
|
||||
|
||||
@_wraps(lax.broadcast, lax_description="""
|
||||
Deprecated. Use :func:`jax.numpy.split` instead.
|
||||
""")
|
||||
def _deprecated_split(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
|
||||
category=FutureWarning)
|
||||
return split(*args, **kwargs)
|
||||
|
||||
|
||||
@core.stash_axis_env()
|
||||
@partial(jit, static_argnums=(1,2,3))
|
||||
def _multi_slice(arr: ArrayLike,
|
||||
@ -5487,6 +5517,7 @@ _array_methods = {
|
||||
"clip": _clip,
|
||||
"conj": conj,
|
||||
"conjugate": conjugate,
|
||||
"compress": _compress_method,
|
||||
"copy": copy,
|
||||
"cumprod": cumprod,
|
||||
"cumsum": cumsum,
|
||||
@ -5516,13 +5547,15 @@ _array_methods = {
|
||||
"var": var,
|
||||
"view": _view,
|
||||
|
||||
# Extra methods handy for specializing dispatch
|
||||
# TODO(jakevdp): find another mechanism for exposing these.
|
||||
"broadcast": lax.broadcast,
|
||||
"broadcast_in_dim": lax.broadcast_in_dim,
|
||||
"split": split,
|
||||
"compress": _compress_method,
|
||||
"_multi_slice": _multi_slice,
|
||||
# Methods exposed in order to avoid circular imports
|
||||
"_split": split, # used in jacfwd/jacrev
|
||||
"_multi_slice": _multi_slice, # used in pxla for sharding
|
||||
|
||||
# Deprecated methods.
|
||||
# TODO(jakevdp): remove these after June 2023
|
||||
"broadcast": _deprecated_broadcast,
|
||||
"broadcast_in_dim": _deprecated_broadcast_in_dim,
|
||||
"split": _deprecated_split,
|
||||
}
|
||||
|
||||
_array_properties = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user