Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses

PiperOrigin-RevId: 518803946
This commit is contained in:
Etienne Pot 2023-03-23 02:31:27 -07:00 committed by jax authors
parent 6d1c849a53
commit 4cb32ba46f

View File

@ -116,12 +116,12 @@ def _check_prng_key_data(impl, key_data: jax.Array):
class PRNGKeyArrayMeta(abc.ABCMeta):
"""Metaclass for overriding PRNGKeyArray isinstance checks."""
def __instancecheck__(self, instance):
def __instancecheck__(cls, instance):
try:
return (isinstance(instance.aval, core.ShapedArray) and
type(instance.aval.dtype) is KeyTy)
except AttributeError:
super().__instancecheck__(instance)
return super().__instancecheck__(instance)
class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):