mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses
PiperOrigin-RevId: 518803946
This commit is contained in:
parent
6d1c849a53
commit
4cb32ba46f
@ -116,12 +116,12 @@ def _check_prng_key_data(impl, key_data: jax.Array):
|
||||
class PRNGKeyArrayMeta(abc.ABCMeta):
|
||||
"""Metaclass for overriding PRNGKeyArray isinstance checks."""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
def __instancecheck__(cls, instance):
|
||||
try:
|
||||
return (isinstance(instance.aval, core.ShapedArray) and
|
||||
type(instance.aval.dtype) is KeyTy)
|
||||
except AttributeError:
|
||||
super().__instancecheck__(instance)
|
||||
return super().__instancecheck__(instance)
|
||||
|
||||
|
||||
class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
|
Loading…
x
Reference in New Issue
Block a user