mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead - jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead - jax.Array.split: use jnp.split instead These are removed because they are not part of the np.ndarray API.
This commit is contained in:
parent
5a8c12db9f
commit
a283aa0cc3
@ -11,6 +11,11 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Deprecations
|
* Deprecations
|
||||||
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
|
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
|
||||||
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
|
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
|
||||||
|
* The following `jax.Array` methods are deprecated and will be removed 3 months from
|
||||||
|
Feb 23 2023:
|
||||||
|
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
|
||||||
|
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
|
||||||
|
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
|
||||||
|
|
||||||
## jaxlib 0.4.5
|
## jaxlib 0.4.5
|
||||||
|
|
||||||
|
@ -1592,7 +1592,7 @@ def _split(x, indices, axis):
|
|||||||
if isinstance(x, np.ndarray):
|
if isinstance(x, np.ndarray):
|
||||||
return np.split(x, indices, axis)
|
return np.split(x, indices, axis)
|
||||||
else:
|
else:
|
||||||
return x.split(indices, axis)
|
return x._split(indices, axis)
|
||||||
|
|
||||||
|
|
||||||
def vmap(fun: F,
|
def vmap(fun: F,
|
||||||
|
@ -117,9 +117,6 @@ class Array(abc.ABC):
|
|||||||
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
|
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
|
||||||
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||||
def astype(self, dtype) -> Array: ...
|
def astype(self, dtype) -> Array: ...
|
||||||
def broadcast(self, sizes: Sequence[int]) -> Array: ...
|
|
||||||
def broadcast_in_dim(self, shape: Sequence[Union[int, Any]],
|
|
||||||
broadcast_dimensions: Sequence[int]) -> Array: ...
|
|
||||||
def choose(self, choices, out=None, mode='raise') -> Array: ...
|
def choose(self, choices, out=None, mode='raise') -> Array: ...
|
||||||
def clip(self, min=None, max=None, out=None) -> Array: ...
|
def clip(self, min=None, max=None, out=None) -> Array: ...
|
||||||
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
|
def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ...
|
||||||
@ -158,7 +155,6 @@ class Array(abc.ABC):
|
|||||||
def round(self, decimals=0, out=None) -> Array: ...
|
def round(self, decimals=0, out=None) -> Array: ...
|
||||||
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
|
||||||
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ...
|
||||||
def split(self, indices_or_sections: ArrayLike, axis: int = 0) -> List[Array]: ...
|
|
||||||
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
|
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ...
|
||||||
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||||
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
|
||||||
|
@ -5127,6 +5127,36 @@ def _compress_method(a: ArrayLike, condition: ArrayLike,
|
|||||||
return compress(condition, a, axis, out)
|
return compress(condition, a, axis, out)
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(lax.broadcast, lax_description="""
|
||||||
|
Deprecated. Use :func:`jax.lax.broadcast` instead.
|
||||||
|
""")
|
||||||
|
def _deprecated_broadcast(*args, **kwargs):
|
||||||
|
warnings.warn(
|
||||||
|
"The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.",
|
||||||
|
category=FutureWarning)
|
||||||
|
return lax.broadcast(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(lax.broadcast, lax_description="""
|
||||||
|
Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead.
|
||||||
|
""")
|
||||||
|
def _deprecated_broadcast_in_dim(*args, **kwargs):
|
||||||
|
warnings.warn(
|
||||||
|
"The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.",
|
||||||
|
category=FutureWarning)
|
||||||
|
return lax.broadcast_in_dim(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(lax.broadcast, lax_description="""
|
||||||
|
Deprecated. Use :func:`jax.numpy.split` instead.
|
||||||
|
""")
|
||||||
|
def _deprecated_split(*args, **kwargs):
|
||||||
|
warnings.warn(
|
||||||
|
"The arr.split() method is deprecated. Use jax.numpy.split instead.",
|
||||||
|
category=FutureWarning)
|
||||||
|
return split(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@core.stash_axis_env()
|
@core.stash_axis_env()
|
||||||
@partial(jit, static_argnums=(1,2,3))
|
@partial(jit, static_argnums=(1,2,3))
|
||||||
def _multi_slice(arr: ArrayLike,
|
def _multi_slice(arr: ArrayLike,
|
||||||
@ -5487,6 +5517,7 @@ _array_methods = {
|
|||||||
"clip": _clip,
|
"clip": _clip,
|
||||||
"conj": conj,
|
"conj": conj,
|
||||||
"conjugate": conjugate,
|
"conjugate": conjugate,
|
||||||
|
"compress": _compress_method,
|
||||||
"copy": copy,
|
"copy": copy,
|
||||||
"cumprod": cumprod,
|
"cumprod": cumprod,
|
||||||
"cumsum": cumsum,
|
"cumsum": cumsum,
|
||||||
@ -5516,13 +5547,15 @@ _array_methods = {
|
|||||||
"var": var,
|
"var": var,
|
||||||
"view": _view,
|
"view": _view,
|
||||||
|
|
||||||
# Extra methods handy for specializing dispatch
|
# Methods exposed in order to avoid circular imports
|
||||||
# TODO(jakevdp): find another mechanism for exposing these.
|
"_split": split, # used in jacfwd/jacrev
|
||||||
"broadcast": lax.broadcast,
|
"_multi_slice": _multi_slice, # used in pxla for sharding
|
||||||
"broadcast_in_dim": lax.broadcast_in_dim,
|
|
||||||
"split": split,
|
# Deprecated methods.
|
||||||
"compress": _compress_method,
|
# TODO(jakevdp): remove these after June 2023
|
||||||
"_multi_slice": _multi_slice,
|
"broadcast": _deprecated_broadcast,
|
||||||
|
"broadcast_in_dim": _deprecated_broadcast_in_dim,
|
||||||
|
"split": _deprecated_split,
|
||||||
}
|
}
|
||||||
|
|
||||||
_array_properties = {
|
_array_properties = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user