remove PRNGKeyArray ABC

We don't expose the `PRNGKeyArray` symbol publicly any longer and we only implement the interface in one place.

PiperOrigin-RevId: 602470550
This commit is contained in:
Roy Frostig 2024-01-29 12:40:47 -08:00 committed by jax authors
parent 37b6d22a82
commit a04332504b
2 changed files with 32 additions and 126 deletions

View File

@ -49,7 +49,7 @@ from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
Shape = tuple[int, ...]
Device = xc.Device
Index = tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
def _get_device(a: ArrayImpl) -> Device:
assert len(a.devices()) == 1
@ -69,7 +69,7 @@ class Shard:
"""
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
data: None | ArrayImpl | PRNGKeyArrayImpl = None):
data: None | ArrayImpl | PRNGKeyArray = None):
self._device = device
self._sharding = sharding
self._global_shape = global_shape

View File

@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import abc
from collections.abc import Iterator, Sequence
from functools import partial, reduce
import math
@ -136,101 +135,6 @@ def _check_prng_key_data(impl, key_data: typing.Array):
class PRNGKeyArray(jax.Array):
"""An array whose elements are PRNG keys"""
@abc.abstractmethod
def unsafe_buffer_pointer(self) -> int: ...
@abc.abstractmethod
def block_until_ready(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def copy_to_host_async(self) -> None: ...
@property
@abc.abstractmethod
def shape(self) -> tuple[int, ...]: ...
@property
@abc.abstractmethod
def ndim(self) -> int: ...
@property
@abc.abstractmethod
def size(self) -> int: ...
@property
@abc.abstractmethod
def dtype(self): ...
@property
@abc.abstractmethod
def itemsize(self): ...
@property
@abc.abstractmethod
def sharding(self): ...
@property
@abc.abstractmethod
def at(self) -> _IndexUpdateHelper: ... # type: ignore[override]
@abc.abstractmethod
def __len__(self) -> int: ...
@abc.abstractmethod
def __iter__(self) -> Iterator[PRNGKeyArray]: ...
@abc.abstractmethod
def reshape(self, *args, order='C') -> PRNGKeyArray: ...
@property
@abc.abstractmethod
def T(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def __getitem__(self, _) -> PRNGKeyArray: ...
@abc.abstractmethod
def ravel(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def squeeze(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def swapaxes(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def take(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def transpose(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def flatten(self, *_, **__) -> PRNGKeyArray: ...
@property
@abc.abstractmethod
def is_fully_addressable(self) -> bool: ...
@property
@abc.abstractmethod
def is_fully_replicated(self) -> bool: ...
@abc.abstractmethod
def device(self) -> Device: ...
@abc.abstractmethod
def devices(self) -> set[Device]: ...
@abc.abstractmethod
def delete(self) -> None: ...
@abc.abstractmethod
def is_deleted(self) -> bool: ...
@abc.abstractmethod
def on_device_size_in_bytes(self) -> int: ...
@property
@abc.abstractmethod
def addressable_shards(self) -> list[Shard]: ...
@property
@abc.abstractmethod
def global_shards(self) -> list[Shard]: ...
@abc.abstractmethod
def addressable_data(self, index: int) -> PRNGKeyArray: ...
# TODO(jakevdp): potentially add tolist(), tobytes(),
# device_buffer, device_buffers, __cuda_interface__()
class PRNGKeyArrayImpl(PRNGKeyArray):
"""An array of PRNG keys backed by an RNG implementation.
This class lifts the definition of a PRNG, provided in the form of a
@ -243,6 +147,8 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
wrapper methods around the PRNG implementation functions (``split``,
``random_bits``, ``fold_in``).
"""
# TODO(jakevdp): potentially add tolist(), tobytes(),
# device_buffer, device_buffers, __cuda_interface__()
_impl: PRNGImpl
_base_array: typing.Array
@ -295,8 +201,8 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))
def addressable_data(self, index: int) -> PRNGKeyArray:
return PRNGKeyArray(self._impl, self._base_array.addressable_data(index))
@property
def addressable_shards(self) -> list[Shard]:
@ -305,7 +211,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self._impl, s._data),
data=PRNGKeyArray(self._impl, s._data),
)
for s in self._base_array.addressable_shards
]
@ -317,7 +223,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self._impl, s._data),
data=PRNGKeyArray(self._impl, s._data),
)
for s in self._base_array.global_shards
]
@ -336,7 +242,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
raise TypeError('len() of unsized object')
return len(self._base_array)
def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
def __iter__(self) -> Iterator[PRNGKeyArray]:
if self._is_scalar():
raise TypeError('iteration over a 0-d key array')
# TODO(frostig): we may want to avoid iteration by slicing because
@ -348,7 +254,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
# * return iter over these unpacked slices
# Whatever we do, we'll want to do it by overriding
# ShapedArray._iter when the element type is KeyTy...
return (PRNGKeyArrayImpl(self._impl, k) for k in iter(self._base_array))
return (PRNGKeyArray(self._impl, k) for k in iter(self._base_array))
def __repr__(self):
return (f'Array({self.shape}, dtype={self.dtype.name}) overlaying:\n'
@ -381,26 +287,26 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
def take(self, *_, **__) -> PRNGKeyArray: assert False
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
_set_array_base_attributes(PRNGKeyArrayImpl, include=[
_set_array_base_attributes(PRNGKeyArray, include=[
*(f"__{op}__" for op in _array_operators),
'at', 'flatten', 'ravel', 'reshape',
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
api_util._shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
def prngkeyarrayimpl_flatten(x):
def prngkeyarray_flatten(x):
return (x._base_array,), x._impl
def prngkeyarrayimpl_unflatten(impl, children):
def prngkeyarray_unflatten(impl, children):
base_array, = children
return PRNGKeyArrayImpl(impl, base_array)
return PRNGKeyArray(impl, base_array)
tree_util_internal.dispatch_registry.register_node(
PRNGKeyArrayImpl, prngkeyarrayimpl_flatten, prngkeyarrayimpl_unflatten)
PRNGKeyArray, prngkeyarray_flatten, prngkeyarray_unflatten)
# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArrayImpl:
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray:
return random_seed(seed, impl=impl)
@ -499,7 +405,7 @@ class KeyTyRules:
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return PRNGKeyArrayImpl(aval.dtype._impl, buf)
return PRNGKeyArray(aval.dtype._impl, buf)
return handler
@staticmethod
@ -524,7 +430,7 @@ class KeyTyRules:
# set up a handler that calls the physical one and wraps back up
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
return handler
@ -539,7 +445,7 @@ class KeyTyRules:
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
return handler
@staticmethod
@ -551,7 +457,7 @@ class KeyTyRules:
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
phys_result = phys_handler(phys_arrays)
return PRNGKeyArrayImpl(aval.dtype._impl, phys_result)
return PRNGKeyArray(aval.dtype._impl, phys_result)
@staticmethod
def device_put_sharded(vals, aval, sharding, devices):
@ -617,26 +523,26 @@ class KeyTy(dtypes.ExtendedDType):
core.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding):
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
arr = x._base_array
phys_sharding = make_key_array_phys_sharding(
x.aval, sharding, is_sharding_from_xla=False)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
def key_array_constant_handler(x):
arr = x._base_array
return mlir.get_constant_handler(type(arr))(arr)
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
# -- primitives
@ -681,7 +587,7 @@ def iterated_vmap_binary_bcast(shape1, shape2, f):
return f
def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArrayImpl:
def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray:
# Avoid overflow error in X32 mode by first converting ints to int64.
# This breaks JIT invariance for large ints, but supports the common
# use-case of instantiating with Python hashes in X32 mode.
@ -704,7 +610,7 @@ def random_seed_abstract_eval(seeds_aval, *, impl):
@random_seed_p.def_impl
def random_seed_impl(seeds, *, impl):
base_arr = random_seed_impl_base(seeds, impl=impl)
return PRNGKeyArrayImpl(impl, base_arr)
return PRNGKeyArray(impl, base_arr)
def random_seed_impl_base(seeds, *, impl):
seed = iterated_vmap_unary(np.ndim(seeds), impl.seed)
@ -736,7 +642,7 @@ def random_split_abstract_eval(keys_aval, *, shape):
def random_split_impl(keys, *, shape):
base_arr = random_split_impl_base(
keys._impl, keys._base_array, keys.ndim, shape=shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)
return PRNGKeyArray(keys._impl, base_arr)
def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
@ -773,7 +679,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
def random_fold_in_impl(keys, msgs):
base_arr = random_fold_in_impl_base(
keys._impl, keys._base_array, msgs, keys.shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)
return PRNGKeyArray(keys._impl, base_arr)
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
fold_in = iterated_vmap_binary_bcast(
@ -877,7 +783,7 @@ def random_wrap_abstract_eval(base_arr_aval, *, impl):
@random_wrap_p.def_impl
def random_wrap_impl(base_arr, *, impl):
return PRNGKeyArrayImpl(impl, base_arr)
return PRNGKeyArray(impl, base_arr)
def random_wrap_lowering(ctx, base_arr, *, impl):
return [base_arr]