Deprecate newshape argument of jnp.reshape

This commit is contained in:
Meekail Zain 2024-05-09 21:02:07 +00:00
parent 1cb69716fe
commit 79005c1e69
5 changed files with 34 additions and 10 deletions

View File

@ -17,6 +17,8 @@ Remember to align the itemized text with the first line of an item within a list
* Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
module. Use the ``compute_capability`` attribute of a GPU device, returned
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
* The ``newshape`` argument to {func}`jax.numpy.reshape`is being deprecated
and will soon be removed. Use `shape` instead.
* Changes
* The minimum jaxlib version of this release is 0.4.27.

View File

@ -893,7 +893,9 @@ def isrealobj(x: Any) -> bool:
return not iscomplexobj(x)
def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
def reshape(
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array:
"""Return a reshaped copy of an array.
JAX implementation of :func:`numpy.reshape`, implemented in terms of
@ -901,7 +903,7 @@ def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
Args:
a: input array to reshape
newshape: integer or sequence of integers giving the new shape, which must match the
shape: integer or sequence of integers giving the new shape, which must match the
size of the input array. If any single dimension is given size ``-1``, it will be
replaced with a value such that the output has the correct size.
order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major
@ -961,12 +963,31 @@ def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
"""
__tracebackhide__ = True
util.check_arraylike("reshape", a)
# TODO(micky774): deprecated 2024-5-9, remove after deprecation expires.
if not isinstance(newshape, DeprecatedArg):
if shape is not None:
raise ValueError(
"jnp.reshape received both `shape` and `newshape` arguments. Note that "
"using `newshape` is deprecated, please only use `shape` instead."
)
warnings.warn(
"The newshape argument of jax.numpy.reshape is deprecated and setting it "
"will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the shape argument instead.",
DeprecationWarning, stacklevel=2)
shape = newshape
del newshape
elif shape is None:
raise TypeError(
"jnp.shape requires passing a `shape` argument, but none was given."
)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr]
except AttributeError:
pass
return asarray(a).reshape(newshape, order=order)
return asarray(a).reshape(shape, order=order)
@partial(jit, static_argnames=('order',), inline=True)

View File

@ -18,7 +18,7 @@ import jax
from jax import Array
# TODO(micky774): Deprecate newshape-->shape in for array API 2023.12
# TODO(micky774): Implement copy
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused

View File

@ -690,7 +690,8 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = ..., *,
total_repeat_length: Optional[int] = ...) -> Array: ...
def reshape(
a: ArrayLike, newshape: Union[DimSize, Shape], order: str = ...
a: ArrayLike, shape: Union[DimSize, Shape] = ...,
newshape: Union[DimSize, Shape] | None = ..., order: str = ...
) -> Array: ...
def resize(a: ArrayLike, new_shape: Shape) -> Array: ...

View File

@ -1221,9 +1221,9 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
key = random.key(123)
keys = random.split(key, 4)
newshape = (2, 2)
key_func = partial(jnp.reshape, newshape=newshape)
arr_func = partial(jnp.reshape, newshape=(*newshape, *key._impl.key_shape))
shape = (2, 2)
key_func = partial(jnp.reshape, shape=shape)
arr_func = partial(jnp.reshape, shape=(*shape, *key._impl.key_shape))
self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
@ -1291,7 +1291,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
keys = random.split(key, 4).reshape(2, 2)
key_func = jnp.ravel
arr_func = partial(jnp.reshape, newshape=(4, *key._impl.key_shape))
arr_func = partial(jnp.reshape, shape=(4, *key._impl.key_shape))
self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)