From b581ad1f33e4a7e63838e542a2ee1c50ee648072 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 11 Jul 2023 10:26:33 -0700 Subject: [PATCH] Remove several deprecated jax.Array methods: - arr.broadcast - arr.broadcast_in_dim - arr.split These have been deprecated since JAX v0.4.5 PiperOrigin-RevId: 547228974 --- CHANGELOG.md | 5 +++++ jax/_src/numpy/array_methods.py | 37 --------------------------------- 2 files changed, 5 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a2262ba5..1220be0bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,11 @@ Remember to align the itemized text with the first line of an item within a list * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` + * The following `jax.Array` methods have been removed, after being deprecated + in JAX v0.4.5: + * `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead. + * `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead. + * `jax.Array.split`: use {func}`jax.numpy.split` instead. * Breaking changes * JAX now requires ml_dtypes version 0.2.0 or newer. diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index cff7e53c2..ae0abce8a 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -24,7 +24,6 @@ __all__ = ['register_jax_array_methods'] import abc from functools import partial, wraps from typing import Any, Optional, Union -import warnings import numpy as np import jax @@ -301,36 +300,6 @@ def _compress_method(a: ArrayLike, condition: ArrayLike, return lax_numpy.jaxcompress(condition, a, axis, out) -@util._wraps(lax.broadcast, lax_description=""" -Deprecated. Use :func:`jax.lax.broadcast` instead. -""") -def _deprecated_broadcast(*args, **kwargs): - warnings.warn( - "The arr.broadcast() method is deprecated. Use jax.lax.broadcast instead.", - category=FutureWarning) - return lax.broadcast(*args, **kwargs) - - -@util._wraps(lax.broadcast, lax_description=""" -Deprecated. Use :func:`jax.lax.broadcast_in_dim` instead. -""") -def _deprecated_broadcast_in_dim(*args, **kwargs): - warnings.warn( - "The arr.broadcast_in_dim() method is deprecated. Use jax.lax.broadcast_in_dim instead.", - category=FutureWarning) - return lax.broadcast_in_dim(*args, **kwargs) - - -@util._wraps(lax.broadcast, lax_description=""" -Deprecated. Use :func:`jax.numpy.split` instead. -""") -def _deprecated_split(*args, **kwargs): - warnings.warn( - "The arr.split() method is deprecated. Use jax.numpy.split instead.", - category=FutureWarning) - return lax_numpy.split(*args, **kwargs) - - @core.stash_axis_env() @partial(jax.jit, static_argnums=(1,2,3)) def _multi_slice(arr: ArrayLike, @@ -717,12 +686,6 @@ _array_methods = { # Methods exposed in order to avoid circular imports "_split": lax_numpy.split, # used in jacfwd/jacrev "_multi_slice": _multi_slice, # used in pxla for sharding - - # Deprecated methods. - # TODO(jakevdp): remove these after June 2023 - "broadcast": _deprecated_broadcast, - "broadcast_in_dim": _deprecated_broadcast_in_dim, - "split": _deprecated_split, } _impl_only_array_methods = {