Fixed the return type of `jax.random.key_impl`

Closes #23363.
This commit is contained in:
Sergei Lebedev 2024-09-02 21:44:18 +01:00
parent c092cbfd9d
commit 59f825a23e

View File

@ -14,7 +14,7 @@
from __future__ import annotations
from collections.abc import Hashable, Sequence
from collections.abc import Sequence
from functools import partial
import math
from operator import index
@ -292,7 +292,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl:
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
return keys_dtype._impl
def key_impl(keys: KeyArrayLike) -> Hashable:
def key_impl(keys: KeyArrayLike) -> PRNGSpec:
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
return PRNGSpec(_key_impl(typed_keys))