jnp.reshape: raise TypeError when specifying newshape

This commit is contained in:
Jake VanderPlas 2024-12-02 10:20:34 -08:00
parent 2e0474a55d
commit a7039a275e
2 changed files with 6 additions and 20 deletions

View File

@ -2143,20 +2143,11 @@ def reshape(
__tracebackhide__ = True
util.check_arraylike("reshape", a)
# TODO(micky774): deprecated 2024-5-9, remove after deprecation expires.
# TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40.
if not isinstance(newshape, DeprecatedArg):
if shape is not None:
raise ValueError(
"jnp.reshape received both `shape` and `newshape` arguments. Note that "
"using `newshape` is deprecated, please only use `shape` instead."
)
deprecations.warn(
"jax-numpy-reshape-newshape",
("The newshape argument of jax.numpy.reshape is deprecated. "
"Please use the shape argument instead."), stacklevel=2)
shape = newshape
del newshape
elif shape is None:
raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36."
" Use shape instead.")
if shape is None:
raise TypeError(
"jnp.shape requires passing a `shape` argument, but none was given."
)

View File

@ -3428,13 +3428,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testReshapeDeprecatedArgs(self):
msg = "The newshape argument of jax.numpy.reshape is deprecated."
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-reshape-newshape"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(msg):
msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36."
with self.assertRaisesRegex(TypeError, msg):
jnp.reshape(jnp.arange(4), newshape=(2, 2))
@jtu.sample_product(