sketch: setup for new key array implementation based on eltypes

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Roy Frostig 2022-07-27 18:26:42 -07:00
parent 39d54bdbf6
commit acb5e491ab
4 changed files with 282 additions and 25 deletions

View File

@ -108,9 +108,14 @@ def apply_primitive(prim, *args, **params):
**params)
return compiled_fun(*args)
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
# TODO(phawkins,frostig,mattjj): update code referring to
# xla.apply_primitive to point here, or use simple_impl if that's why
# it is using apply_primitive to begin with
xla.apply_primitive = apply_primitive
def simple_impl(prim):
prim.def_impl(partial(apply_primitive, prim))
RuntimeToken = Any
class RuntimeTokenSet(threading.local):

View File

@ -14,7 +14,7 @@
from functools import partial
from typing import Callable, Iterator, NamedTuple, Sequence
from typing import Callable, Hashable, Iterator, NamedTuple, Sequence
import warnings
import numpy as np
@ -29,6 +29,9 @@ from jax.dtypes import float0
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import dispatch
from jax._src import dtypes
from jax._src.api import jit, vmap
from jax._src.lax import lax as lax_internal
from jax._src.lib.mlir.dialects import mhlo
@ -36,15 +39,18 @@ from jax._src.numpy.lax_numpy import (
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
_expand_bool_indices, _register_stackable)
import jax._src.pretty_printer as pp
from jax._src.util import canonicalize_axis, prod
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
from jax._src.lib import gpu_prng
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
UINT_DTYPES = {
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
# -- PRNG implementation interface --
# -- PRNG implementation interface
class PRNGImpl(NamedTuple):
"""Specifies PRNG key shape and operations.
@ -68,15 +74,22 @@ class PRNGImpl(NamedTuple):
split: Callable
random_bits: Callable
fold_in: Callable
tag: str
def __hash__(self) -> int:
return hash(self.tag)
def __str__(self) -> str:
return self.tag
def pprint(self):
return (pp.text(f"{self.__class__.__name__}:") +
return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") +
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
]))))
# -- PRNG key arrays --
# -- PRNG key arrays
def _check_prng_key_data(impl, key_data: jnp.ndarray):
ndim = len(impl.key_shape)
@ -94,7 +107,6 @@ def _check_prng_key_data(impl, key_data: jnp.ndarray):
f"got dtype={key_data.dtype}")
@tree_util.register_pytree_node_class
class PRNGKeyArray:
"""An array whose elements are PRNG keys.
@ -117,6 +129,7 @@ class PRNGKeyArray:
# instead of a jnp.ndarray due to tree_unflatten
if type(key_data) not in [object, bool]:
_check_prng_key_data(impl, key_data)
assert not isinstance(key_data, core.Tracer)
self.impl = impl
self._keys = key_data
@ -148,6 +161,7 @@ class PRNGKeyArray:
@property
def shape(self):
return self._shape
# TODO(frostig): simplify once we always enable_custom_prng
if config.jax_enable_custom_prng:
return self._shape
@ -191,12 +205,15 @@ class PRNGKeyArray:
return PRNGKeyArray(self.impl, self._keys[idx])
def _fold_in(self, data: int) -> 'PRNGKeyArray':
assert False
return PRNGKeyArray(self.impl, self.impl.fold_in(self._keys, data))
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
assert False
return self.impl.random_bits(self._keys, bit_width, shape)
def _split(self, num: int) -> 'PRNGKeyArray':
assert False
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
def reshape(self, newshape, order=None):
@ -223,6 +240,9 @@ class PRNGKeyArray:
return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))
def __repr__(self):
return f'{self.__class__.__name__}[{self.impl.tag}] {{ {self._keys} }}'
def pprint(self):
arr_shape = self._shape
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
pp_impl = pp.text('impl = ') + self.impl.pprint()
@ -232,11 +252,242 @@ class PRNGKeyArray:
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
return PRNGKeyArray(impl, impl.seed(seed))
return random_seed(seed, impl=impl)
_register_stackable(PRNGKeyArray)
# -- threefry2x32 PRNG implementation --
class KeyTy:
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
def __init__(self, impl):
self.impl = impl
@property
def name(self) -> str:
return f'key<{self.impl.tag}>'
def __repr__(self) -> str:
return self.name
def __eq__(self, other):
return type(other) is KeyTy and self.impl is other.impl
def __hash__(self) -> int:
return hash((self.__class__, self.impl))
# handlers
@staticmethod
def aval_to_ir_types(aval):
phys_aval = core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
jnp.dtype('uint32'))
return mlir.aval_to_ir_types(phys_aval)
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return PRNGKeyArray(aval.dtype.impl, buf)
return handler
# eltype-polymorphic primitive lowering rules
@staticmethod
def empty_mlir(ctx):
aval_out, = ctx.aval_out
return mlir.ir_constants(np.empty(aval_out.dtype.impl.key_shape,
dtype=np.dtype('uint32')))
@staticmethod
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
aval_out, = ctx.aval_out
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [mlir.ir_constant(np.zeros(0, dtype))] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, *key_shape))
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
@staticmethod
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
aval_out, = ctx.aval_out
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
key_shape = aval_out.dtype.impl.key_shape
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
start_indices = (*start_indices, *zeros)
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results
@staticmethod
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
if dyn_shape: raise NotImplementedError
aval_out, = ctx.avals_out
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)).results
@staticmethod
def transpose_mlir(ctx, x, *, permutation):
aval_out, = ctx.avals_out
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
perm = [*permutation, *trailing_dims]
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
core.custom_eltypes.add(KeyTy)
def key_shaped_array(impl, shape):
return core.ShapedArray(shape, KeyTy(impl))
def key_aval_to_raw_aval(key_array_aval):
shape = (*key_array_aval.shape, *key_array_aval.dtype.impl.key_shape)
return core.ShapedArray(shape, np.dtype('uint32'))
core.pytype_aval_mappings[PRNGKeyArray] = (
lambda x: key_shaped_array(x.impl, x._shape))
xla.pytype_aval_mappings[PRNGKeyArray] = (
lambda x: key_shaped_array(x.impl, x._shape))
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
def device_put_key_array(x: PRNGKeyArray, device):
return dispatch._device_put_array(x.unsafe_raw_array(), device)
dispatch.device_put_handlers[PRNGKeyArray] = device_put_key_array
def key_array_constant_handler(val, canonicalize_dtypes):
return mlir._device_array_constant_handler(
val.unsafe_raw_array(), canonicalize_dtypes)
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
# -- primitives
def vmap_n(n, f):
for _ in range(n):
f = jax.vmap(f)
return f
def random_seed(seeds, impl):
return random_seed_p.bind(seeds, impl=impl)
random_seed_p = core.Primitive('random_seed')
batching.defvectorized(random_seed_p)
@random_seed_p.def_abstract_eval
def random_seed_abstract_eval(seeds_aval, *, impl):
return key_shaped_array(impl, seeds_aval.shape)
@random_seed_p.def_impl
def random_seed_impl(seeds, *, impl):
seed = vmap_n(seeds.ndim, impl.seed)
return PRNGKeyArray(impl, seed(seeds))
def random_seed_lowering(ctx, seeds, *, impl):
aval, = ctx.avals_in
seed = vmap_n(aval.ndim, impl.seed)
seed_lowering = mlir.lower_fun(seed, multiple_results=False)
ctx_new = ctx.replace(avals_out=map(key_aval_to_raw_aval, ctx.avals_out))
out = seed_lowering(ctx_new, seeds)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
mlir.register_lowering(random_seed_p, random_seed_lowering)
def random_split(keys, count):
return random_split_p.bind(keys, count=count)
random_split_p = core.Primitive('random_split')
batching.defvectorized(random_split_p)
@random_split_p.def_abstract_eval
def random_split_abstract_eval(keys_aval, *, count):
return key_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count))
@random_split_p.def_impl
def random_split_impl(keys, *, count):
impl = keys.impl
split = vmap_n(keys.ndim, impl.split)
return PRNGKeyArray(impl, split(keys.unsafe_raw_array(), count))
def random_split_lowering(ctx, keys, *, count):
aval, = ctx.avals_in
impl = aval.dtype.impl
split = vmap_n(aval.ndim, impl.split)
split_lowering = mlir.lower_fun(split, multiple_results=False)
ctx_new = ctx.replace(avals_in=[key_aval_to_raw_aval(aval)],
avals_out=map(key_aval_to_raw_aval, ctx.avals_out))
out = split_lowering(ctx_new, keys)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
mlir.register_lowering(random_split_p, random_split_lowering)
def random_fold_in(keys, msgs):
return random_fold_in_p.bind(keys, msgs)
random_fold_in_p = core.Primitive('random_fold_in')
batching.defvectorized(random_fold_in_p)
@random_fold_in_p.def_abstract_eval
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
return keys_aval
@random_fold_in_p.def_impl
def random_fold_in_impl(keys, msgs):
impl = keys.impl
fold_in = vmap_n(keys.ndim, impl.fold_in)
return PRNGKeyArray(impl, fold_in(keys.unsafe_raw_array(), msgs))
def random_fold_in_lowering(ctx, keys, msgs):
keys_aval, msgs_aval = ctx.avals_in
impl = keys_aval.dtype.impl
fold_in = vmap_n(keys_aval.ndim, impl.fold_in)
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
ctx_new = ctx.replace(avals_in=[key_aval_to_raw_aval(keys_aval), msgs_aval],
avals_out=map(key_aval_to_raw_aval, ctx.avals_out))
out = fold_in_lowering(ctx_new, keys, msgs)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
def random_bits(keys, bit_width, shape):
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape)
random_bits_p = core.Primitive('random_bits')
batching.defvectorized(random_bits_p)
@random_bits_p.def_abstract_eval
def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
out_shape = (*keys_aval.shape, *shape)
out_dtype = dtypes.dtype(f'uint{bit_width}')
return core.ShapedArray(out_shape, out_dtype)
@random_bits_p.def_impl
def random_bits_impl(keys, *, bit_width, shape):
impl = keys.impl
bits = vmap_n(keys.ndim, lambda k: impl.random_bits(k, bit_width, shape))
return bits(keys.unsafe_raw_array())
def random_bits_lowering(ctx, keys, *, bit_width, shape):
aval, = ctx.avals_in
impl = aval.dtype.impl
bits = vmap_n(aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
ctx_new = ctx.replace(avals_in=[key_aval_to_raw_aval(aval)])
out = bits_lowering(ctx_new, keys)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
mlir.register_lowering(random_bits_p, random_bits_lowering)
# -- threefry2x32 PRNG implementation
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
@ -424,6 +675,7 @@ mlir.register_lowering(
platform='rocm')
# TODO(frostig): no longer need to jit?
@partial(jit, inline=True)
def threefry_2x32(keypair, count):
"""Apply the Threefry 2x32 hash.
@ -461,6 +713,7 @@ def threefry_2x32(keypair, count):
def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
return _threefry_split(key, int(num)) # type: ignore
# TODO(frostig): no longer need to jit?
@partial(jit, static_argnums=(1,), inline=True)
def _threefry_split(key, num) -> jnp.ndarray:
counts = lax.iota(np.uint32, num * 2)
@ -470,11 +723,13 @@ def _threefry_split(key, num) -> jnp.ndarray:
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
return _threefry_fold_in(key, jnp.uint32(data))
# TODO(frostig): no longer need to jit?
@partial(jit, inline=True)
def _threefry_fold_in(key, data):
return threefry_2x32(key, threefry_seed(data))
# TODO(frostig): no longer need to jit?
@partial(jit, static_argnums=(1, 2), inline=True)
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
@ -537,10 +792,11 @@ threefry_prng_impl = PRNGImpl(
seed=threefry_seed,
split=threefry_split,
random_bits=threefry_random_bits,
fold_in=threefry_fold_in)
fold_in=threefry_fold_in,
tag='fry')
# -- RngBitGenerator PRNG implementation --
# -- RngBitGenerator PRNG implementation
# This code is experimental!
# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
@ -572,7 +828,8 @@ rbg_prng_impl = PRNGImpl(
seed=_rbg_seed,
split=_rbg_split,
random_bits=_rbg_random_bits,
fold_in=_rbg_fold_in)
fold_in=_rbg_fold_in,
tag='rbg')
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
# treat 10 iterations of random bits as a 'hash function'
@ -588,4 +845,5 @@ unsafe_rbg_prng_impl = PRNGImpl(
seed=_rbg_seed,
split=_unsafe_rbg_split,
random_bits=_rbg_random_bits,
fold_in=_unsafe_rbg_fold_in)
fold_in=_unsafe_rbg_fold_in,
tag='urbg')

View File

@ -61,20 +61,14 @@ def _isnan(x):
def _check_prng_key(key):
# TODO(frostig): remove once we always enable_custom_prng
if type(key) is prng.PRNGKeyArray:
if config.jax_enable_custom_prng:
return key, False
elif _arraylike(key):
if config.jax_enable_custom_prng:
warnings.warn(
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
return prng.PRNGKeyArray(default_prng_impl(), key), True
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')
def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng
return key
assert type(key) is prng.PRNGKeyArray, type(key)
if config.jax_enable_custom_prng:
return key
@ -82,8 +76,7 @@ def _return_prng_keys(was_wrapped, key):
return key.unsafe_raw_array() if was_wrapped else key
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
key, _ = _check_prng_key(key)
return key._random_bits(bit_width, shape)
return prng.random_bits(key, bit_width=bit_width, shape=shape)
PRNG_IMPLS = {
@ -158,7 +151,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray:
def _fold_in(key: KeyArray, data: int) -> KeyArray:
# Alternative to fold_in() to use within random samplers.
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
return key._fold_in(jnp.uint32(data))
return prng.random_bits(key, jnp.uint32(data))
def fold_in(key: KeyArray, data: int) -> KeyArray:
"""Folds in data to a PRNG key to form a new PRNG key.
@ -177,7 +170,7 @@ def fold_in(key: KeyArray, data: int) -> KeyArray:
def _split(key: KeyArray, num: int = 2) -> KeyArray:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split() once we always enable_custom_prng
return key._split(num)
return prng.random_split(key, count=num)
def split(key: KeyArray, num: int = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.

View File

@ -1096,6 +1096,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
else:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
out, tokens = jaxpr_subcomp(
ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts),