mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
c092cbfd9d
commit
59f825a23e
@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import math
|
||||
from operator import index
|
||||
@ -292,7 +292,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl:
|
||||
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
|
||||
return keys_dtype._impl
|
||||
|
||||
def key_impl(keys: KeyArrayLike) -> Hashable:
|
||||
def key_impl(keys: KeyArrayLike) -> PRNGSpec:
|
||||
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
|
||||
return PRNGSpec(_key_impl(typed_keys))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user