mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Deprecate newshape argument of jnp.reshape
This commit is contained in:
parent
1cb69716fe
commit
79005c1e69
@ -17,6 +17,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
|
||||
module. Use the ``compute_capability`` attribute of a GPU device, returned
|
||||
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
|
||||
* The ``newshape`` argument to {func}`jax.numpy.reshape`is being deprecated
|
||||
and will soon be removed. Use `shape` instead.
|
||||
|
||||
* Changes
|
||||
* The minimum jaxlib version of this release is 0.4.27.
|
||||
|
@ -893,7 +893,9 @@ def isrealobj(x: Any) -> bool:
|
||||
return not iscomplexobj(x)
|
||||
|
||||
|
||||
def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
|
||||
def reshape(
|
||||
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
|
||||
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array:
|
||||
"""Return a reshaped copy of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.reshape`, implemented in terms of
|
||||
@ -901,7 +903,7 @@ def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
|
||||
|
||||
Args:
|
||||
a: input array to reshape
|
||||
newshape: integer or sequence of integers giving the new shape, which must match the
|
||||
shape: integer or sequence of integers giving the new shape, which must match the
|
||||
size of the input array. If any single dimension is given size ``-1``, it will be
|
||||
replaced with a value such that the output has the correct size.
|
||||
order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major
|
||||
@ -961,12 +963,31 @@ def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
util.check_arraylike("reshape", a)
|
||||
|
||||
# TODO(micky774): deprecated 2024-5-9, remove after deprecation expires.
|
||||
if not isinstance(newshape, DeprecatedArg):
|
||||
if shape is not None:
|
||||
raise ValueError(
|
||||
"jnp.reshape received both `shape` and `newshape` arguments. Note that "
|
||||
"using `newshape` is deprecated, please only use `shape` instead."
|
||||
)
|
||||
warnings.warn(
|
||||
"The newshape argument of jax.numpy.reshape is deprecated and setting it "
|
||||
"will soon raise an error. To avoid an error in the future, and to "
|
||||
"suppress this warning, please use the shape argument instead.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
shape = newshape
|
||||
del newshape
|
||||
elif shape is None:
|
||||
raise TypeError(
|
||||
"jnp.shape requires passing a `shape` argument, but none was given."
|
||||
)
|
||||
try:
|
||||
# forward to method for ndarrays
|
||||
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
|
||||
return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr]
|
||||
except AttributeError:
|
||||
pass
|
||||
return asarray(a).reshape(newshape, order=order)
|
||||
return asarray(a).reshape(shape, order=order)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('order',), inline=True)
|
||||
|
@ -18,7 +18,7 @@ import jax
|
||||
from jax import Array
|
||||
|
||||
|
||||
# TODO(micky774): Deprecate newshape-->shape in for array API 2023.12
|
||||
# TODO(micky774): Implement copy
|
||||
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
|
||||
"""Reshapes an array without changing its data."""
|
||||
del copy # unused
|
||||
|
@ -690,7 +690,8 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = ..., *,
|
||||
total_repeat_length: Optional[int] = ...) -> Array: ...
|
||||
def reshape(
|
||||
a: ArrayLike, newshape: Union[DimSize, Shape], order: str = ...
|
||||
a: ArrayLike, shape: Union[DimSize, Shape] = ...,
|
||||
newshape: Union[DimSize, Shape] | None = ..., order: str = ...
|
||||
) -> Array: ...
|
||||
|
||||
def resize(a: ArrayLike, new_shape: Shape) -> Array: ...
|
||||
|
@ -1221,9 +1221,9 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
key = random.key(123)
|
||||
keys = random.split(key, 4)
|
||||
|
||||
newshape = (2, 2)
|
||||
key_func = partial(jnp.reshape, newshape=newshape)
|
||||
arr_func = partial(jnp.reshape, newshape=(*newshape, *key._impl.key_shape))
|
||||
shape = (2, 2)
|
||||
key_func = partial(jnp.reshape, shape=shape)
|
||||
arr_func = partial(jnp.reshape, shape=(*shape, *key._impl.key_shape))
|
||||
|
||||
self.check_shape(key_func, keys)
|
||||
self.check_against_reference(key_func, arr_func, keys)
|
||||
@ -1291,7 +1291,7 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
keys = random.split(key, 4).reshape(2, 2)
|
||||
|
||||
key_func = jnp.ravel
|
||||
arr_func = partial(jnp.reshape, newshape=(4, *key._impl.key_shape))
|
||||
arr_func = partial(jnp.reshape, shape=(4, *key._impl.key_shape))
|
||||
|
||||
self.check_shape(key_func, keys)
|
||||
self.check_against_reference(key_func, arr_func, keys)
|
||||
|
Loading…
x
Reference in New Issue
Block a user