Merge pull request #18553 from mattjj:ones-error-message

PiperOrigin-RevId: 582890009
This commit is contained in:
jax authors 2023-11-15 20:11:57 -08:00
commit 95de3d03b9
2 changed files with 22 additions and 1 deletions

View File

@ -2227,6 +2227,7 @@ def full_like(a: ArrayLike | DuckTypedArray,
def zeros(shape: Any, dtype: DTypeLike | None = None) -> Array:
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
dtypes.check_user_dtype_supported(dtype, "zeros")
shape = canonicalize_shape(shape)
return lax.full(shape, 0, _jnp_dtype(dtype))
@ -2235,18 +2236,27 @@ def zeros(shape: Any, dtype: DTypeLike | None = None) -> Array:
def ones(shape: Any, dtype: DTypeLike | None = None) -> Array:
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
shape = canonicalize_shape(shape)
dtypes.check_user_dtype_supported(dtype, "ones")
return lax.full(shape, 1, _jnp_dtype(dtype))
@util._wraps(np.empty, lax_description="""\
Because XLA cannot create uninitialized arrays, the JAX version will
return an array initialized with zeros.""")
def empty(shape: Any, dtype: DTypeLike | None = None) -> Array:
if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
dtypes.check_user_dtype_supported(dtype, "empty")
return zeros(shape, dtype)
def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
if isinstance(dtype, int) and isinstance(shape, int):
return (f"Cannot interpret '{dtype}' as a data type."
f"\n\nDid you accidentally write "
f"`jax.numpy.{name}({shape}, {dtype})` "
f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. "
"with a single tuple argument for the shape?")
@util._wraps(np.array_equal)
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:

View File

@ -5216,6 +5216,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
"but got first argument of"):
jnp.rot90(jnp.ones(2))
@parameterized.named_parameters(
('ones', jnp.ones),
('zeros', jnp.zeros),
('empty', jnp.empty))
def test_error_hint(self, fn):
with self.assertRaisesRegex(
TypeError,
r"Did you accidentally write `jax\.numpy\..*?\(2, 3\)` "
r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"):
fn(2, 3)
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.