mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[custom prng] make PRNGKeyArray a subclass of jax.Array
This commit is contained in:
parent
b20b93e9c9
commit
ea5f126e85
@ -22,6 +22,7 @@ from typing import Any, Callable, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
@ -136,13 +137,12 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
|
||||
return super().__instancecheck__(instance)
|
||||
|
||||
|
||||
class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
class PRNGKeyArray(jax.Array, metaclass=PRNGKeyArrayMeta):
|
||||
"""An array whose elements are PRNG keys"""
|
||||
|
||||
@abc.abstractmethod # TODO(frostig): rename
|
||||
def unsafe_raw_array(self) -> PRNGKeyArray: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
|
||||
@ -174,7 +174,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def at(self) -> _IndexUpdateHelper: ...
|
||||
def at(self) -> _IndexUpdateHelper: ... # type: ignore[override]
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int: ...
|
||||
@ -182,7 +182,7 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
def __iter__(self) -> Iterator[PRNGKeyArray]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def reshape(self, newshape, order=None) -> PRNGKeyArray: ...
|
||||
def reshape(self, *args, order='C') -> PRNGKeyArray: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@ -380,7 +380,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
|
||||
# Overwritten immediately below
|
||||
@property
|
||||
def at(self) -> _IndexUpdateHelper: assert False
|
||||
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]
|
||||
@property
|
||||
def T(self) -> PRNGKeyArray: assert False
|
||||
def __getitem__(self, _) -> PRNGKeyArray: assert False
|
||||
|
Loading…
x
Reference in New Issue
Block a user