[custom prng] make PRNGKeyArray a subclass of jax.Array

This commit is contained in:
Jake VanderPlas 2023-09-12 13:46:22 -07:00
parent b20b93e9c9
commit ea5f126e85

View File

@ -22,6 +22,7 @@ from typing import Any, Callable, NamedTuple
import numpy as np
import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util
@ -136,13 +137,12 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
return super().__instancecheck__(instance)
class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
class PRNGKeyArray(jax.Array, metaclass=PRNGKeyArrayMeta):
"""An array whose elements are PRNG keys"""
@abc.abstractmethod # TODO(frostig): rename
def unsafe_raw_array(self) -> PRNGKeyArray: ...
@property
@abc.abstractmethod
def unsafe_buffer_pointer(self) -> int: ...
@ -174,7 +174,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
@property
@abc.abstractmethod
def at(self) -> _IndexUpdateHelper: ...
def at(self) -> _IndexUpdateHelper: ... # type: ignore[override]
@abc.abstractmethod
def __len__(self) -> int: ...
@ -182,7 +182,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
def __iter__(self) -> Iterator[PRNGKeyArray]: ...
@abc.abstractmethod
def reshape(self, newshape, order=None) -> PRNGKeyArray: ...
def reshape(self, *args, order='C') -> PRNGKeyArray: ...
@property
@abc.abstractmethod
@ -380,7 +380,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
# Overwritten immediately below
@property
def at(self) -> _IndexUpdateHelper: assert False
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]
@property
def T(self) -> PRNGKeyArray: assert False
def __getitem__(self, _) -> PRNGKeyArray: assert False