mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Error on numpy array conversion of PRNG key array
This commit is contained in:
parent
1a544b6f36
commit
83383fc717
@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
for information on migrating to the new API.
|
||||
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
|
||||
has been removed, after being deprecated in v0.4.27.
|
||||
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
|
||||
now raises an error. Previously, this returned a scalar object array.
|
||||
* The following deprecated methods and functions in {mod}`jax.export` have
|
||||
been removed:
|
||||
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
|
||||
|
@ -279,6 +279,11 @@ class PRNGKeyArray(jax.Array):
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
__array_priority__ = 100
|
||||
|
||||
def __array__(self, dtype: np.dtype | None = None, copy: bool | None = None) -> np.ndarray:
|
||||
raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array."
|
||||
" Use jax.random.key_data(arr) if you wish to extract the underlying"
|
||||
" integer array.")
|
||||
|
||||
# Overwritten immediately below
|
||||
@property
|
||||
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]
|
||||
|
@ -1160,6 +1160,12 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
def test_keyarray_array_conversion_fails(self):
|
||||
key = jax.random.key(0)
|
||||
msg = "JAX array with PRNGKey dtype cannot be converted to a NumPy array."
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
np.asarray(key)
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user