mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18553 from mattjj:ones-error-message
PiperOrigin-RevId: 582890009
This commit is contained in:
commit
95de3d03b9
@ -2227,6 +2227,7 @@ def full_like(a: ArrayLike | DuckTypedArray,
|
||||
def zeros(shape: Any, dtype: DTypeLike | None = None) -> Array:
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
|
||||
dtypes.check_user_dtype_supported(dtype, "zeros")
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full(shape, 0, _jnp_dtype(dtype))
|
||||
@ -2235,18 +2236,27 @@ def zeros(shape: Any, dtype: DTypeLike | None = None) -> Array:
|
||||
def ones(shape: Any, dtype: DTypeLike | None = None) -> Array:
|
||||
if isinstance(shape, types.GeneratorType):
|
||||
raise TypeError("expected sequence object with len >= 0 or a single integer")
|
||||
if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
|
||||
shape = canonicalize_shape(shape)
|
||||
dtypes.check_user_dtype_supported(dtype, "ones")
|
||||
return lax.full(shape, 1, _jnp_dtype(dtype))
|
||||
|
||||
|
||||
@util._wraps(np.empty, lax_description="""\
|
||||
Because XLA cannot create uninitialized arrays, the JAX version will
|
||||
return an array initialized with zeros.""")
|
||||
def empty(shape: Any, dtype: DTypeLike | None = None) -> Array:
|
||||
if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
|
||||
dtypes.check_user_dtype_supported(dtype, "empty")
|
||||
return zeros(shape, dtype)
|
||||
|
||||
def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore
|
||||
if isinstance(dtype, int) and isinstance(shape, int):
|
||||
return (f"Cannot interpret '{dtype}' as a data type."
|
||||
f"\n\nDid you accidentally write "
|
||||
f"`jax.numpy.{name}({shape}, {dtype})` "
|
||||
f"when you meant `jax.numpy.{name}(({shape}, {dtype}))`, i.e. "
|
||||
"with a single tuple argument for the shape?")
|
||||
|
||||
|
||||
@util._wraps(np.array_equal)
|
||||
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
|
||||
|
@ -5216,6 +5216,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
"but got first argument of"):
|
||||
jnp.rot90(jnp.ones(2))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('ones', jnp.ones),
|
||||
('zeros', jnp.zeros),
|
||||
('empty', jnp.empty))
|
||||
def test_error_hint(self, fn):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Did you accidentally write `jax\.numpy\..*?\(2, 3\)` "
|
||||
r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"):
|
||||
fn(2, 3)
|
||||
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
Loading…
x
Reference in New Issue
Block a user