mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[NumPy] Remove references to deprecated NumPy type aliases.
This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str). NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy. PiperOrigin-RevId: 496691463
This commit is contained in:
parent
71b896893a
commit
843bc43790
@ -138,7 +138,7 @@ def _get_random_data(x: jnp.ndarray) -> np.ndarray:
|
||||
return np.random.randint(0, 100, size=x.shape, dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
return np.array(np.random.uniform(size=x.shape), dtype=dtype)
|
||||
elif dtype == np.bool:
|
||||
elif dtype == bool:
|
||||
return np.random.choice(a=[False, True], size=x.shape)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype for numerical comparison: {dtype}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user