mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
remove PRNGKeyArray
ABC
We don't expose the `PRNGKeyArray` symbol publicly any longer and we only implement the interface in one place. PiperOrigin-RevId: 602470550
This commit is contained in:
parent
37b6d22a82
commit
a04332504b
@ -49,7 +49,7 @@ from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = tuple[slice, ...]
|
||||
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
|
||||
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
def _get_device(a: ArrayImpl) -> Device:
|
||||
assert len(a.devices()) == 1
|
||||
@ -69,7 +69,7 @@ class Shard:
|
||||
"""
|
||||
|
||||
def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
|
||||
data: None | ArrayImpl | PRNGKeyArrayImpl = None):
|
||||
data: None | ArrayImpl | PRNGKeyArray = None):
|
||||
self._device = device
|
||||
self._sharding = sharding
|
||||
self._global_shape = global_shape
|
||||
|
154
jax/_src/prng.py
154
jax/_src/prng.py
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterator, Sequence
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
@ -136,101 +135,6 @@ def _check_prng_key_data(impl, key_data: typing.Array):
|
||||
|
||||
|
||||
class PRNGKeyArray(jax.Array):
|
||||
"""An array whose elements are PRNG keys"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def block_until_ready(self) -> PRNGKeyArray: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def copy_to_host_async(self) -> None: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def shape(self) -> tuple[int, ...]: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def ndim(self) -> int: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def size(self) -> int: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def dtype(self): ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def itemsize(self): ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def sharding(self): ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def at(self) -> _IndexUpdateHelper: ... # type: ignore[override]
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int: ...
|
||||
@abc.abstractmethod
|
||||
def __iter__(self) -> Iterator[PRNGKeyArray]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def reshape(self, *args, order='C') -> PRNGKeyArray: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def T(self) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, _) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def ravel(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def squeeze(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def swapaxes(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def take(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def transpose(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def flatten(self, *_, **__) -> PRNGKeyArray: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_fully_addressable(self) -> bool: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_fully_replicated(self) -> bool: ...
|
||||
@abc.abstractmethod
|
||||
def device(self) -> Device: ...
|
||||
@abc.abstractmethod
|
||||
def devices(self) -> set[Device]: ...
|
||||
@abc.abstractmethod
|
||||
def delete(self) -> None: ...
|
||||
@abc.abstractmethod
|
||||
def is_deleted(self) -> bool: ...
|
||||
@abc.abstractmethod
|
||||
def on_device_size_in_bytes(self) -> int: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def addressable_shards(self) -> list[Shard]: ...
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def global_shards(self) -> list[Shard]: ...
|
||||
@abc.abstractmethod
|
||||
def addressable_data(self, index: int) -> PRNGKeyArray: ...
|
||||
|
||||
# TODO(jakevdp): potentially add tolist(), tobytes(),
|
||||
# device_buffer, device_buffers, __cuda_interface__()
|
||||
|
||||
|
||||
class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
"""An array of PRNG keys backed by an RNG implementation.
|
||||
|
||||
This class lifts the definition of a PRNG, provided in the form of a
|
||||
@ -243,6 +147,8 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
wrapper methods around the PRNG implementation functions (``split``,
|
||||
``random_bits``, ``fold_in``).
|
||||
"""
|
||||
# TODO(jakevdp): potentially add tolist(), tobytes(),
|
||||
# device_buffer, device_buffers, __cuda_interface__()
|
||||
|
||||
_impl: PRNGImpl
|
||||
_base_array: typing.Array
|
||||
@ -295,8 +201,8 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
|
||||
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
|
||||
|
||||
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
|
||||
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))
|
||||
def addressable_data(self, index: int) -> PRNGKeyArray:
|
||||
return PRNGKeyArray(self._impl, self._base_array.addressable_data(index))
|
||||
|
||||
@property
|
||||
def addressable_shards(self) -> list[Shard]:
|
||||
@ -305,7 +211,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
device=s._device,
|
||||
sharding=s._sharding,
|
||||
global_shape=s._global_shape,
|
||||
data=PRNGKeyArrayImpl(self._impl, s._data),
|
||||
data=PRNGKeyArray(self._impl, s._data),
|
||||
)
|
||||
for s in self._base_array.addressable_shards
|
||||
]
|
||||
@ -317,7 +223,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
device=s._device,
|
||||
sharding=s._sharding,
|
||||
global_shape=s._global_shape,
|
||||
data=PRNGKeyArrayImpl(self._impl, s._data),
|
||||
data=PRNGKeyArray(self._impl, s._data),
|
||||
)
|
||||
for s in self._base_array.global_shards
|
||||
]
|
||||
@ -336,7 +242,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
raise TypeError('len() of unsized object')
|
||||
return len(self._base_array)
|
||||
|
||||
def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
|
||||
def __iter__(self) -> Iterator[PRNGKeyArray]:
|
||||
if self._is_scalar():
|
||||
raise TypeError('iteration over a 0-d key array')
|
||||
# TODO(frostig): we may want to avoid iteration by slicing because
|
||||
@ -348,7 +254,7 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
# * return iter over these unpacked slices
|
||||
# Whatever we do, we'll want to do it by overriding
|
||||
# ShapedArray._iter when the element type is KeyTy...
|
||||
return (PRNGKeyArrayImpl(self._impl, k) for k in iter(self._base_array))
|
||||
return (PRNGKeyArray(self._impl, k) for k in iter(self._base_array))
|
||||
|
||||
def __repr__(self):
|
||||
return (f'Array({self.shape}, dtype={self.dtype.name}) overlaying:\n'
|
||||
@ -381,26 +287,26 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
def take(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
|
||||
_set_array_base_attributes(PRNGKeyArrayImpl, include=[
|
||||
_set_array_base_attributes(PRNGKeyArray, include=[
|
||||
*(f"__{op}__" for op in _array_operators),
|
||||
'at', 'flatten', 'ravel', 'reshape',
|
||||
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
|
||||
|
||||
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
|
||||
api_util._shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
|
||||
|
||||
def prngkeyarrayimpl_flatten(x):
|
||||
def prngkeyarray_flatten(x):
|
||||
return (x._base_array,), x._impl
|
||||
|
||||
def prngkeyarrayimpl_unflatten(impl, children):
|
||||
def prngkeyarray_unflatten(impl, children):
|
||||
base_array, = children
|
||||
return PRNGKeyArrayImpl(impl, base_array)
|
||||
return PRNGKeyArray(impl, base_array)
|
||||
|
||||
tree_util_internal.dispatch_registry.register_node(
|
||||
PRNGKeyArrayImpl, prngkeyarrayimpl_flatten, prngkeyarrayimpl_unflatten)
|
||||
PRNGKeyArray, prngkeyarray_flatten, prngkeyarray_unflatten)
|
||||
|
||||
|
||||
# TODO(frostig): remove, rerouting callers directly to random_seed
|
||||
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArrayImpl:
|
||||
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray:
|
||||
return random_seed(seed, impl=impl)
|
||||
|
||||
|
||||
@ -499,7 +405,7 @@ class KeyTyRules:
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return PRNGKeyArrayImpl(aval.dtype._impl, buf)
|
||||
return PRNGKeyArray(aval.dtype._impl, buf)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
@ -524,7 +430,7 @@ class KeyTyRules:
|
||||
|
||||
# set up a handler that calls the physical one and wraps back up
|
||||
def handler(bufs):
|
||||
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
|
||||
|
||||
return handler
|
||||
|
||||
@ -539,7 +445,7 @@ class KeyTyRules:
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
|
||||
is_out_sharding_from_xla)
|
||||
def handler(bufs):
|
||||
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
@ -551,7 +457,7 @@ class KeyTyRules:
|
||||
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
|
||||
phys_result = phys_handler(phys_arrays)
|
||||
return PRNGKeyArrayImpl(aval.dtype._impl, phys_result)
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_result)
|
||||
|
||||
@staticmethod
|
||||
def device_put_sharded(vals, aval, sharding, devices):
|
||||
@ -617,26 +523,26 @@ class KeyTy(dtypes.ExtendedDType):
|
||||
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
|
||||
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding):
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
|
||||
arr = x._base_array
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
x.aval, sharding, is_sharding_from_xla=False)
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
|
||||
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
|
||||
|
||||
|
||||
def key_array_constant_handler(x):
|
||||
arr = x._base_array
|
||||
return mlir.get_constant_handler(type(arr))(arr)
|
||||
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
|
||||
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
|
||||
|
||||
|
||||
# -- primitives
|
||||
@ -681,7 +587,7 @@ def iterated_vmap_binary_bcast(shape1, shape2, f):
|
||||
return f
|
||||
|
||||
|
||||
def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArrayImpl:
|
||||
def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray:
|
||||
# Avoid overflow error in X32 mode by first converting ints to int64.
|
||||
# This breaks JIT invariance for large ints, but supports the common
|
||||
# use-case of instantiating with Python hashes in X32 mode.
|
||||
@ -704,7 +610,7 @@ def random_seed_abstract_eval(seeds_aval, *, impl):
|
||||
@random_seed_p.def_impl
|
||||
def random_seed_impl(seeds, *, impl):
|
||||
base_arr = random_seed_impl_base(seeds, impl=impl)
|
||||
return PRNGKeyArrayImpl(impl, base_arr)
|
||||
return PRNGKeyArray(impl, base_arr)
|
||||
|
||||
def random_seed_impl_base(seeds, *, impl):
|
||||
seed = iterated_vmap_unary(np.ndim(seeds), impl.seed)
|
||||
@ -736,7 +642,7 @@ def random_split_abstract_eval(keys_aval, *, shape):
|
||||
def random_split_impl(keys, *, shape):
|
||||
base_arr = random_split_impl_base(
|
||||
keys._impl, keys._base_array, keys.ndim, shape=shape)
|
||||
return PRNGKeyArrayImpl(keys._impl, base_arr)
|
||||
return PRNGKeyArray(keys._impl, base_arr)
|
||||
|
||||
def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
|
||||
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
|
||||
@ -773,7 +679,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
|
||||
def random_fold_in_impl(keys, msgs):
|
||||
base_arr = random_fold_in_impl_base(
|
||||
keys._impl, keys._base_array, msgs, keys.shape)
|
||||
return PRNGKeyArrayImpl(keys._impl, base_arr)
|
||||
return PRNGKeyArray(keys._impl, base_arr)
|
||||
|
||||
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
|
||||
fold_in = iterated_vmap_binary_bcast(
|
||||
@ -877,7 +783,7 @@ def random_wrap_abstract_eval(base_arr_aval, *, impl):
|
||||
|
||||
@random_wrap_p.def_impl
|
||||
def random_wrap_impl(base_arr, *, impl):
|
||||
return PRNGKeyArrayImpl(impl, base_arr)
|
||||
return PRNGKeyArray(impl, base_arr)
|
||||
|
||||
def random_wrap_lowering(ctx, base_arr, *, impl):
|
||||
return [base_arr]
|
||||
|
Loading…
x
Reference in New Issue
Block a user