mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Error on numpy array conversion of PRNG key array
This commit is contained in:
parent
1a544b6f36
commit
83383fc717
@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
for information on migrating to the new API.
|
for information on migrating to the new API.
|
||||||
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
|
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
|
||||||
has been removed, after being deprecated in v0.4.27.
|
has been removed, after being deprecated in v0.4.27.
|
||||||
|
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
|
||||||
|
now raises an error. Previously, this returned a scalar object array.
|
||||||
* The following deprecated methods and functions in {mod}`jax.export` have
|
* The following deprecated methods and functions in {mod}`jax.export` have
|
||||||
been removed:
|
been removed:
|
||||||
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
|
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect
|
||||||
|
@ -279,6 +279,11 @@ class PRNGKeyArray(jax.Array):
|
|||||||
__hash__ = None # type: ignore[assignment]
|
__hash__ = None # type: ignore[assignment]
|
||||||
__array_priority__ = 100
|
__array_priority__ = 100
|
||||||
|
|
||||||
|
def __array__(self, dtype: np.dtype | None = None, copy: bool | None = None) -> np.ndarray:
|
||||||
|
raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array."
|
||||||
|
" Use jax.random.key_data(arr) if you wish to extract the underlying"
|
||||||
|
" integer array.")
|
||||||
|
|
||||||
# Overwritten immediately below
|
# Overwritten immediately below
|
||||||
@property
|
@property
|
||||||
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]
|
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]
|
||||||
|
@ -1160,6 +1160,12 @@ class KeyArrayTest(jtu.JaxTestCase):
|
|||||||
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
|
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
|
||||||
self.assertEqual(result, 1.0)
|
self.assertEqual(result, 1.0)
|
||||||
|
|
||||||
|
def test_keyarray_array_conversion_fails(self):
|
||||||
|
key = jax.random.key(0)
|
||||||
|
msg = "JAX array with PRNGKey dtype cannot be converted to a NumPy array."
|
||||||
|
with self.assertRaisesRegex(TypeError, msg):
|
||||||
|
np.asarray(key)
|
||||||
|
|
||||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user