Deprecate jax.numpy.row_stack

This commit is contained in:
Jake VanderPlas 2023-08-22 13:12:49 -07:00
parent 3082109a59
commit 19a57e1a01
5 changed files with 10 additions and 5 deletions

View File

@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.numpy.NINF` has been deprecated. Use `-jax.numpy.inf` instead.
* `jax.numpy.PZERO` has been deprecated. Use `0.0` instead.
* `jax.numpy.NZERO` has been deprecated. Use `-0.0` instead.
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.

View File

@ -340,7 +340,6 @@ namespace; they are listed below.
rot90
round
round_
row_stack
s_
save
savez

View File

@ -1868,7 +1868,6 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike],
else:
arrs = [atleast_2d(m) for m in tup]
return concatenate(arrs, axis=0, dtype=dtype)
row_stack = vstack
@util._wraps(np.hstack)

View File

@ -206,7 +206,6 @@ from jax._src.numpy.lax_numpy import (
rot90 as rot90,
round as round,
round_ as round_,
row_stack as row_stack,
save as save,
savez as savez,
searchsorted as searchsorted,
@ -464,7 +463,12 @@ _deprecations = {
"issubsctype": (
"jax.numpy.issubsctype is deprecated. In most cases, jax.numpy.issubdtype can be used instead.",
_numpy.core.numerictypes.issubsctype,
)
),
# Added Aug 22, 2023
"row_stack": (
"jax.numpy.row_stack is deprecated. Use jax.numpy.vstack instead.",
vstack,
),
}
import typing
@ -472,6 +476,7 @@ if typing.TYPE_CHECKING:
alltrue = all
cumproduct = cumprod
product = prod
row_stack = vstack
sometrue = any
NINF = -inf
NZERO = -0.0

View File

@ -318,7 +318,7 @@ roots: Any
rot90: Any
round: Any
round_: Any
row_stack: Any
row_stack: Any # TODO(jakevdp): remove this
s_: Any
save: Any
savez: Any