mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
make PRNGKeyArray
abstract, separate from implementation
We expose the `PRNGKeyArray` symbol publicly, at least for use in annotations (especially by libraries). Separating interface from implementation helps ensure no instantiations. Also, should anyone try to inherit from the public type, they will not pick up all of the magic behavior of the implementing class (e.g. presence in pytype-aval mappings). This reflects what we do with `jax.Array` as well. Makes a few other annotation fixups in `jax._src.prng` along the way.
This commit is contained in:
parent
b4402185db
commit
cd5e2380d8
127
jax/_src/prng.py
127
jax/_src/prng.py
@ -17,7 +17,7 @@ import abc
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
import operator as op
|
||||
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence, Union
|
||||
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -128,8 +128,66 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
|
||||
return super().__instancecheck__(instance)
|
||||
|
||||
|
||||
class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
"""An array whose elements are PRNG keys.
|
||||
class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
"""An array whose elements are PRNG keys"""
|
||||
|
||||
@abc.abstractmethod # TODO(frostig): rename
|
||||
def unsafe_raw_array(self) -> PRNGKeyArray: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def block_until_ready(self) -> PRNGKeyArray: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def shape(self) -> Tuple[int, ...]: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def ndim(self) -> int: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def dtype(self): ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def sharding(self): ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int: ...
|
||||
@abc.abstractmethod
|
||||
def __iter__(self) -> Iterator[PRNGKeyArray]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def reshape(self, newshape, order=None) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def broadcast_to(self, shape) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def expand_dims(self, dimensions: Sequence[int]) -> PRNGKeyArray: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def T(self) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, _) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def ravel(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def squeeze(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def swapaxes(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def take(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def transpose(self, *_, **__) -> PRNGKeyArray: ...
|
||||
@abc.abstractmethod
|
||||
def flatten(self, *_, **__) -> PRNGKeyArray: ...
|
||||
|
||||
|
||||
class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
"""An array of PRNG keys backed by an RNG implementation.
|
||||
|
||||
This class lifts the definition of a PRNG, provided in the form of a
|
||||
``PRNGImpl``, into an array-like pytree class. Instances of this
|
||||
@ -194,7 +252,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
raise TypeError('len() of unsized object')
|
||||
return len(self._base_array)
|
||||
|
||||
def __iter__(self) -> Iterator['PRNGKeyArray']:
|
||||
def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
|
||||
if self._is_scalar():
|
||||
raise TypeError('iteration over a 0-d key array')
|
||||
# TODO(frostig): we may want to avoid iteration by slicing because
|
||||
@ -206,7 +264,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
# * return iter over these unpacked slices
|
||||
# Whatever we do, we'll want to do it by overriding
|
||||
# ShapedArray._iter when the element type is KeyTy...
|
||||
return (PRNGKeyArray(self.impl, k) for k in iter(self._base_array))
|
||||
return (PRNGKeyArrayImpl(self.impl, k) for k in iter(self._base_array))
|
||||
|
||||
# TODO(frostig): are all of the stackable methods below (reshape,
|
||||
# concat, broadcast_to, expand_dims), and the stackable registration,
|
||||
@ -214,30 +272,30 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
# to remove stackables altogether? This may be the only application.
|
||||
|
||||
# TODO(frostig): Remove? Overwritten below in particular
|
||||
def reshape(self, newshape, order=None) -> 'PRNGKeyArray':
|
||||
def reshape(self, newshape, order=None) -> PRNGKeyArrayImpl:
|
||||
reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order)
|
||||
return PRNGKeyArray(self.impl, reshaped_base)
|
||||
return PRNGKeyArrayImpl(self.impl, reshaped_base)
|
||||
|
||||
def concatenate(self, key_arrs, axis, dtype=None):
|
||||
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArrayImpl:
|
||||
if dtype is not None:
|
||||
raise ValueError(
|
||||
'dtype argument not supported for concatenating PRNGKeyArray')
|
||||
axis = canonicalize_axis(axis, self.ndim)
|
||||
arrs = [self._base_array, *[k._base_array for k in key_arrs]]
|
||||
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
|
||||
return PRNGKeyArrayImpl(self.impl, jnp.concatenate(arrs, axis))
|
||||
|
||||
def broadcast_to(self, shape):
|
||||
def broadcast_to(self, shape) -> PRNGKeyArrayImpl:
|
||||
if jnp.ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
new_shape = (*shape, *self.impl.key_shape)
|
||||
return PRNGKeyArray(
|
||||
return PRNGKeyArrayImpl(
|
||||
self.impl, jnp.broadcast_to(self._base_array, new_shape))
|
||||
|
||||
def expand_dims(self, dimensions: Sequence[int]):
|
||||
def expand_dims(self, dimensions: Sequence[int]) -> PRNGKeyArrayImpl:
|
||||
# follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
|
||||
ndim_out = self.ndim + len(set(dimensions))
|
||||
dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
|
||||
return PRNGKeyArray(
|
||||
return PRNGKeyArrayImpl(
|
||||
self.impl, lax.expand_dims(self._base_array, dimensions))
|
||||
|
||||
def __repr__(self):
|
||||
@ -251,11 +309,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
pp.text('PRNGKeyArray:') +
|
||||
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
|
||||
|
||||
# Hollow defs only for typing purposes, overwritten below
|
||||
#
|
||||
# TODO(frostig): there may be a better way to do this with
|
||||
# `typing.type_check_only`.
|
||||
|
||||
# Overwritten immediately below
|
||||
@property
|
||||
def T(self) -> PRNGKeyArray: assert False
|
||||
def __getitem__(self, _) -> PRNGKeyArray: assert False
|
||||
@ -266,16 +320,15 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
|
||||
|
||||
_set_device_array_base_attributes(PRNGKeyArray, include=[
|
||||
_set_device_array_base_attributes(PRNGKeyArrayImpl, include=[
|
||||
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
|
||||
'transpose', 'flatten', 'T'])
|
||||
_register_stackable(PRNGKeyArray)
|
||||
basearray.Array.register(PRNGKeyArray)
|
||||
_register_stackable(PRNGKeyArrayImpl)
|
||||
basearray.Array.register(PRNGKeyArrayImpl)
|
||||
|
||||
|
||||
# TODO(frostig): remove, rerouting callers directly to random_seed
|
||||
def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArray:
|
||||
def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
|
||||
return random_seed(seed, impl=impl)
|
||||
|
||||
|
||||
@ -347,7 +400,7 @@ class KeyTyRules:
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return PRNGKeyArray(aval.dtype.impl, buf)
|
||||
return PRNGKeyArrayImpl(aval.dtype.impl, buf)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
@ -372,7 +425,7 @@ class KeyTyRules:
|
||||
|
||||
# set up a handler that calls the physical one and wraps back up
|
||||
def handler(bufs):
|
||||
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
||||
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
|
||||
|
||||
return handler
|
||||
|
||||
@ -387,7 +440,7 @@ class KeyTyRules:
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
|
||||
is_out_sharding_from_xla)
|
||||
def handler(bufs):
|
||||
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
||||
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
# element-type-polymorphic primitive lowering rules
|
||||
@ -491,16 +544,16 @@ class KeyTy:
|
||||
core.opaque_dtypes.add(KeyTy)
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = (
|
||||
core.pytype_aval_mappings[PRNGKeyArrayImpl] = (
|
||||
lambda x: keys_shaped_array(x.impl, x.shape))
|
||||
|
||||
xla.pytype_aval_mappings[PRNGKeyArray] = (
|
||||
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = (
|
||||
lambda x: keys_shaped_array(x.impl, x.shape))
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
|
||||
|
||||
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, sharding):
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
|
||||
# TODO(frostig): Remove the need for `core.get_aval`.
|
||||
aval = core.get_aval(x)
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
@ -517,13 +570,13 @@ def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, sharding):
|
||||
)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
|
||||
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
|
||||
|
||||
|
||||
def key_array_constant_handler(x, canonicalize_dtypes):
|
||||
arr = x.unsafe_raw_array()
|
||||
return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes)
|
||||
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
|
||||
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
|
||||
|
||||
|
||||
# -- primitives
|
||||
@ -589,7 +642,7 @@ def random_seed_abstract_eval(seeds_aval, *, impl):
|
||||
@random_seed_p.def_impl
|
||||
def random_seed_impl(seeds, *, impl):
|
||||
base_arr = random_seed_impl_base(seeds, impl=impl)
|
||||
return PRNGKeyArray(impl, base_arr)
|
||||
return PRNGKeyArrayImpl(impl, base_arr)
|
||||
|
||||
def random_seed_impl_base(seeds, *, impl):
|
||||
seed = iterated_vmap_unary(seeds.ndim, impl.seed)
|
||||
@ -621,7 +674,7 @@ def random_split_abstract_eval(keys_aval, *, count):
|
||||
def random_split_impl(keys, *, count):
|
||||
base_arr = random_split_impl_base(
|
||||
keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
|
||||
return PRNGKeyArray(keys.impl, base_arr)
|
||||
return PRNGKeyArrayImpl(keys.impl, base_arr)
|
||||
|
||||
def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
|
||||
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
|
||||
@ -658,7 +711,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
|
||||
def random_fold_in_impl(keys, msgs):
|
||||
base_arr = random_fold_in_impl_base(
|
||||
keys.impl, keys.unsafe_raw_array(), msgs, keys.shape)
|
||||
return PRNGKeyArray(keys.impl, base_arr)
|
||||
return PRNGKeyArrayImpl(keys.impl, base_arr)
|
||||
|
||||
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
|
||||
fold_in = iterated_vmap_binary_bcast(
|
||||
@ -762,7 +815,7 @@ def random_wrap_abstract_eval(base_arr_aval, *, impl):
|
||||
|
||||
@random_wrap_p.def_impl
|
||||
def random_wrap_impl(base_arr, *, impl):
|
||||
return PRNGKeyArray(impl, base_arr)
|
||||
return PRNGKeyArrayImpl(impl, base_arr)
|
||||
|
||||
def random_wrap_lowering(ctx, base_arr, *, impl):
|
||||
return [base_arr]
|
||||
@ -778,7 +831,7 @@ batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
|
||||
|
||||
|
||||
def random_unwrap(keys):
|
||||
if not isinstance(keys, PRNGKeyArray):
|
||||
if not isinstance(keys, PRNGKeyArrayImpl):
|
||||
raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}')
|
||||
return random_unwrap_p.bind(keys)
|
||||
|
||||
|
@ -1605,6 +1605,16 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(g[0], k1.dtype)
|
||||
self.assertEqual(g[0], k2.dtype)
|
||||
|
||||
def test_isinstance(self):
|
||||
@jax.jit
|
||||
def f(k):
|
||||
self.assertIsInstance(k, random.KeyArray)
|
||||
return k
|
||||
|
||||
k1 = self.make_keys()
|
||||
k2 = f(k1)
|
||||
self.assertIsInstance(k1, random.KeyArray)
|
||||
self.assertIsInstance(k2, random.KeyArray)
|
||||
|
||||
# -- prng primitives
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user