diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a61b1d67f..6a0e4059c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2143,20 +2143,11 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(micky774): deprecated 2024-5-9, remove after deprecation expires. + # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. if not isinstance(newshape, DeprecatedArg): - if shape is not None: - raise ValueError( - "jnp.reshape received both `shape` and `newshape` arguments. Note that " - "using `newshape` is deprecated, please only use `shape` instead." - ) - deprecations.warn( - "jax-numpy-reshape-newshape", - ("The newshape argument of jax.numpy.reshape is deprecated. " - "Please use the shape argument instead."), stacklevel=2) - shape = newshape - del newshape - elif shape is None: + raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." + " Use shape instead.") + if shape is None: raise TypeError( "jnp.shape requires passing a `shape` argument, but none was given." ) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ef80e368c..ef7faf9e3 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3428,13 +3428,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testReshapeDeprecatedArgs(self): - msg = "The newshape argument of jax.numpy.reshape is deprecated." - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-reshape-newshape"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(msg): + msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." + with self.assertRaisesRegex(TypeError, msg): jnp.reshape(jnp.arange(4), newshape=(2, 2)) @jtu.sample_product(