diff --git a/CHANGELOG.md b/CHANGELOG.md index b347c37c1..e7e561586 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`. `jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023. + * The following `jax.Array` methods are deprecated and will be removed 3 months from + Feb 23 2023: + * `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead. + * `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead. + * `jax.Array.split`: use {func}`jax.numpy.split` instead. ## jaxlib 0.4.5 diff --git a/jax/_src/api.py b/jax/_src/api.py index 27872654c..6da1101de 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1592,7 +1592,7 @@ def _split(x, indices, axis): if isinstance(x, np.ndarray): return np.split(x, indices, axis) else: - return x.split(indices, axis) + return x._split(indices, axis) def vmap(fun: F, diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index cc0ba7137..0a401a0b0 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -117,9 +117,6 @@ class Array(abc.ABC): def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ... def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... def astype(self, dtype) -> Array: ... - def broadcast(self, sizes: Sequence[int]) -> Array: ... - def broadcast_in_dim(self, shape: Sequence[Union[int, Any]], - broadcast_dimensions: Sequence[int]) -> Array: ... def choose(self, choices, out=None, mode='raise') -> Array: ... def clip(self, min=None, max=None, out=None) -> Array: ... def compress(self, condition, axis: Optional[int] = None, out=None) -> Array: ... @@ -158,7 +155,6 @@ class Array(abc.ABC): def round(self, decimals=0, out=None) -> Array: ... def searchsorted(self, v, side='left', sorter=None) -> Array: ... def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Array: ... - def split(self, indices_or_sections: ArrayLike, axis: int = 0) -> List[Array]: ... def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: ... def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f2cffd8f6..fac69a0b3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5127,6 +5127,36 @@ def _compress_method(a: ArrayLike, condition: ArrayLike, return compress(condition, a, axis, out) +@_wraps(lax.broadcast, lax_description=""" +Deprecated. Use :func:`jax.lax.broadcast` instead. +""") +def _deprecated_broadcast(*args, **kwargs): + warnings.warn( + "The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.", + category=FutureWarning) + return lax.broadcast(*args, **kwargs) + + +@_wraps(lax.broadcast, lax_description=""" +Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead. +""") +def _deprecated_broadcast_in_dim(*args, **kwargs): + warnings.warn( + "The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.", + category=FutureWarning) + return lax.broadcast_in_dim(*args, **kwargs) + + +@_wraps(lax.broadcast, lax_description=""" +Deprecated. Use :func:`jax.numpy.split` instead. +""") +def _deprecated_split(*args, **kwargs): + warnings.warn( + "The arr.split() method is deprecated. Use jax.numpy.split instead.", + category=FutureWarning) + return split(*args, **kwargs) + + @core.stash_axis_env() @partial(jit, static_argnums=(1,2,3)) def _multi_slice(arr: ArrayLike, @@ -5487,6 +5517,7 @@ _array_methods = { "clip": _clip, "conj": conj, "conjugate": conjugate, + "compress": _compress_method, "copy": copy, "cumprod": cumprod, "cumsum": cumsum, @@ -5516,13 +5547,15 @@ _array_methods = { "var": var, "view": _view, - # Extra methods handy for specializing dispatch - # TODO(jakevdp): find another mechanism for exposing these. - "broadcast": lax.broadcast, - "broadcast_in_dim": lax.broadcast_in_dim, - "split": split, - "compress": _compress_method, - "_multi_slice": _multi_slice, + # Methods exposed in order to avoid circular imports + "_split": split, # used in jacfwd/jacrev + "_multi_slice": _multi_slice, # used in pxla for sharding + + # Deprecated methods. + # TODO(jakevdp): remove these after June 2023 + "broadcast": _deprecated_broadcast, + "broadcast_in_dim": _deprecated_broadcast_in_dim, + "split": _deprecated_split, } _array_properties = {