diff --git a/CHANGELOG.md b/CHANGELOG.md index bafe55110..10e3bca6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. + * Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`) + now raises an error. Previously, this returned a scalar object array. * The following deprecated methods and functions in {mod}`jax.export` have been removed: * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 8925a4342..8d1af46be 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -279,6 +279,11 @@ class PRNGKeyArray(jax.Array): __hash__ = None # type: ignore[assignment] __array_priority__ = 100 + def __array__(self, dtype: np.dtype | None = None, copy: bool | None = None) -> np.ndarray: + raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array." + " Use jax.random.key_data(arr) if you wish to extract the underlying" + " integer array.") + # Overwritten immediately below @property def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override] diff --git a/tests/random_test.py b/tests/random_test.py index fed12792d..e18100a63 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1160,6 +1160,12 @@ class KeyArrayTest(jtu.JaxTestCase): result = jax.grad(lambda theta: f(theta, state)[0])(3.0) self.assertEqual(result, 1.0) + def test_keyarray_array_conversion_fails(self): + key = jax.random.key(0) + msg = "JAX array with PRNGKey dtype cannot be converted to a NumPy array." + with self.assertRaisesRegex(TypeError, msg): + np.asarray(key) + # TODO(frostig,mattjj): more polymorphic primitives tests