make PRNGKeyArray abstract, separate from implementation

We expose the `PRNGKeyArray` symbol publicly, at least for use in
annotations (especially by libraries). Separating interface from
implementation helps ensure no instantiations. Also, should anyone try
to inherit from the public type, they will not pick up all of the
magic behavior of the implementing class (e.g. presence in pytype-aval
mappings).

This reflects what we do with `jax.Array` as well.

Makes a few other annotation fixups in `jax._src.prng` along the way.
This commit is contained in:
Roy Frostig 2023-04-06 13:27:31 -07:00
parent b4402185db
commit cd5e2380d8
2 changed files with 100 additions and 37 deletions

View File

@ -17,7 +17,7 @@ import abc
from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence, Union
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence, Tuple, Union
import numpy as np
@ -128,8 +128,66 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
return super().__instancecheck__(instance)
class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
"""An array whose elements are PRNG keys.
class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
"""An array whose elements are PRNG keys"""
@abc.abstractmethod # TODO(frostig): rename
def unsafe_raw_array(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def block_until_ready(self) -> PRNGKeyArray: ...
@property
@abc.abstractmethod
def shape(self) -> Tuple[int, ...]: ...
@property
@abc.abstractmethod
def ndim(self) -> int: ...
@property
@abc.abstractmethod
def dtype(self): ...
@property
@abc.abstractmethod
def sharding(self): ...
@abc.abstractmethod
def __len__(self) -> int: ...
@abc.abstractmethod
def __iter__(self) -> Iterator[PRNGKeyArray]: ...
@abc.abstractmethod
def reshape(self, newshape, order=None) -> PRNGKeyArray: ...
@abc.abstractmethod
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArray: ...
@abc.abstractmethod
def broadcast_to(self, shape) -> PRNGKeyArray: ...
@abc.abstractmethod
def expand_dims(self, dimensions: Sequence[int]) -> PRNGKeyArray: ...
@property
@abc.abstractmethod
def T(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def __getitem__(self, _) -> PRNGKeyArray: ...
@abc.abstractmethod
def ravel(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def squeeze(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def swapaxes(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def take(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def transpose(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def flatten(self, *_, **__) -> PRNGKeyArray: ...
class PRNGKeyArrayImpl(PRNGKeyArray):
"""An array of PRNG keys backed by an RNG implementation.
This class lifts the definition of a PRNG, provided in the form of a
``PRNGImpl``, into an array-like pytree class. Instances of this
@ -194,7 +252,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
raise TypeError('len() of unsized object')
return len(self._base_array)
def __iter__(self) -> Iterator['PRNGKeyArray']:
def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
if self._is_scalar():
raise TypeError('iteration over a 0-d key array')
# TODO(frostig): we may want to avoid iteration by slicing because
@ -206,7 +264,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
# * return iter over these unpacked slices
# Whatever we do, we'll want to do it by overriding
# ShapedArray._iter when the element type is KeyTy...
return (PRNGKeyArray(self.impl, k) for k in iter(self._base_array))
return (PRNGKeyArrayImpl(self.impl, k) for k in iter(self._base_array))
# TODO(frostig): are all of the stackable methods below (reshape,
# concat, broadcast_to, expand_dims), and the stackable registration,
@ -214,30 +272,30 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
# to remove stackables altogether? This may be the only application.
# TODO(frostig): Remove? Overwritten below in particular
def reshape(self, newshape, order=None) -> 'PRNGKeyArray':
def reshape(self, newshape, order=None) -> PRNGKeyArrayImpl:
reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order)
return PRNGKeyArray(self.impl, reshaped_base)
return PRNGKeyArrayImpl(self.impl, reshaped_base)
def concatenate(self, key_arrs, axis, dtype=None):
def concatenate(self, key_arrs, axis, dtype=None) -> PRNGKeyArrayImpl:
if dtype is not None:
raise ValueError(
'dtype argument not supported for concatenating PRNGKeyArray')
axis = canonicalize_axis(axis, self.ndim)
arrs = [self._base_array, *[k._base_array for k in key_arrs]]
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
return PRNGKeyArrayImpl(self.impl, jnp.concatenate(arrs, axis))
def broadcast_to(self, shape):
def broadcast_to(self, shape) -> PRNGKeyArrayImpl:
if jnp.ndim(shape) == 0:
shape = (shape,)
new_shape = (*shape, *self.impl.key_shape)
return PRNGKeyArray(
return PRNGKeyArrayImpl(
self.impl, jnp.broadcast_to(self._base_array, new_shape))
def expand_dims(self, dimensions: Sequence[int]):
def expand_dims(self, dimensions: Sequence[int]) -> PRNGKeyArrayImpl:
# follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
ndim_out = self.ndim + len(set(dimensions))
dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
return PRNGKeyArray(
return PRNGKeyArrayImpl(
self.impl, lax.expand_dims(self._base_array, dimensions))
def __repr__(self):
@ -251,11 +309,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
pp.text('PRNGKeyArray:') +
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
# Hollow defs only for typing purposes, overwritten below
#
# TODO(frostig): there may be a better way to do this with
# `typing.type_check_only`.
# Overwritten immediately below
@property
def T(self) -> PRNGKeyArray: assert False
def __getitem__(self, _) -> PRNGKeyArray: assert False
@ -266,16 +320,15 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
_set_device_array_base_attributes(PRNGKeyArray, include=[
_set_device_array_base_attributes(PRNGKeyArrayImpl, include=[
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
'transpose', 'flatten', 'T'])
_register_stackable(PRNGKeyArray)
basearray.Array.register(PRNGKeyArray)
_register_stackable(PRNGKeyArrayImpl)
basearray.Array.register(PRNGKeyArrayImpl)
# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArray:
def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
return random_seed(seed, impl=impl)
@ -347,7 +400,7 @@ class KeyTyRules:
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return PRNGKeyArray(aval.dtype.impl, buf)
return PRNGKeyArrayImpl(aval.dtype.impl, buf)
return handler
@staticmethod
@ -372,7 +425,7 @@ class KeyTyRules:
# set up a handler that calls the physical one and wraps back up
def handler(bufs):
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
return handler
@ -387,7 +440,7 @@ class KeyTyRules:
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
return handler
# element-type-polymorphic primitive lowering rules
@ -491,16 +544,16 @@ class KeyTy:
core.opaque_dtypes.add(KeyTy)
core.pytype_aval_mappings[PRNGKeyArray] = (
core.pytype_aval_mappings[PRNGKeyArrayImpl] = (
lambda x: keys_shaped_array(x.impl, x.shape))
xla.pytype_aval_mappings[PRNGKeyArray] = (
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = (
lambda x: keys_shaped_array(x.impl, x.shape))
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, sharding):
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
# TODO(frostig): Remove the need for `core.get_aval`.
aval = core.get_aval(x)
key_shape = aval.dtype.impl.key_shape
@ -517,13 +570,13 @@ def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, sharding):
)
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
def key_array_constant_handler(x, canonicalize_dtypes):
arr = x.unsafe_raw_array()
return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes)
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
# -- primitives
@ -589,7 +642,7 @@ def random_seed_abstract_eval(seeds_aval, *, impl):
@random_seed_p.def_impl
def random_seed_impl(seeds, *, impl):
base_arr = random_seed_impl_base(seeds, impl=impl)
return PRNGKeyArray(impl, base_arr)
return PRNGKeyArrayImpl(impl, base_arr)
def random_seed_impl_base(seeds, *, impl):
seed = iterated_vmap_unary(seeds.ndim, impl.seed)
@ -621,7 +674,7 @@ def random_split_abstract_eval(keys_aval, *, count):
def random_split_impl(keys, *, count):
base_arr = random_split_impl_base(
keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
return PRNGKeyArray(keys.impl, base_arr)
return PRNGKeyArrayImpl(keys.impl, base_arr)
def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
@ -658,7 +711,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
def random_fold_in_impl(keys, msgs):
base_arr = random_fold_in_impl_base(
keys.impl, keys.unsafe_raw_array(), msgs, keys.shape)
return PRNGKeyArray(keys.impl, base_arr)
return PRNGKeyArrayImpl(keys.impl, base_arr)
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
fold_in = iterated_vmap_binary_bcast(
@ -762,7 +815,7 @@ def random_wrap_abstract_eval(base_arr_aval, *, impl):
@random_wrap_p.def_impl
def random_wrap_impl(base_arr, *, impl):
return PRNGKeyArray(impl, base_arr)
return PRNGKeyArrayImpl(impl, base_arr)
def random_wrap_lowering(ctx, base_arr, *, impl):
return [base_arr]
@ -778,7 +831,7 @@ batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
def random_unwrap(keys):
if not isinstance(keys, PRNGKeyArray):
if not isinstance(keys, PRNGKeyArrayImpl):
raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}')
return random_unwrap_p.bind(keys)

View File

@ -1605,6 +1605,16 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(g[0], k1.dtype)
self.assertEqual(g[0], k2.dtype)
def test_isinstance(self):
@jax.jit
def f(k):
self.assertIsInstance(k, random.KeyArray)
return k
k1 = self.make_keys()
k2 = f(k1)
self.assertIsInstance(k1, random.KeyArray)
self.assertIsInstance(k2, random.KeyArray)
# -- prng primitives