2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-06-03 21:55:39 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
import abc
|
2021-06-03 21:55:39 -07:00
|
|
|
from functools import partial
|
2022-09-08 13:45:06 -07:00
|
|
|
import operator as op
|
2022-08-22 13:56:50 -07:00
|
|
|
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
|
2021-06-03 21:55:39 -07:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2022-05-27 11:12:39 -07:00
|
|
|
import jax
|
2021-06-03 21:55:39 -07:00
|
|
|
from jax import lax
|
|
|
|
from jax import core
|
|
|
|
from jax import numpy as jnp
|
2021-08-15 08:09:30 -07:00
|
|
|
from jax.config import config
|
2021-09-24 12:13:24 +01:00
|
|
|
from jax.dtypes import float0
|
2022-08-22 13:56:50 -07:00
|
|
|
from jax.interpreters import ad
|
2021-06-03 21:55:39 -07:00
|
|
|
from jax.interpreters import batching
|
2022-04-06 13:22:25 -07:00
|
|
|
from jax.interpreters import mlir
|
2022-08-22 13:56:50 -07:00
|
|
|
from jax.interpreters import pxla
|
2021-06-03 21:55:39 -07:00
|
|
|
from jax.interpreters import xla
|
2022-09-27 10:06:10 -07:00
|
|
|
from jax._src.sharding import (
|
2022-09-08 08:49:12 -07:00
|
|
|
MeshPspecSharding, PmapSharding, OpShardingSharding)
|
2022-08-22 13:56:50 -07:00
|
|
|
|
|
|
|
from jax._src import dispatch
|
|
|
|
from jax._src import dtypes
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src.api import jit, vmap
|
2022-03-07 12:25:01 -08:00
|
|
|
from jax._src.lax import lax as lax_internal
|
2022-08-22 13:56:50 -07:00
|
|
|
from jax._src.lax import utils as lax_utils
|
2022-04-06 13:22:25 -07:00
|
|
|
from jax._src.lib.mlir.dialects import mhlo
|
2022-08-22 13:56:50 -07:00
|
|
|
from jax._src.numpy import lax_numpy
|
2021-09-24 22:08:42 -04:00
|
|
|
import jax._src.pretty_printer as pp
|
2022-08-22 13:56:50 -07:00
|
|
|
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
2022-05-05 10:54:53 -07:00
|
|
|
from jax._src.lib import gpu_prng
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
|
2021-06-03 21:55:39 -07:00
|
|
|
|
|
|
|
UINT_DTYPES = {
|
|
|
|
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
# -- PRNG implementation interface
|
2021-06-03 21:55:39 -07:00
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
class PRNGImpl(NamedTuple):
|
|
|
|
"""Specifies PRNG key shape and operations.
|
|
|
|
|
|
|
|
A PRNG implementation is determined by a key type ``K`` and a
|
|
|
|
collection of functions that operate on such keys. The key type
|
|
|
|
``K`` is an array type with element type uint32 and shape specified
|
|
|
|
by ``key_shape``. The type signature of each operations is::
|
|
|
|
|
|
|
|
seed :: int[] -> K
|
|
|
|
fold_in :: K -> int[] -> K
|
|
|
|
split[n] :: K -> K[n]
|
|
|
|
random_bits[shape, bit_width] :: K -> uint<bit_width>[shape]
|
|
|
|
|
|
|
|
A PRNG implementation is adapted to an array-like object of keys
|
|
|
|
``K`` by the ``PRNGKeyArray`` class, which should be created via the
|
|
|
|
``seed_with_impl`` function.
|
|
|
|
"""
|
|
|
|
key_shape: core.Shape
|
|
|
|
seed: Callable
|
|
|
|
split: Callable
|
|
|
|
random_bits: Callable
|
|
|
|
fold_in: Callable
|
2022-08-22 13:56:50 -07:00
|
|
|
tag: str = '?'
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return hash(self.tag)
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
return self.tag
|
2021-06-08 11:16:33 -07:00
|
|
|
|
|
|
|
def pprint(self):
|
2022-08-22 13:56:50 -07:00
|
|
|
return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") +
|
2021-09-24 22:08:42 -04:00
|
|
|
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
|
|
|
|
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
|
|
|
|
]))))
|
2021-06-08 11:16:33 -07:00
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
# -- PRNG key arrays
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-03-04 10:49:29 -08:00
|
|
|
def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
2021-06-08 11:16:33 -07:00
|
|
|
ndim = len(impl.key_shape)
|
2022-03-04 10:49:29 -08:00
|
|
|
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
|
|
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
|
|
|
|
f"to have ndim, shape, and dtype attributes. Got {key_data}")
|
|
|
|
if key_data.ndim < 1:
|
|
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected "
|
|
|
|
f"key_data.ndim >= 1; got ndim={key_data.ndim}")
|
|
|
|
if key_data.shape[-ndim:] != impl.key_shape:
|
|
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.shape to "
|
|
|
|
f"end with {impl.key_shape}; got shape={key_data.shape} for impl={impl}")
|
|
|
|
if key_data.dtype not in [np.uint32, float0]:
|
|
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
|
|
|
|
f"got dtype={key_data.dtype}")
|
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
class PRNGKeyArrayMeta(abc.ABCMeta):
|
|
|
|
"""Metaclass for overriding PRNGKeyArray isinstance checks."""
|
|
|
|
|
|
|
|
def __instancecheck__(self, instance):
|
|
|
|
try:
|
2022-09-08 13:45:06 -07:00
|
|
|
return (isinstance(instance.aval, core.ShapedArray) and
|
2022-08-22 13:56:50 -07:00
|
|
|
type(instance.aval.dtype) is KeyTy)
|
|
|
|
except AttributeError:
|
|
|
|
super().__instancecheck__(instance)
|
|
|
|
|
|
|
|
|
|
|
|
class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
2021-06-08 11:16:33 -07:00
|
|
|
"""An array whose elements are PRNG keys.
|
|
|
|
|
|
|
|
This class lifts the definition of a PRNG, provided in the form of a
|
|
|
|
``PRNGImpl``, into an array-like pytree class. Instances of this
|
|
|
|
class behave like an array whose base elements are keys, hiding the
|
|
|
|
fact that keys are typically arrays (of ``uint32`` dtype) themselves.
|
|
|
|
|
|
|
|
PRNGKeyArrays are also restricted relative to JAX arrays in that
|
|
|
|
they do not expose arithmetic operations. They instead expose
|
|
|
|
wrapper methods around the PRNG implementation functions (``split``,
|
|
|
|
``random_bits``, ``fold_in``).
|
|
|
|
"""
|
|
|
|
|
|
|
|
impl: PRNGImpl
|
2022-08-22 13:56:50 -07:00
|
|
|
_base_array: jnp.ndarray
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def __init__(self, impl, key_data: Any):
|
|
|
|
assert not isinstance(key_data, core.Tracer)
|
|
|
|
_check_prng_key_data(impl, key_data)
|
2021-06-08 11:16:33 -07:00
|
|
|
self.impl = impl
|
2022-08-22 13:56:50 -07:00
|
|
|
self._base_array = key_data
|
2021-10-11 21:21:37 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
# TODO(frostig): rename to unsafe_base_array, or just offer base_array attr?
|
2021-10-11 21:21:37 -07:00
|
|
|
def unsafe_raw_array(self):
|
|
|
|
"""Access the raw numerical array that carries underlying key data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A uint32 JAX array whose leading dimensions are ``self.shape``.
|
|
|
|
"""
|
2022-08-22 13:56:50 -07:00
|
|
|
return self._base_array
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def block_until_ready(self):
|
|
|
|
_ = self._base_array.block_until_ready()
|
|
|
|
return self
|
2021-08-15 08:09:30 -07:00
|
|
|
|
2021-09-10 18:29:39 -07:00
|
|
|
@property
|
2021-06-08 11:16:33 -07:00
|
|
|
def shape(self):
|
2022-08-22 13:56:50 -07:00
|
|
|
return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape)
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-02-16 20:32:17 -08:00
|
|
|
@property
|
|
|
|
def ndim(self):
|
|
|
|
return len(self.shape)
|
|
|
|
|
2022-08-30 14:05:22 -07:00
|
|
|
@property
|
|
|
|
def dtype(self):
|
|
|
|
return KeyTy(self.impl)
|
|
|
|
|
2022-09-08 13:45:06 -07:00
|
|
|
_device = property(op.attrgetter('_base_array._device'))
|
|
|
|
_committed = property(op.attrgetter('_base_array._committed'))
|
|
|
|
sharding = property(op.attrgetter('_base_array.sharding'))
|
|
|
|
|
2021-08-15 08:09:30 -07:00
|
|
|
def _is_scalar(self):
|
2021-06-08 11:16:33 -07:00
|
|
|
base_ndim = len(self.impl.key_shape)
|
2022-08-22 13:56:50 -07:00
|
|
|
return self._base_array.ndim == base_ndim
|
2021-08-15 08:09:30 -07:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
if self._is_scalar():
|
|
|
|
raise TypeError('len() of unsized object')
|
2022-08-22 13:56:50 -07:00
|
|
|
return len(self._base_array)
|
2021-08-15 08:09:30 -07:00
|
|
|
|
|
|
|
def __iter__(self) -> Iterator['PRNGKeyArray']:
|
|
|
|
if self._is_scalar():
|
2022-08-22 13:56:50 -07:00
|
|
|
raise TypeError('iteration over a 0-d key array')
|
|
|
|
# TODO(frostig): we may want to avoid iteration by slicing because
|
|
|
|
# a very common use of iteration is `k1, k2 = split(key)`, and
|
|
|
|
# slicing/indexing may be trickier to track for linearity checking
|
|
|
|
# purposes. Maybe we can:
|
|
|
|
# * introduce an unpack primitive+traceable (also allow direct use)
|
|
|
|
# * unpack upfront into shape[0] many keyarray slices
|
|
|
|
# * return iter over these unpacked slices
|
|
|
|
# Whatever we do, we'll want to do it by overriding
|
2022-08-30 14:47:15 -07:00
|
|
|
# ShapedArray._iter when the element type is KeyTy...
|
2022-08-22 13:56:50 -07:00
|
|
|
return (PRNGKeyArray(self.impl, k) for k in iter(self._base_array))
|
|
|
|
|
|
|
|
# TODO(frostig): are all of the stackable methods below (reshape,
|
|
|
|
# concat, broadcast_to, expand_dims), and the stackable registration,
|
|
|
|
# still needed? If, with some work, none are needed, then do we want
|
|
|
|
# to remove stackables altogether? This may be the only application.
|
|
|
|
|
|
|
|
# TODO(frostig): Remove? Overwritten below in particular
|
|
|
|
def reshape(self, newshape, order=None) -> 'PRNGKeyArray':
|
|
|
|
reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order)
|
|
|
|
return PRNGKeyArray(self.impl, reshaped_base)
|
2021-10-25 15:53:52 +01:00
|
|
|
|
2022-08-01 15:48:40 -07:00
|
|
|
def concatenate(self, key_arrs, axis, dtype=None):
|
|
|
|
if dtype is not None:
|
2022-08-22 13:56:50 -07:00
|
|
|
raise ValueError(
|
|
|
|
'dtype argument not supported for concatenating PRNGKeyArray')
|
2022-02-16 20:32:17 -08:00
|
|
|
axis = canonicalize_axis(axis, self.ndim)
|
2022-08-22 13:56:50 -07:00
|
|
|
arrs = [self._base_array, *[k._base_array for k in key_arrs]]
|
2022-02-16 20:32:17 -08:00
|
|
|
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
|
2021-10-25 15:53:52 +01:00
|
|
|
|
|
|
|
def broadcast_to(self, shape):
|
2022-04-15 04:09:42 +08:00
|
|
|
if jnp.ndim(shape) == 0:
|
|
|
|
shape = (shape,)
|
2022-02-16 20:32:17 -08:00
|
|
|
new_shape = (*shape, *self.impl.key_shape)
|
2022-08-22 13:56:50 -07:00
|
|
|
return PRNGKeyArray(
|
|
|
|
self.impl, jnp.broadcast_to(self._base_array, new_shape))
|
2021-10-25 15:53:52 +01:00
|
|
|
|
2022-02-16 20:32:17 -08:00
|
|
|
def expand_dims(self, dimensions: Sequence[int]):
|
|
|
|
# follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
|
|
|
|
ndim_out = self.ndim + len(set(dimensions))
|
|
|
|
dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
|
2022-08-22 13:56:50 -07:00
|
|
|
return PRNGKeyArray(
|
|
|
|
self.impl, lax.expand_dims(self._base_array, dimensions))
|
2022-02-16 20:32:17 -08:00
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
def __repr__(self):
|
2022-08-22 13:56:50 -07:00
|
|
|
return (f'{self.__class__.__name__}[{self.impl.tag}]'
|
|
|
|
f' {{ {self._base_array} }}')
|
|
|
|
|
|
|
|
def pprint(self):
|
|
|
|
pp_keys = pp.text('shape = ') + pp.text(str(self.shape))
|
2021-09-24 22:08:42 -04:00
|
|
|
pp_impl = pp.text('impl = ') + self.impl.pprint()
|
|
|
|
return str(pp.group(
|
|
|
|
pp.text('PRNGKeyArray:') +
|
|
|
|
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
# Hollow defs only for typing purposes, overwritten below
|
|
|
|
#
|
|
|
|
# TODO(frostig): there may be a better way to do this with
|
|
|
|
# `typing.type_check_only`.
|
|
|
|
|
|
|
|
@property
|
|
|
|
def T(self) -> 'PRNGKeyArray': assert False
|
|
|
|
def __getitem__(self, _) -> 'PRNGKeyArray': assert False
|
|
|
|
def ravel(self, *_, **__) -> 'PRNGKeyArray': assert False
|
|
|
|
def squeeze(self, *_, **__) -> 'PRNGKeyArray': assert False
|
|
|
|
def swapaxes(self, *_, **__) -> 'PRNGKeyArray': assert False
|
|
|
|
def take(self, *_, **__) -> 'PRNGKeyArray': assert False
|
|
|
|
def transpose(self, *_, **__) -> 'PRNGKeyArray': assert False
|
|
|
|
def flatten(self, *_, **__) -> 'PRNGKeyArray': assert False
|
|
|
|
|
introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.
We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.
This leans heavily on our recently-introduced extended element type
capabilities.
As a consequence, `vmap`, `scan`, etc. now work.
A sample of changes made to introduce key-element-type arrays:
* Introduce a new element type (`prng.KeyTy`), with the requisite IR
type mapping and device result handlers, as well as lowering rules
for dtype-polymorphic primitive operations.
* Introduce primitives for basic RNG operations: `random_seed`,
`random_bits`, `random_split`, `random_fold_in`. These primitives
essentially delegate to the underlying PRNG implementation (directly
so in their impl rules, and by translating their staged-out form in
lowering rules).
* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
conversion from/to the base `uint32` array. We need this backwards
compatibility, and it's useful for tests.
* Introduce some `vmap`-based helpers to adapt PRNG impls (which
define basic `random_bits`, `split`, etc. on scalars) to the above
batch-polymorphic primitives. Most of the primitives are vectorized,
but `random_fold_in` is a broadcasting binary op.
* Update the `gamma` primitive rules to account for key-element-type
abstract arrays (nice simplification here).
* Give PRNG implementation short string names ("tags") for IR
pretty-printing.
* Update `lax.stop_gradient` to handle opaque dtypes.
* Fix up loop MLIR lowering, which assumed that shaped arrays of all
dtypes have the same physical shape.
* Add new tests (exercising staging, jaxprs, lowerings, ...)
A sample of changes made to rework Python-level PRNG key arrays:
* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
tracers that carry them.
* Patch (only a subset of) standard device array attributes onto PRNG
key arrays.
* Implement various conversion handlers (sharding, constant-creation,
`device_put`).
* Accept PRNG key arrays as input to `lax_numpy.transpose`.
* Update tests and rename some internals.
A sample of extra changes along the way:
* Disallow AD on key-typed arrays in the main API.
* Hoist `random_bits`'s named-shape-handling logic, which used to only
take place in the threefry PRNG's `random_bits` implementation, up
to the new `random_bits` traceable, so that we apply it consistently
across PRNG implementations.
This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.
Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-10 06:06:19 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
lax_numpy._set_device_array_base_attributes(PRNGKeyArray, include=[
|
|
|
|
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
|
|
|
|
'transpose', 'flatten', 'T'])
|
|
|
|
lax_numpy._register_stackable(PRNGKeyArray)
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(frostig): remove, rerouting callers directly to random_seed
|
2021-06-08 11:16:33 -07:00
|
|
|
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
|
2022-08-22 13:56:50 -07:00
|
|
|
return random_seed(seed, impl=impl)
|
|
|
|
|
|
|
|
|
|
|
|
def keys_shaped_array(impl, shape):
|
|
|
|
return core.ShapedArray(shape, KeyTy(impl))
|
|
|
|
|
|
|
|
def keys_aval_to_base_arr_aval(keys_aval):
|
|
|
|
shape = (*keys_aval.shape, *keys_aval.dtype.impl.key_shape)
|
|
|
|
return core.ShapedArray(shape, np.dtype('uint32'))
|
introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.
We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.
This leans heavily on our recently-introduced extended element type
capabilities.
As a consequence, `vmap`, `scan`, etc. now work.
A sample of changes made to introduce key-element-type arrays:
* Introduce a new element type (`prng.KeyTy`), with the requisite IR
type mapping and device result handlers, as well as lowering rules
for dtype-polymorphic primitive operations.
* Introduce primitives for basic RNG operations: `random_seed`,
`random_bits`, `random_split`, `random_fold_in`. These primitives
essentially delegate to the underlying PRNG implementation (directly
so in their impl rules, and by translating their staged-out form in
lowering rules).
* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
conversion from/to the base `uint32` array. We need this backwards
compatibility, and it's useful for tests.
* Introduce some `vmap`-based helpers to adapt PRNG impls (which
define basic `random_bits`, `split`, etc. on scalars) to the above
batch-polymorphic primitives. Most of the primitives are vectorized,
but `random_fold_in` is a broadcasting binary op.
* Update the `gamma` primitive rules to account for key-element-type
abstract arrays (nice simplification here).
* Give PRNG implementation short string names ("tags") for IR
pretty-printing.
* Update `lax.stop_gradient` to handle opaque dtypes.
* Fix up loop MLIR lowering, which assumed that shaped arrays of all
dtypes have the same physical shape.
* Add new tests (exercising staging, jaxprs, lowerings, ...)
A sample of changes made to rework Python-level PRNG key arrays:
* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
tracers that carry them.
* Patch (only a subset of) standard device array attributes onto PRNG
key arrays.
* Implement various conversion handlers (sharding, constant-creation,
`device_put`).
* Accept PRNG key arrays as input to `lax_numpy.transpose`.
* Update tests and rename some internals.
A sample of extra changes along the way:
* Disallow AD on key-typed arrays in the main API.
* Hoist `random_bits`'s named-shape-handling logic, which used to only
take place in the threefry PRNG's `random_bits` implementation, up
to the new `random_bits` traceable, so that we apply it consistently
across PRNG implementations.
This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.
Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-10 06:06:19 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def base_arr_shape_to_keys_shape(impl, base_arr_shape):
|
|
|
|
base_ndim = len(impl.key_shape)
|
|
|
|
return base_arr_shape[:-base_ndim]
|
2021-06-08 11:16:33 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
|
2022-08-30 13:25:49 -07:00
|
|
|
class KeyTyRules:
|
2022-08-22 13:56:50 -07:00
|
|
|
|
|
|
|
@staticmethod
|
2022-08-24 19:48:36 -07:00
|
|
|
def physical_avals(aval): # TODO(frostig): rename to `grounded_avals`
|
|
|
|
# TODO(frostig): dedup with `keys_aval_to_base_arr_aval``
|
2022-08-22 13:56:50 -07:00
|
|
|
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
|
|
|
|
jnp.dtype('uint32'))]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def aval_to_ir_types(aval):
|
2022-08-30 13:25:49 -07:00
|
|
|
phys_aval, = KeyTyRules.physical_avals(aval)
|
2022-08-22 13:56:50 -07:00
|
|
|
return mlir.aval_to_ir_types(phys_aval)
|
|
|
|
|
2022-08-31 22:53:32 -07:00
|
|
|
@staticmethod
|
|
|
|
def physical_op_sharding(aval, sharding):
|
|
|
|
op_sharding = sharding._to_xla_op_sharding(aval.ndim)
|
|
|
|
key_shape = aval.dtype.impl.key_shape
|
|
|
|
|
2022-10-26 15:08:58 -04:00
|
|
|
new_op_sharding = op_sharding.clone()
|
|
|
|
tad = list(new_op_sharding.tile_assignment_dimensions)
|
|
|
|
tad.extend([1] * len(key_shape))
|
|
|
|
new_op_sharding.tile_assignment_dimensions = tad
|
|
|
|
return new_op_sharding
|
2022-08-31 22:53:32 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
@staticmethod
|
|
|
|
def result_handler(sticky_device, aval):
|
|
|
|
def handler(_, buf):
|
|
|
|
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
|
|
|
return PRNGKeyArray(aval.dtype.impl, buf)
|
|
|
|
return handler
|
|
|
|
|
|
|
|
@staticmethod
|
2022-08-24 19:48:36 -07:00
|
|
|
def local_sharded_result_handler(aval, sharding, indices):
|
2022-08-30 13:25:49 -07:00
|
|
|
phys_aval, = KeyTyRules.physical_avals(aval)
|
2022-08-24 19:48:36 -07:00
|
|
|
key_shape = aval.dtype.impl.key_shape
|
|
|
|
|
|
|
|
# TODO(yashkatariya,frostig): remove this conditional and inline it when
|
|
|
|
# the transient config ever settles
|
|
|
|
if config.jax_array:
|
|
|
|
output_type = pxla.OutputType.Array
|
|
|
|
else:
|
|
|
|
output_type = pxla.OutputType.ShardedDeviceArray
|
2022-08-22 13:56:50 -07:00
|
|
|
phys_handler_maker = pxla.local_result_handlers[
|
2022-08-24 19:48:36 -07:00
|
|
|
(core.ShapedArray, output_type)]
|
|
|
|
|
|
|
|
# set up a grounded sharding (with a grounded sharding spec)
|
2022-08-25 17:13:33 -07:00
|
|
|
if isinstance(sharding, PmapSharding):
|
|
|
|
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
|
|
|
|
phys_sharding_spec = pxla.ShardingSpec(
|
|
|
|
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
|
|
|
|
mesh_mapping=sharding.sharding_spec.mesh_mapping)
|
|
|
|
phys_sharding = PmapSharding(devices=sharding.devices,
|
|
|
|
sharding_spec=phys_sharding_spec)
|
|
|
|
elif isinstance(sharding, MeshPspecSharding):
|
|
|
|
trailing_spec = [None] * len(key_shape)
|
|
|
|
phys_sharding = MeshPspecSharding(
|
|
|
|
sharding.mesh,
|
|
|
|
pxla.PartitionSpec(*sharding.spec, *trailing_spec))
|
|
|
|
else:
|
|
|
|
assert False, f'impossible sharding {sharding} in local sharded result handler'
|
2022-08-24 19:48:36 -07:00
|
|
|
|
|
|
|
# set up grounded indices
|
|
|
|
trailing_inds = [slice(None)] * len(key_shape)
|
|
|
|
phys_indices = [(*inds, *trailing_inds) for inds in indices]
|
|
|
|
|
|
|
|
# make a physical handler
|
|
|
|
phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices)
|
|
|
|
|
|
|
|
# set up a handler that calls the physical one and wraps back up
|
2022-08-22 13:56:50 -07:00
|
|
|
def handler(bufs):
|
|
|
|
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
2022-08-24 19:48:36 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
return handler
|
|
|
|
|
2022-08-24 19:48:36 -07:00
|
|
|
@staticmethod
|
2022-08-31 22:53:32 -07:00
|
|
|
def global_sharded_result_handler(aval, out_sharding, committed,
|
|
|
|
is_out_sharding_from_xla):
|
2022-08-30 13:25:49 -07:00
|
|
|
phys_aval, = KeyTyRules.physical_avals(aval)
|
2022-08-25 12:22:42 -07:00
|
|
|
key_shape = aval.dtype.impl.key_shape
|
|
|
|
|
|
|
|
# TODO(yashkatariya,frostig): remove this conditional and inline it when
|
|
|
|
# the transient config ever settles
|
|
|
|
if config.jax_array:
|
|
|
|
output_type = pxla.OutputType.Array
|
|
|
|
else:
|
|
|
|
output_type = pxla.OutputType.GlobalDeviceArray
|
|
|
|
|
|
|
|
phys_handler_maker = pxla.global_result_handlers[
|
|
|
|
(core.ShapedArray, output_type)]
|
|
|
|
|
2022-09-08 08:49:12 -07:00
|
|
|
if dispatch.is_single_device_sharding(out_sharding):
|
2022-08-25 12:22:42 -07:00
|
|
|
phys_sharding = out_sharding
|
|
|
|
elif isinstance(out_sharding, MeshPspecSharding):
|
|
|
|
trailing_spec = [None] * len(key_shape)
|
|
|
|
phys_sharding = MeshPspecSharding(
|
|
|
|
out_sharding.mesh,
|
|
|
|
pxla.PartitionSpec(*out_sharding.spec, *trailing_spec))
|
|
|
|
else:
|
2022-08-31 22:53:32 -07:00
|
|
|
if is_out_sharding_from_xla:
|
|
|
|
phys_sharding = out_sharding
|
|
|
|
else:
|
|
|
|
phys_sharding = OpShardingSharding(
|
|
|
|
out_sharding._device_assignment,
|
|
|
|
KeyTyRules.physical_op_sharding(aval, out_sharding))
|
|
|
|
|
|
|
|
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
|
|
|
|
is_out_sharding_from_xla)
|
2022-08-25 12:22:42 -07:00
|
|
|
def handler(bufs):
|
|
|
|
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
|
|
|
return handler
|
2022-08-24 19:48:36 -07:00
|
|
|
|
2022-08-30 14:47:15 -07:00
|
|
|
# element-type-polymorphic primitive lowering rules
|
2022-08-22 13:56:50 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def empty_mlir(ctx):
|
|
|
|
aval_out, = ctx.avals_out
|
2022-08-23 15:04:07 -07:00
|
|
|
return mlir.ir_constants(np.zeros(aval_out.dtype.impl.key_shape,
|
2022-08-22 13:56:50 -07:00
|
|
|
dtype=np.dtype('uint32')))
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def slice_mlir(ctx, x, start_indices, limit_indices, strides):
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
key_shape = aval_out.dtype.impl.key_shape
|
|
|
|
trailing_zeros = [0] * len(key_shape)
|
|
|
|
trailing_ones = [1] * len(key_shape)
|
|
|
|
start_indices = (*start_indices, *trailing_zeros)
|
|
|
|
limit_indices = (*limit_indices, *key_shape)
|
|
|
|
strides = (*strides, *trailing_ones)
|
|
|
|
return mhlo.SliceOp(x,
|
|
|
|
mlir.dense_int_elements(start_indices),
|
|
|
|
mlir.dense_int_elements(limit_indices),
|
|
|
|
mlir.dense_int_elements(strides)).results
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
|
|
|
key_shape = aval_out.dtype.impl.key_shape
|
|
|
|
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
|
|
|
|
start_indices = (*start_indices, *trailing_zeros)
|
|
|
|
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, *key_shape))
|
|
|
|
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
|
|
|
key_shape = aval_out.dtype.impl.key_shape
|
|
|
|
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
|
|
|
|
start_indices = (*start_indices, *zeros)
|
|
|
|
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
|
|
|
start_indices).results
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
|
|
|
if dyn_shape: raise NotImplementedError
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
key_shape = aval_out.dtype.impl.key_shape
|
|
|
|
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
|
|
|
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
|
|
|
|
return mhlo.BroadcastInDimOp(
|
|
|
|
mlir.aval_to_ir_type(aval_out), x,
|
|
|
|
mlir.dense_int_elements(broadcast_dimensions)).results
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def transpose_mlir(ctx, x, *, permutation):
|
|
|
|
aval_out, = ctx.avals_out
|
|
|
|
key_shape = aval_out.dtype.impl.key_shape
|
|
|
|
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
|
|
|
perm = [*permutation, *trailing_dims]
|
|
|
|
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def gather_mlir(ctx, x, indices, *,
|
|
|
|
dimension_numbers, slice_sizes, unique_indices,
|
|
|
|
indices_are_sorted, mode, fill_value):
|
|
|
|
aval_x, aval_indices = ctx.avals_in
|
|
|
|
aval_y, = ctx.avals_out
|
|
|
|
key_shape = aval_x.dtype.impl.key_shape
|
|
|
|
trailing_offset_dims = [aval_y.ndim + i for i in range(len(key_shape))]
|
|
|
|
dimension_numbers = dimension_numbers._replace(
|
|
|
|
offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims))
|
|
|
|
slice_sizes = (*slice_sizes, *key_shape)
|
|
|
|
gather_lower = partial(
|
|
|
|
lax_internal.slicing._gather_lower, dimension_numbers=dimension_numbers,
|
|
|
|
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
|
|
|
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
|
|
|
return mlir.delegate_lowering(
|
|
|
|
ctx, gather_lower, x, indices,
|
|
|
|
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices],
|
|
|
|
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
|
|
|
|
|
2022-08-30 13:25:49 -07:00
|
|
|
|
|
|
|
class KeyTy:
|
|
|
|
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
|
|
|
|
_rules = KeyTyRules
|
|
|
|
|
|
|
|
def __init__(self, impl):
|
|
|
|
self.impl = impl
|
|
|
|
|
|
|
|
@property
|
|
|
|
def name(self) -> str:
|
|
|
|
return f'key<{self.impl.tag}>'
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2022-08-30 21:08:54 -07:00
|
|
|
return type(other) is KeyTy and self.impl == other.impl
|
2022-08-30 13:25:49 -07:00
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
|
|
|
return hash((self.__class__, self.impl))
|
|
|
|
|
|
|
|
|
2022-08-30 14:47:15 -07:00
|
|
|
core.opaque_dtypes.add(KeyTy)
|
2022-08-22 13:56:50 -07:00
|
|
|
|
|
|
|
|
|
|
|
core.pytype_aval_mappings[PRNGKeyArray] = (
|
|
|
|
lambda x: keys_shaped_array(x.impl, x.shape))
|
|
|
|
|
|
|
|
xla.pytype_aval_mappings[PRNGKeyArray] = (
|
|
|
|
lambda x: keys_shaped_array(x.impl, x.shape))
|
|
|
|
|
|
|
|
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
|
|
|
|
|
|
|
def device_put_key_array(x: PRNGKeyArray, device):
|
|
|
|
return dispatch.device_put(x.unsafe_raw_array(), device)
|
|
|
|
dispatch.device_put_handlers[PRNGKeyArray] = device_put_key_array
|
|
|
|
|
|
|
|
def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, mode):
|
2022-08-31 22:53:32 -07:00
|
|
|
# TODO(frostig): Remove the need for `core.get_aval`.
|
|
|
|
key_shape = core.get_aval(x).dtype.impl.key_shape
|
2022-08-22 13:56:50 -07:00
|
|
|
arr = x.unsafe_raw_array()
|
2022-08-31 22:53:32 -07:00
|
|
|
|
|
|
|
# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
|
|
|
|
# sharded. This is only true when enable_custom_prng is True.
|
|
|
|
trailing_inds = [slice(None)] * len(key_shape)
|
|
|
|
phys_indices = [(*inds, *trailing_inds) for inds in indices]
|
|
|
|
return pxla.shard_arg_handlers[type(arr)](arr, devices, phys_indices, mode)
|
2022-08-22 13:56:50 -07:00
|
|
|
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
|
|
|
|
|
2022-08-31 22:53:32 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def key_array_constant_handler(x, canonicalize_dtypes):
|
|
|
|
arr = x.unsafe_raw_array()
|
|
|
|
return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes)
|
|
|
|
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
|
|
|
|
|
|
|
|
|
|
|
|
# -- primitives
|
|
|
|
|
|
|
|
def iterated_vmap_unary(n, f):
|
|
|
|
for _ in range(n):
|
|
|
|
f = jax.vmap(f)
|
|
|
|
return f
|
|
|
|
|
|
|
|
# TODO(frostig): Revise the following two functions? These basically
|
|
|
|
# undo the singleton dimensions added by `batching.defbroadcasting`.
|
|
|
|
# It works, but introduces some possibly-redundant squeezes. Can we
|
|
|
|
# borrow from other broadcasting primitives instead?
|
|
|
|
|
|
|
|
def squeeze_vmap(f, left):
|
|
|
|
def squeeze_vmap_f(x, y):
|
|
|
|
if left:
|
|
|
|
x = jnp.squeeze(x, axis=0)
|
|
|
|
axes = (None, 0)
|
|
|
|
else:
|
|
|
|
y = jnp.squeeze(y, axis=0)
|
|
|
|
axes = (0, None)
|
|
|
|
return jax.vmap(f, in_axes=axes, out_axes=0)(x, y)
|
|
|
|
return squeeze_vmap_f
|
|
|
|
|
|
|
|
def iterated_vmap_binary_bcast(shape1, shape2, f):
|
|
|
|
ndim1, ndim2 = len(shape1), len(shape2)
|
|
|
|
if ndim1 == ndim2 == 0:
|
|
|
|
return f
|
|
|
|
if 0 in [ndim1, ndim2]:
|
|
|
|
if ndim1 == 0:
|
|
|
|
return lambda x, y: iterated_vmap_unary(ndim2, lambda y: f(x, y))(y)
|
|
|
|
else:
|
|
|
|
return lambda x, y: iterated_vmap_unary(ndim1, lambda x: f(x, y))(x)
|
|
|
|
assert len(shape1) == len(shape2)
|
|
|
|
for sz1, sz2 in reversed(zip(shape1, shape2)):
|
|
|
|
if sz1 == sz2:
|
|
|
|
f = jax.vmap(f, out_axes=0)
|
|
|
|
else:
|
|
|
|
assert sz1 == 1 or sz2 == 1, (sz1, sz2)
|
|
|
|
f = squeeze_vmap(f, sz1 == 1)
|
|
|
|
return f
|
|
|
|
|
|
|
|
|
|
|
|
def random_seed(seeds, impl):
|
|
|
|
# Avoid overflow error in X32 mode by first converting ints to int64.
|
|
|
|
# This breaks JIT invariance for large ints, but supports the common
|
|
|
|
# use-case of instantiating with Python hashes in X32 mode.
|
|
|
|
if isinstance(seeds, int):
|
|
|
|
seeds_arr = jnp.asarray(np.int64(seeds))
|
|
|
|
else:
|
|
|
|
seeds_arr = jnp.asarray(seeds)
|
|
|
|
return random_seed_p.bind(seeds_arr, impl=impl)
|
|
|
|
|
|
|
|
random_seed_p = core.Primitive('random_seed')
|
|
|
|
ad.defjvp_zero(random_seed_p)
|
|
|
|
batching.defvectorized(random_seed_p)
|
|
|
|
|
|
|
|
@random_seed_p.def_abstract_eval
|
|
|
|
def random_seed_abstract_eval(seeds_aval, *, impl):
|
|
|
|
return keys_shaped_array(impl, seeds_aval.shape)
|
|
|
|
|
|
|
|
@random_seed_p.def_impl
|
|
|
|
def random_seed_impl(seeds, *, impl):
|
|
|
|
base_arr = random_seed_impl_base(seeds, impl=impl)
|
|
|
|
return PRNGKeyArray(impl, base_arr)
|
|
|
|
|
|
|
|
def random_seed_impl_base(seeds, *, impl):
|
|
|
|
seed = iterated_vmap_unary(seeds.ndim, impl.seed)
|
|
|
|
return seed(seeds)
|
|
|
|
|
|
|
|
def random_seed_lowering(ctx, seeds, *, impl):
|
|
|
|
aval, = ctx.avals_in
|
|
|
|
seed = iterated_vmap_unary(aval.ndim, impl.seed)
|
|
|
|
seed_lowering = mlir.lower_fun(seed, multiple_results=False)
|
|
|
|
return mlir.delegate_lowering(
|
|
|
|
ctx, seed_lowering, seeds,
|
|
|
|
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
|
|
|
|
|
|
|
|
mlir.register_lowering(random_seed_p, random_seed_lowering)
|
|
|
|
|
|
|
|
|
|
|
|
def random_split(keys, count):
|
|
|
|
return random_split_p.bind(keys, count=count)
|
|
|
|
|
|
|
|
random_split_p = core.Primitive('random_split')
|
|
|
|
ad.defjvp_zero(random_split_p)
|
|
|
|
batching.defvectorized(random_split_p)
|
|
|
|
|
|
|
|
@random_split_p.def_abstract_eval
|
|
|
|
def random_split_abstract_eval(keys_aval, *, count):
|
|
|
|
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count))
|
|
|
|
|
|
|
|
@random_split_p.def_impl
|
|
|
|
def random_split_impl(keys, *, count):
|
|
|
|
base_arr = random_split_impl_base(
|
|
|
|
keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
|
|
|
|
return PRNGKeyArray(keys.impl, base_arr)
|
|
|
|
|
|
|
|
def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
|
|
|
|
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
|
|
|
|
return split(base_arr)
|
|
|
|
|
|
|
|
def random_split_lowering(ctx, keys, *, count):
|
|
|
|
aval, = ctx.avals_in
|
|
|
|
impl = aval.dtype.impl
|
|
|
|
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, count))
|
|
|
|
split_lowering = mlir.lower_fun(split, multiple_results=False)
|
|
|
|
return mlir.delegate_lowering(
|
|
|
|
ctx, split_lowering, keys,
|
|
|
|
avals_in=[keys_aval_to_base_arr_aval(aval)],
|
|
|
|
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
|
|
|
|
|
|
|
|
mlir.register_lowering(random_split_p, random_split_lowering)
|
|
|
|
|
|
|
|
|
|
|
|
def random_fold_in(keys, msgs):
|
|
|
|
return random_fold_in_p.bind(keys, jnp.asarray(msgs))
|
|
|
|
|
|
|
|
random_fold_in_p = core.Primitive('random_fold_in')
|
|
|
|
ad.defjvp_zero(random_fold_in_p)
|
|
|
|
batching.defbroadcasting(random_fold_in_p)
|
|
|
|
|
|
|
|
@random_fold_in_p.def_abstract_eval
|
|
|
|
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
|
|
|
|
shape = lax_internal.broadcasting_shape_rule(
|
|
|
|
'random_fold_in', keys_aval, msgs_aval)
|
|
|
|
named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval)
|
|
|
|
return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape)
|
|
|
|
|
|
|
|
@random_fold_in_p.def_impl
|
|
|
|
def random_fold_in_impl(keys, msgs):
|
|
|
|
base_arr = random_fold_in_impl_base(
|
|
|
|
keys.impl, keys.unsafe_raw_array(), msgs, keys.shape)
|
|
|
|
return PRNGKeyArray(keys.impl, base_arr)
|
|
|
|
|
|
|
|
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
|
|
|
|
fold_in = iterated_vmap_binary_bcast(
|
|
|
|
keys_shape, np.shape(msgs), impl.fold_in)
|
|
|
|
return fold_in(base_arr, msgs)
|
|
|
|
|
|
|
|
def random_fold_in_lowering(ctx, keys, msgs):
|
|
|
|
keys_aval, msgs_aval = ctx.avals_in
|
|
|
|
impl = keys_aval.dtype.impl
|
|
|
|
fold_in = iterated_vmap_binary_bcast(
|
|
|
|
keys_aval.shape, msgs_aval.shape, impl.fold_in)
|
|
|
|
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
|
|
|
|
return mlir.delegate_lowering(
|
|
|
|
ctx, fold_in_lowering, keys, msgs,
|
|
|
|
avals_in=[keys_aval_to_base_arr_aval(keys_aval), msgs_aval],
|
|
|
|
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
|
|
|
|
|
|
|
|
mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
|
|
|
|
|
|
|
|
|
|
|
|
def random_bits(keys, bit_width, shape):
|
|
|
|
shape = core.as_named_shape(shape)
|
|
|
|
for name, size in shape.named_items:
|
|
|
|
# TODO(frostig,mattjj,apaszke): Is this real_size check necessary,
|
|
|
|
# and is it meant to raise a user-facing ValueError? Should it be
|
|
|
|
# an `assert` (or RuntimeError) instead? Why do we check it in
|
|
|
|
# calls to `random_bits` instead of a more common paralleism path?
|
|
|
|
real_size = lax.psum(1, name)
|
|
|
|
if real_size != size:
|
|
|
|
raise ValueError(f"The shape of axis {name} was specified as {size}, "
|
|
|
|
f"but it really is {real_size}")
|
|
|
|
axis_index = lax.axis_index(name)
|
|
|
|
keys = random_fold_in(keys, axis_index)
|
|
|
|
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional)
|
|
|
|
|
|
|
|
random_bits_p = core.Primitive('random_bits')
|
|
|
|
ad.defjvp_zero(random_bits_p)
|
|
|
|
batching.defvectorized(random_bits_p)
|
|
|
|
|
|
|
|
@random_bits_p.def_abstract_eval
|
|
|
|
def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
|
|
|
|
out_shape = (*keys_aval.shape, *shape)
|
|
|
|
out_dtype = dtypes.dtype(f'uint{bit_width}')
|
|
|
|
return core.ShapedArray(out_shape, out_dtype)
|
|
|
|
|
|
|
|
@random_bits_p.def_impl
|
|
|
|
def random_bits_impl(keys, *, bit_width, shape):
|
|
|
|
return random_bits_impl_base(keys.impl, keys.unsafe_raw_array(), keys.ndim,
|
|
|
|
bit_width=bit_width, shape=shape)
|
|
|
|
|
|
|
|
def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):
|
|
|
|
bits = iterated_vmap_unary(
|
|
|
|
keys_ndim, lambda k: impl.random_bits(k, bit_width, shape))
|
|
|
|
return bits(base_arr)
|
|
|
|
|
|
|
|
def random_bits_lowering(ctx, keys, *, bit_width, shape):
|
|
|
|
aval, = ctx.avals_in
|
|
|
|
impl = aval.dtype.impl
|
|
|
|
bits = iterated_vmap_unary(
|
|
|
|
aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
|
|
|
|
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
|
|
|
|
ctx_new = ctx.replace(avals_in=[keys_aval_to_base_arr_aval(aval)])
|
|
|
|
out = bits_lowering(ctx_new, keys)
|
|
|
|
ctx.set_tokens_out(ctx_new.tokens_out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
mlir.register_lowering(random_bits_p, random_bits_lowering)
|
|
|
|
|
|
|
|
|
|
|
|
# The following wrap/unwrap primitives are at least a stopgap for
|
|
|
|
# backwards compatibility, namely when `config.jax_enable_custom_prng`
|
|
|
|
# is False. We need to convert key arrays to and from underlying
|
|
|
|
# uint32 base array, and we may need to do so under a jit. For
|
|
|
|
# example, we want to support:
|
|
|
|
#
|
|
|
|
# keys = jax.jit(random.split)(key)
|
|
|
|
#
|
|
|
|
# where `key` and `keys` are both acceptably old-style uint32 arrays
|
|
|
|
# so long as enable_custom_prng is False. The way we handle this is
|
|
|
|
# that `random.split` adapts the input/output by converting to/from
|
|
|
|
# key arrays across its call to `random_split`. So we rely on these
|
|
|
|
# wrap/unwrap casting primitives to allow that conversion under jit.
|
|
|
|
#
|
|
|
|
# We may want to keep both around for testing and debugging escape
|
|
|
|
# hatches. We can rename them `unsafe` for emphasis, and/or issue a
|
|
|
|
# warning on entry to the traceable.
|
|
|
|
#
|
|
|
|
# TODO(frostig): Consider removal once we always enable_custom_prng.
|
|
|
|
|
|
|
|
def random_wrap(base_arr, *, impl):
|
|
|
|
_check_prng_key_data(impl, base_arr)
|
|
|
|
return random_wrap_p.bind(base_arr, impl=impl)
|
|
|
|
|
|
|
|
random_wrap_p = core.Primitive('random_wrap')
|
|
|
|
ad.defjvp_zero(random_wrap_p)
|
|
|
|
|
|
|
|
@random_wrap_p.def_abstract_eval
|
|
|
|
def random_wrap_abstract_eval(base_arr_aval, *, impl):
|
|
|
|
shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape)
|
|
|
|
return keys_shaped_array(impl, shape)
|
|
|
|
|
|
|
|
@random_wrap_p.def_impl
|
|
|
|
def random_wrap_impl(base_arr, *, impl):
|
|
|
|
return PRNGKeyArray(impl, base_arr)
|
|
|
|
|
|
|
|
def random_wrap_lowering(ctx, base_arr, *, impl):
|
|
|
|
return [base_arr]
|
|
|
|
|
|
|
|
def random_wrap_batch_rule(batched_args, batch_dims, *, impl):
|
|
|
|
x, = batched_args
|
|
|
|
d, = batch_dims
|
|
|
|
x = batching.bdim_at_front(x, d, 1)
|
|
|
|
return random_wrap(x, impl=impl), 0
|
|
|
|
|
|
|
|
mlir.register_lowering(random_wrap_p, random_wrap_lowering)
|
|
|
|
batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
|
|
|
|
|
|
|
|
|
|
|
|
def random_unwrap(keys):
|
2022-08-30 18:05:01 -07:00
|
|
|
if not isinstance(keys, PRNGKeyArray):
|
|
|
|
raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}')
|
2022-08-22 13:56:50 -07:00
|
|
|
return random_unwrap_p.bind(keys)
|
|
|
|
|
|
|
|
random_unwrap_p = core.Primitive('random_unwrap')
|
|
|
|
ad.defjvp_zero(random_unwrap_p)
|
|
|
|
batching.defvectorized(random_unwrap_p)
|
|
|
|
|
|
|
|
@random_unwrap_p.def_abstract_eval
|
|
|
|
def random_unwrap_abstract_eval(keys_aval):
|
|
|
|
return keys_aval_to_base_arr_aval(keys_aval)
|
|
|
|
|
|
|
|
@random_unwrap_p.def_impl
|
|
|
|
def random_unwrap_impl(keys):
|
|
|
|
return keys.unsafe_raw_array()
|
|
|
|
|
|
|
|
def random_unwrap_lowering(ctx, keys):
|
|
|
|
return [keys]
|
|
|
|
|
|
|
|
mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
|
|
|
|
|
|
|
|
|
|
|
|
# -- threefry2x32 PRNG implementation
|
2021-06-08 11:16:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
|
|
|
|
try:
|
|
|
|
return key.shape == (2,) and key.dtype == np.uint32
|
|
|
|
except AttributeError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def threefry_seed(seed: jnp.ndarray) -> jnp.ndarray:
|
|
|
|
"""Create a single raw threefry PRNG key from an integer seed.
|
2021-06-03 21:55:39 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
seed: a 64- or 32-bit integer used as the value of the key.
|
|
|
|
|
|
|
|
Returns:
|
2021-06-08 11:16:33 -07:00
|
|
|
The PRNG key contents, modeled as an array of shape (2,) and dtype
|
|
|
|
uint32. The key is constructed from a 64-bit seed by effectively
|
|
|
|
bit-casting to a pair of uint32 values (or from a 32-bit seed by
|
|
|
|
first padding out with zeros).
|
2021-06-03 21:55:39 -07:00
|
|
|
"""
|
2022-08-22 13:56:50 -07:00
|
|
|
if seed.shape:
|
2021-06-08 11:16:33 -07:00
|
|
|
raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
|
2022-08-22 13:56:50 -07:00
|
|
|
if not np.issubdtype(seed.dtype, np.integer):
|
2021-06-08 11:16:33 -07:00
|
|
|
raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
|
2021-06-03 21:55:39 -07:00
|
|
|
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
2022-03-07 12:25:01 -08:00
|
|
|
k1 = convert(
|
2022-08-22 13:56:50 -07:00
|
|
|
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
|
2022-05-27 11:12:39 -07:00
|
|
|
with jax.numpy_dtype_promotion('standard'):
|
|
|
|
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
|
|
|
|
# inputs. We should avoid this.
|
2022-08-22 13:56:50 -07:00
|
|
|
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
|
2021-06-03 21:55:39 -07:00
|
|
|
return lax.concatenate([k1, k2], 0)
|
|
|
|
|
|
|
|
|
|
|
|
def _make_rotate_left(dtype):
|
|
|
|
if not jnp.issubdtype(dtype, np.integer):
|
|
|
|
raise TypeError("_rotate_left only accepts integer dtypes.")
|
|
|
|
nbits = np.array(jnp.iinfo(dtype).bits, dtype)
|
|
|
|
|
|
|
|
def _rotate_left(x, d):
|
|
|
|
if lax.dtype(d) != dtype:
|
|
|
|
d = lax.convert_element_type(d, dtype)
|
|
|
|
if lax.dtype(x) != dtype:
|
|
|
|
x = lax.convert_element_type(x, dtype)
|
|
|
|
return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
|
|
|
|
return _rotate_left
|
|
|
|
|
|
|
|
|
|
|
|
def _bit_stats(bits):
|
|
|
|
"""This is a debugging function to compute the statistics of bit fields."""
|
|
|
|
return np.array([list(map(int, np.binary_repr(x, 64))) for x in bits]).mean(0)
|
|
|
|
|
|
|
|
|
|
|
|
### hash function and split
|
|
|
|
|
|
|
|
def _threefry2x32_abstract_eval(*args):
|
|
|
|
if any(a.dtype != jnp.uint32 for a in args):
|
|
|
|
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
|
|
|
|
.format(args))
|
|
|
|
if all(isinstance(arg, core.ShapedArray) for arg in args):
|
2022-08-22 13:56:50 -07:00
|
|
|
shape = lax_internal.broadcasting_shape_rule(*args)
|
2021-06-03 21:55:39 -07:00
|
|
|
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
|
|
|
|
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
|
|
|
|
else:
|
|
|
|
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
|
|
|
|
return (aval,) * 2
|
|
|
|
|
|
|
|
|
|
|
|
rotate_left = _make_rotate_left(np.uint32)
|
|
|
|
|
|
|
|
|
|
|
|
def apply_round(v, rot):
|
|
|
|
v = v[:]
|
|
|
|
v[0] = v[0] + v[1]
|
|
|
|
v[1] = rotate_left(v[1], rot)
|
|
|
|
v[1] = v[0] ^ v[1]
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
def rotate_list(xs):
|
|
|
|
return xs[1:] + xs[:1]
|
|
|
|
|
|
|
|
|
|
|
|
def rolled_loop_step(i, state):
|
|
|
|
x, ks, rotations = state
|
|
|
|
for r in rotations[0]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
new_x = [x[0] + ks[0], x[1] + ks[1] + jnp.asarray(i + 1, dtype=np.uint32)]
|
|
|
|
return new_x, rotate_list(ks), rotate_list(rotations)
|
|
|
|
|
|
|
|
|
|
|
|
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
|
|
|
"""Apply the Threefry 2x32 hash.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
keypair: a pair of 32bit unsigned integers used for the key.
|
|
|
|
count: an array of dtype uint32 used for the counts.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array of dtype uint32 with the same shape as `count`.
|
|
|
|
"""
|
|
|
|
x = [x1, x2]
|
|
|
|
|
|
|
|
rotations = [np.array([13, 15, 26, 6], dtype=np.uint32),
|
|
|
|
np.array([17, 29, 16, 24], dtype=np.uint32)]
|
|
|
|
ks = [key1, key2, key1 ^ key2 ^ np.uint32(0x1BD11BDA)]
|
|
|
|
|
|
|
|
x[0] = x[0] + ks[0]
|
|
|
|
x[1] = x[1] + ks[1]
|
|
|
|
|
|
|
|
if use_rolled_loops:
|
|
|
|
x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations))
|
|
|
|
|
|
|
|
else:
|
|
|
|
for r in rotations[0]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[1]
|
|
|
|
x[1] = x[1] + ks[2] + np.uint32(1)
|
|
|
|
|
|
|
|
for r in rotations[1]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[2]
|
|
|
|
x[1] = x[1] + ks[0] + np.uint32(2)
|
|
|
|
|
|
|
|
for r in rotations[0]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[0]
|
|
|
|
x[1] = x[1] + ks[1] + np.uint32(3)
|
|
|
|
|
|
|
|
for r in rotations[1]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[1]
|
|
|
|
x[1] = x[1] + ks[2] + np.uint32(4)
|
|
|
|
|
|
|
|
for r in rotations[0]:
|
|
|
|
x = apply_round(x, r)
|
|
|
|
x[0] = x[0] + ks[2]
|
|
|
|
x[1] = x[1] + ks[0] + np.uint32(5)
|
|
|
|
|
|
|
|
return tuple(x)
|
|
|
|
|
|
|
|
|
2022-05-05 10:54:53 -07:00
|
|
|
def _threefry2x32_gpu_lowering(threefry2x32_lowering, ctx, k1, k2, x1, x2):
|
2022-04-06 13:22:25 -07:00
|
|
|
aval_out, _ = ctx.avals_out
|
|
|
|
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
|
|
|
|
rank = len(aval_out.shape)
|
|
|
|
if 0 in aval_out.shape:
|
|
|
|
zeros = mlir.full_like_aval(0, aval_out)
|
|
|
|
return [zeros, zeros]
|
|
|
|
def _broadcast(x, aval):
|
|
|
|
return mhlo.BroadcastInDimOp(
|
|
|
|
mlir.aval_to_ir_type(aval_out), x,
|
|
|
|
mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result
|
2022-05-05 10:54:53 -07:00
|
|
|
return threefry2x32_lowering(
|
|
|
|
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
|
|
|
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
|
2022-04-06 13:22:25 -07:00
|
|
|
|
2021-06-03 21:55:39 -07:00
|
|
|
|
|
|
|
threefry2x32_p = core.Primitive("threefry2x32")
|
|
|
|
threefry2x32_p.multiple_results = True
|
|
|
|
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
|
|
|
|
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
|
|
|
batching.defbroadcasting(threefry2x32_p)
|
2022-04-06 13:22:25 -07:00
|
|
|
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
|
|
|
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
|
|
|
multiple_results=True))
|
|
|
|
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
|
|
|
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
|
|
|
multiple_results=True), platform='cpu')
|
2022-07-08 00:21:16 +00:00
|
|
|
mlir.register_lowering(
|
|
|
|
threefry2x32_p,
|
|
|
|
partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32),
|
|
|
|
platform='cuda')
|
|
|
|
mlir.register_lowering(
|
|
|
|
threefry2x32_p,
|
|
|
|
partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32),
|
|
|
|
platform='rocm')
|
2022-05-05 10:54:53 -07:00
|
|
|
|
2021-06-03 21:55:39 -07:00
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, inline=True)
|
2021-06-03 21:55:39 -07:00
|
|
|
def threefry_2x32(keypair, count):
|
|
|
|
"""Apply the Threefry 2x32 hash.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
keypair: a pair of 32bit unsigned integers used for the key.
|
|
|
|
count: an array of dtype uint32 used for the counts.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An array of dtype uint32 with the same shape as `count`.
|
|
|
|
"""
|
|
|
|
key1, key2 = keypair
|
|
|
|
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32:
|
|
|
|
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
|
|
|
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
|
|
|
|
|
|
|
try:
|
|
|
|
odd_size = count.size % 2
|
|
|
|
except core.InconclusiveDimensionOperation as e:
|
|
|
|
msg = ("jax.random functions have limited support for shape polymorphism. "
|
|
|
|
"In particular, the product of the known dimensions must be even.")
|
|
|
|
raise core.InconclusiveDimensionOperation(msg) from e
|
|
|
|
|
|
|
|
if odd_size:
|
|
|
|
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
|
|
|
|
else:
|
|
|
|
x = list(jnp.split(count.ravel(), 2))
|
|
|
|
|
|
|
|
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
|
|
|
|
out = jnp.concatenate(x)
|
|
|
|
assert out.dtype == np.uint32
|
|
|
|
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
|
|
|
|
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
|
|
|
return _threefry_split(key, int(num)) # type: ignore
|
2021-06-03 21:55:39 -07:00
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, static_argnums=(1,), inline=True)
|
2021-06-08 11:16:33 -07:00
|
|
|
def _threefry_split(key, num) -> jnp.ndarray:
|
2021-06-03 21:55:39 -07:00
|
|
|
counts = lax.iota(np.uint32, num * 2)
|
|
|
|
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def threefry_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
|
|
|
assert not data.shape
|
2021-06-08 11:16:33 -07:00
|
|
|
return _threefry_fold_in(key, jnp.uint32(data))
|
2021-06-03 21:55:39 -07:00
|
|
|
|
2021-08-20 13:43:38 -07:00
|
|
|
@partial(jit, inline=True)
|
2021-06-08 11:16:33 -07:00
|
|
|
def _threefry_fold_in(key, data):
|
|
|
|
return threefry_2x32(key, threefry_seed(data))
|
2021-06-03 21:55:39 -07:00
|
|
|
|
|
|
|
|
2021-06-08 11:16:33 -07:00
|
|
|
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
2021-06-03 21:55:39 -07:00
|
|
|
"""Sample uniform random bits of given width and shape using PRNG key."""
|
2021-06-08 11:16:33 -07:00
|
|
|
if not _is_threefry_prng_key(key):
|
2021-10-01 18:15:00 -07:00
|
|
|
raise TypeError("threefry_random_bits got invalid prng key.")
|
2021-06-03 21:55:39 -07:00
|
|
|
if bit_width not in (8, 16, 32, 64):
|
|
|
|
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
2022-10-25 08:13:55 -07:00
|
|
|
|
|
|
|
if (config.jax_threefry_partitionable and bit_width == 32 and
|
|
|
|
not any(core.is_special_dim_size(d) for d in shape)):
|
|
|
|
return _threefry_random_bits_partitionable(key, bit_width, shape)
|
|
|
|
else:
|
|
|
|
return _threefry_random_bits_original(key, bit_width, shape)
|
|
|
|
|
|
|
|
def _threefry_random_bits_partitionable(key: jnp.ndarray, bit_width, shape):
|
|
|
|
if all(core.is_constant_dim(d) for d in shape) and prod(shape) > 2 ** 64:
|
|
|
|
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
|
|
|
|
|
|
|
|
size = prod(shape)
|
|
|
|
n, r = divmod(bit_width * size, 32)
|
|
|
|
if r > 0:
|
|
|
|
n += 1
|
|
|
|
even_size = n + (n % 2)
|
|
|
|
|
|
|
|
if not shape:
|
|
|
|
counts = jnp.arange(n, dtype=jnp.uint32).reshape(shape)
|
|
|
|
else:
|
|
|
|
iotas = [lax.broadcasted_iota(jnp.dtype('uint32'), shape, i)
|
|
|
|
for i in range(len(shape))]
|
|
|
|
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
|
|
|
|
counts = sum(s * i for i, s in zip(iotas, strides)) # type: ignore
|
|
|
|
circ0 = counts % (even_size // 2)
|
|
|
|
circ1 = (circ0 + even_size // 2) % n
|
|
|
|
k1, k2 = key
|
|
|
|
bits_xx, bits_yy = threefry2x32_p.bind(k1, k2, circ0, circ1)
|
|
|
|
|
|
|
|
dtype = UINT_DTYPES[bit_width]
|
|
|
|
if bit_width == 64:
|
|
|
|
assert n == even_size
|
|
|
|
assert False # broken...
|
|
|
|
bits_x, bits_y = bits_xx[:size // 2], bits_yy[:size // 2]
|
|
|
|
bits_x = lax.convert_element_type(bits_x, dtype)
|
|
|
|
bits_y = lax.convert_element_type(bits_y, dtype)
|
|
|
|
bits = lax.shift_left(bits_x, dtype(32)) | bits_y
|
|
|
|
else:
|
|
|
|
bits = jnp.where(counts < even_size // 2, bits_xx, bits_yy)
|
|
|
|
if bit_width != 32:
|
|
|
|
assert False # broken...
|
|
|
|
bits = bits.view(dtype)
|
|
|
|
|
|
|
|
return bits
|
|
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
|
|
|
def _threefry_random_bits_original(key: jnp.ndarray, bit_width, shape):
|
2022-08-22 13:56:50 -07:00
|
|
|
size = prod(shape)
|
2021-06-03 21:55:39 -07:00
|
|
|
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
|
|
|
|
# polymorphism
|
|
|
|
max_count, r = divmod(bit_width * size, 32)
|
|
|
|
if r > 0:
|
|
|
|
max_count += 1
|
|
|
|
|
|
|
|
if core.is_constant_dim(max_count):
|
|
|
|
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
|
|
|
|
else:
|
|
|
|
nblocks, rem = 0, max_count
|
|
|
|
|
|
|
|
if not nblocks:
|
|
|
|
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
|
|
|
|
else:
|
2021-06-08 11:16:33 -07:00
|
|
|
keys = threefry_split(key, nblocks + 1)
|
2021-06-03 21:55:39 -07:00
|
|
|
subkeys, last_key = keys[:-1], keys[-1]
|
|
|
|
blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
|
|
|
|
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
|
|
|
|
bits = lax.concatenate([blocks.ravel(), last], 0)
|
|
|
|
|
|
|
|
dtype = UINT_DTYPES[bit_width]
|
|
|
|
if bit_width == 64:
|
|
|
|
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
|
|
|
|
bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
|
|
|
|
elif bit_width in [8, 16]:
|
|
|
|
# this is essentially bits.view(dtype)[:size]
|
|
|
|
bits = lax.bitwise_and(
|
|
|
|
np.uint32(np.iinfo(dtype).max),
|
|
|
|
lax.shift_right_logical(
|
|
|
|
lax.broadcast(bits, (1,)),
|
|
|
|
lax.mul(
|
|
|
|
np.uint32(bit_width),
|
|
|
|
lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
2022-06-20 10:48:15 +02:00
|
|
|
bits = lax.reshape(bits, ((max_count * 32 // bit_width),), (1, 0))
|
2021-06-03 21:55:39 -07:00
|
|
|
bits = lax.convert_element_type(bits, dtype)[:size]
|
|
|
|
return lax.reshape(bits, shape)
|
2021-06-08 11:16:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
threefry_prng_impl = PRNGImpl(
|
|
|
|
key_shape=(2,),
|
|
|
|
seed=threefry_seed,
|
|
|
|
split=threefry_split,
|
|
|
|
random_bits=threefry_random_bits,
|
2022-08-22 13:56:50 -07:00
|
|
|
fold_in=threefry_fold_in,
|
|
|
|
tag='fry')
|
2021-10-01 18:15:00 -07:00
|
|
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
# -- RngBitGenerator PRNG implementation
|
2021-10-01 18:15:00 -07:00
|
|
|
|
|
|
|
# This code is experimental!
|
|
|
|
# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
|
|
|
|
# Notice that the RngBitGenerator operations are not guaranteed to be
|
|
|
|
# stable/deterministic across backends or compiler versions. Correspondingly, we
|
|
|
|
# reserve the right to change any of these implementations at any time!
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def _rbg_seed(seed: jnp.ndarray) -> jnp.ndarray:
|
|
|
|
assert not seed.shape
|
2021-10-01 18:15:00 -07:00
|
|
|
halfkey = threefry_seed(seed)
|
Revert: https://github.com/google/jax/pull/10221 (2nd revert)
Prefer jnp.tile over concatenate.
jnp.tile generates a jaxpr like the following:
```
{ lambda ; a:i32[720192]. let
b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a
c:i32[720192] = squeeze[dimensions=(0,)] b
d:i32[2,720192] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(2, 720192)
] c
e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d
in (e,) }
```
whereas lax.concatenate generates the following jaxpr:
```
{ lambda ; a:i32[720192]. let
b:i32[1440384] = concatenate[dimension=0] a a
in (b,) }
```
It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible.
Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change.
PiperOrigin-RevId: 444287005
2022-04-25 09:15:25 -07:00
|
|
|
return jnp.concatenate([halfkey, halfkey])
|
2021-10-01 18:15:00 -07:00
|
|
|
|
|
|
|
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
2021-10-07 21:19:06 -07:00
|
|
|
return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
|
2021-10-06 21:54:22 -07:00
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def _rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
|
|
|
assert not data.shape
|
2021-10-07 21:19:06 -07:00
|
|
|
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
|
2021-10-01 18:15:00 -07:00
|
|
|
|
|
|
|
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
|
|
|
|
) -> jnp.ndarray:
|
|
|
|
if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'):
|
|
|
|
raise TypeError("_rbg_random_bits got invalid prng key.")
|
|
|
|
if bit_width not in (8, 16, 32, 64):
|
|
|
|
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
|
|
|
_, bits = lax.rng_bit_generator(key, shape, dtype=UINT_DTYPES[bit_width])
|
|
|
|
return bits
|
|
|
|
|
|
|
|
rbg_prng_impl = PRNGImpl(
|
|
|
|
key_shape=(4,),
|
|
|
|
seed=_rbg_seed,
|
|
|
|
split=_rbg_split,
|
|
|
|
random_bits=_rbg_random_bits,
|
2022-08-22 13:56:50 -07:00
|
|
|
fold_in=_rbg_fold_in,
|
|
|
|
tag='rbg')
|
2021-10-06 21:54:22 -07:00
|
|
|
|
|
|
|
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
|
|
|
# treat 10 iterations of random bits as a 'hash function'
|
|
|
|
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
|
|
|
|
return keys[::10]
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
|
|
|
assert not data.shape
|
2021-10-06 21:54:22 -07:00
|
|
|
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
|
|
|
|
return key ^ random_bits[-1]
|
|
|
|
|
|
|
|
unsafe_rbg_prng_impl = PRNGImpl(
|
|
|
|
key_shape=(4,),
|
|
|
|
seed=_rbg_seed,
|
|
|
|
split=_unsafe_rbg_split,
|
|
|
|
random_bits=_rbg_random_bits,
|
2022-08-22 13:56:50 -07:00
|
|
|
fold_in=_unsafe_rbg_fold_in,
|
|
|
|
tag='urbg')
|