Error on numpy array conversion of PRNG key array

This commit is contained in:
Jake VanderPlas 2024-11-05 09:26:42 -08:00
parent 1a544b6f36
commit 83383fc717
3 changed files with 13 additions and 0 deletions

View File

@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
for information on migrating to the new API.
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
has been removed, after being deprecated in v0.4.27.
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
now raises an error. Previously, this returned a scalar object array.
* The following deprecated methods and functions in {mod}`jax.export` have
been removed:
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect

View File

@ -279,6 +279,11 @@ class PRNGKeyArray(jax.Array):
__hash__ = None # type: ignore[assignment]
__array_priority__ = 100
def __array__(self, dtype: np.dtype | None = None, copy: bool | None = None) -> np.ndarray:
raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array."
" Use jax.random.key_data(arr) if you wish to extract the underlying"
" integer array.")
# Overwritten immediately below
@property
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]

View File

@ -1160,6 +1160,12 @@ class KeyArrayTest(jtu.JaxTestCase):
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
self.assertEqual(result, 1.0)
def test_keyarray_array_conversion_fails(self):
key = jax.random.key(0)
msg = "JAX array with PRNGKey dtype cannot be converted to a NumPy array."
with self.assertRaisesRegex(TypeError, msg):
np.asarray(key)
# TODO(frostig,mattjj): more polymorphic primitives tests