mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 12:36:07 +00:00

Prefer jnp.tile over concatenate. jnp.tile generates a jaxpr like the following: ``` { lambda ; a:i32[720192]. let b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a c:i32[720192] = squeeze[dimensions=(0,)] b d:i32[2,720192] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(2, 720192) ] c e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d in (e,) } ``` whereas lax.concatenate generates the following jaxpr: ``` { lambda ; a:i32[720192]. let b:i32[1440384] = concatenate[dimension=0] a a in (b,) } ``` It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible. Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change. PiperOrigin-RevId: 444287005
587 lines
20 KiB
Python
587 lines
20 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from functools import partial
|
|
from typing import Callable, Iterator, NamedTuple, Sequence
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from jax import lax
|
|
from jax import core
|
|
from jax import numpy as jnp
|
|
from jax import tree_util
|
|
from jax.config import config
|
|
from jax.dtypes import float0
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters import xla
|
|
from jax._src.api import jit, vmap
|
|
from jax._src.lax import lax as lax_internal
|
|
from jax._src.lib import cuda_prng
|
|
from jax._src.lib.mlir.dialects import mhlo
|
|
from jax._src.numpy.lax_numpy import (
|
|
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
|
|
_expand_bool_indices, _register_stackable)
|
|
from jax._src.lib import hip_prng
|
|
import jax._src.pretty_printer as pp
|
|
from jax._src.util import canonicalize_axis, prod
|
|
|
|
|
|
UINT_DTYPES = {
|
|
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
|
|
|
|
# -- PRNG implementation interface --
|
|
|
|
class PRNGImpl(NamedTuple):
|
|
"""Specifies PRNG key shape and operations.
|
|
|
|
A PRNG implementation is determined by a key type ``K`` and a
|
|
collection of functions that operate on such keys. The key type
|
|
``K`` is an array type with element type uint32 and shape specified
|
|
by ``key_shape``. The type signature of each operations is::
|
|
|
|
seed :: int[] -> K
|
|
fold_in :: K -> int[] -> K
|
|
split[n] :: K -> K[n]
|
|
random_bits[shape, bit_width] :: K -> uint<bit_width>[shape]
|
|
|
|
A PRNG implementation is adapted to an array-like object of keys
|
|
``K`` by the ``PRNGKeyArray`` class, which should be created via the
|
|
``seed_with_impl`` function.
|
|
"""
|
|
key_shape: core.Shape
|
|
seed: Callable
|
|
split: Callable
|
|
random_bits: Callable
|
|
fold_in: Callable
|
|
|
|
def pprint(self):
|
|
return (pp.text(f"{self.__class__.__name__}:") +
|
|
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
|
|
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
|
|
]))))
|
|
|
|
|
|
# -- PRNG key arrays --
|
|
|
|
def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
|
ndim = len(impl.key_shape)
|
|
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
|
|
f"to have ndim, shape, and dtype attributes. Got {key_data}")
|
|
if key_data.ndim < 1:
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected "
|
|
f"key_data.ndim >= 1; got ndim={key_data.ndim}")
|
|
if key_data.shape[-ndim:] != impl.key_shape:
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.shape to "
|
|
f"end with {impl.key_shape}; got shape={key_data.shape} for impl={impl}")
|
|
if key_data.dtype not in [np.uint32, float0]:
|
|
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
|
|
f"got dtype={key_data.dtype}")
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class PRNGKeyArray:
|
|
"""An array whose elements are PRNG keys.
|
|
|
|
This class lifts the definition of a PRNG, provided in the form of a
|
|
``PRNGImpl``, into an array-like pytree class. Instances of this
|
|
class behave like an array whose base elements are keys, hiding the
|
|
fact that keys are typically arrays (of ``uint32`` dtype) themselves.
|
|
|
|
PRNGKeyArrays are also restricted relative to JAX arrays in that
|
|
they do not expose arithmetic operations. They instead expose
|
|
wrapper methods around the PRNG implementation functions (``split``,
|
|
``random_bits``, ``fold_in``).
|
|
"""
|
|
|
|
impl: PRNGImpl
|
|
_keys: jnp.ndarray
|
|
|
|
def __init__(self, impl, key_data: jnp.ndarray):
|
|
# key_data might be a placeholder python `object` or `bool`
|
|
# instead of a jnp.ndarray due to tree_unflatten
|
|
if type(key_data) not in [object, bool]:
|
|
_check_prng_key_data(impl, key_data)
|
|
self.impl = impl
|
|
self._keys = key_data
|
|
|
|
def tree_flatten(self):
|
|
return (self._keys,), self.impl
|
|
|
|
def unsafe_raw_array(self):
|
|
"""Access the raw numerical array that carries underlying key data.
|
|
|
|
Returns:
|
|
A uint32 JAX array whose leading dimensions are ``self.shape``.
|
|
"""
|
|
return self._keys
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, impl, keys):
|
|
keys, = keys
|
|
return cls(impl, keys)
|
|
|
|
@property
|
|
def dtype(self):
|
|
# TODO(frostig): remove after deprecation window
|
|
if config.jax_enable_custom_prng:
|
|
raise AttributeError("'PRNGKeyArray' has no attribute 'dtype'")
|
|
else:
|
|
warnings.warn(
|
|
'deprecated `dtype` attribute of PRNG key arrays', FutureWarning)
|
|
return np.uint32
|
|
|
|
@property
|
|
def shape(self):
|
|
# TODO(frostig): simplify once we always enable_custom_prng
|
|
if config.jax_enable_custom_prng:
|
|
return self._shape
|
|
else:
|
|
warnings.warn(
|
|
'deprecated `shape` attribute of PRNG key arrays. In a future version '
|
|
'of JAX this attribute will be removed or its value may change.',
|
|
FutureWarning)
|
|
return self._keys.shape
|
|
|
|
@property
|
|
def _shape(self):
|
|
base_ndim = len(self.impl.key_shape)
|
|
return self._keys.shape[:-base_ndim]
|
|
|
|
@property
|
|
def ndim(self):
|
|
return len(self.shape)
|
|
|
|
def _is_scalar(self):
|
|
base_ndim = len(self.impl.key_shape)
|
|
return self._keys.ndim == base_ndim
|
|
|
|
def __len__(self):
|
|
if self._is_scalar():
|
|
raise TypeError('len() of unsized object')
|
|
return len(self._keys)
|
|
|
|
def __iter__(self) -> Iterator['PRNGKeyArray']:
|
|
if self._is_scalar():
|
|
raise TypeError('iteration over a 0-d single PRNG key')
|
|
return (PRNGKeyArray(self.impl, k) for k in iter(self._keys))
|
|
|
|
def __getitem__(self, idx) -> 'PRNGKeyArray':
|
|
base_ndim = len(self.impl.key_shape)
|
|
ndim = self._keys.ndim - base_ndim
|
|
indexable_shape = self.impl.key_shape[:ndim]
|
|
idx = _eliminate_deprecated_list_indexing(idx)
|
|
idx = _expand_bool_indices(idx, indexable_shape)
|
|
idx = _canonicalize_tuple_index(ndim, idx, array_name='PRNGKeyArray')
|
|
return PRNGKeyArray(self.impl, self._keys[idx])
|
|
|
|
def _fold_in(self, data: int) -> 'PRNGKeyArray':
|
|
return PRNGKeyArray(self.impl, self.impl.fold_in(self._keys, data))
|
|
|
|
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
|
|
return self.impl.random_bits(self._keys, bit_width, shape)
|
|
|
|
def _split(self, num: int) -> 'PRNGKeyArray':
|
|
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
|
|
|
|
def reshape(self, newshape, order=None):
|
|
reshaped_keys = jnp.reshape(self._keys, (*newshape, -1), order=order)
|
|
return PRNGKeyArray(self.impl, reshaped_keys)
|
|
|
|
def concatenate(self, key_arrs, axis):
|
|
axis = canonicalize_axis(axis, self.ndim)
|
|
arrs = [self._keys, *[k._keys for k in key_arrs]]
|
|
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
|
|
|
|
def broadcast_to(self, shape):
|
|
if jnp.ndim(shape) == 0:
|
|
shape = (shape,)
|
|
new_shape = (*shape, *self.impl.key_shape)
|
|
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))
|
|
|
|
def expand_dims(self, dimensions: Sequence[int]):
|
|
# follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
|
|
ndim_out = self.ndim + len(set(dimensions))
|
|
dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
|
|
return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))
|
|
|
|
def __repr__(self):
|
|
arr_shape = self._shape
|
|
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
|
|
pp_impl = pp.text('impl = ') + self.impl.pprint()
|
|
return str(pp.group(
|
|
pp.text('PRNGKeyArray:') +
|
|
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
|
|
|
|
|
|
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
|
|
return PRNGKeyArray(impl, impl.seed(seed))
|
|
|
|
_register_stackable(PRNGKeyArray)
|
|
|
|
# -- threefry2x32 PRNG implementation --
|
|
|
|
|
|
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
|
|
try:
|
|
return key.shape == (2,) and key.dtype == np.uint32
|
|
except AttributeError:
|
|
return False
|
|
|
|
|
|
def threefry_seed(seed: int) -> jnp.ndarray:
|
|
"""Create a single raw threefry PRNG key given an integer seed.
|
|
|
|
Args:
|
|
seed: a 64- or 32-bit integer used as the value of the key.
|
|
|
|
Returns:
|
|
The PRNG key contents, modeled as an array of shape (2,) and dtype
|
|
uint32. The key is constructed from a 64-bit seed by effectively
|
|
bit-casting to a pair of uint32 values (or from a 32-bit seed by
|
|
first padding out with zeros).
|
|
"""
|
|
# Avoid overflowerror in X32 mode by first converting ints to int64.
|
|
# This breaks JIT invariance for large ints, but supports the common
|
|
# use-case of instantiating with Python hashes in X32 mode.
|
|
if isinstance(seed, int):
|
|
seed_arr = jnp.asarray(np.int64(seed))
|
|
else:
|
|
seed_arr = jnp.asarray(seed)
|
|
if seed_arr.shape:
|
|
raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
|
|
if not np.issubdtype(seed_arr.dtype, np.integer):
|
|
raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
|
|
|
|
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
|
k1 = convert(
|
|
lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
|
|
k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
|
|
return lax.concatenate([k1, k2], 0)
|
|
|
|
|
|
def _make_rotate_left(dtype):
|
|
if not jnp.issubdtype(dtype, np.integer):
|
|
raise TypeError("_rotate_left only accepts integer dtypes.")
|
|
nbits = np.array(jnp.iinfo(dtype).bits, dtype)
|
|
|
|
def _rotate_left(x, d):
|
|
if lax.dtype(d) != dtype:
|
|
d = lax.convert_element_type(d, dtype)
|
|
if lax.dtype(x) != dtype:
|
|
x = lax.convert_element_type(x, dtype)
|
|
return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
|
|
return _rotate_left
|
|
|
|
|
|
def _bit_stats(bits):
|
|
"""This is a debugging function to compute the statistics of bit fields."""
|
|
return np.array([list(map(int, np.binary_repr(x, 64))) for x in bits]).mean(0)
|
|
|
|
|
|
### hash function and split
|
|
|
|
def _threefry2x32_abstract_eval(*args):
|
|
if any(a.dtype != jnp.uint32 for a in args):
|
|
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
|
|
.format(args))
|
|
if all(isinstance(arg, core.ShapedArray) for arg in args):
|
|
shape = lax_internal._broadcasting_shape_rule(*args)
|
|
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
|
|
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
|
|
else:
|
|
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
|
|
return (aval,) * 2
|
|
|
|
|
|
rotate_left = _make_rotate_left(np.uint32)
|
|
|
|
|
|
def apply_round(v, rot):
|
|
v = v[:]
|
|
v[0] = v[0] + v[1]
|
|
v[1] = rotate_left(v[1], rot)
|
|
v[1] = v[0] ^ v[1]
|
|
return v
|
|
|
|
|
|
def rotate_list(xs):
|
|
return xs[1:] + xs[:1]
|
|
|
|
|
|
def rolled_loop_step(i, state):
|
|
x, ks, rotations = state
|
|
for r in rotations[0]:
|
|
x = apply_round(x, r)
|
|
new_x = [x[0] + ks[0], x[1] + ks[1] + jnp.asarray(i + 1, dtype=np.uint32)]
|
|
return new_x, rotate_list(ks), rotate_list(rotations)
|
|
|
|
|
|
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
|
"""Apply the Threefry 2x32 hash.
|
|
|
|
Args:
|
|
keypair: a pair of 32bit unsigned integers used for the key.
|
|
count: an array of dtype uint32 used for the counts.
|
|
|
|
Returns:
|
|
An array of dtype uint32 with the same shape as `count`.
|
|
"""
|
|
x = [x1, x2]
|
|
|
|
rotations = [np.array([13, 15, 26, 6], dtype=np.uint32),
|
|
np.array([17, 29, 16, 24], dtype=np.uint32)]
|
|
ks = [key1, key2, key1 ^ key2 ^ np.uint32(0x1BD11BDA)]
|
|
|
|
x[0] = x[0] + ks[0]
|
|
x[1] = x[1] + ks[1]
|
|
|
|
if use_rolled_loops:
|
|
x, _, _ = lax.fori_loop(0, 5, rolled_loop_step, (x, rotate_list(ks), rotations))
|
|
|
|
else:
|
|
for r in rotations[0]:
|
|
x = apply_round(x, r)
|
|
x[0] = x[0] + ks[1]
|
|
x[1] = x[1] + ks[2] + np.uint32(1)
|
|
|
|
for r in rotations[1]:
|
|
x = apply_round(x, r)
|
|
x[0] = x[0] + ks[2]
|
|
x[1] = x[1] + ks[0] + np.uint32(2)
|
|
|
|
for r in rotations[0]:
|
|
x = apply_round(x, r)
|
|
x[0] = x[0] + ks[0]
|
|
x[1] = x[1] + ks[1] + np.uint32(3)
|
|
|
|
for r in rotations[1]:
|
|
x = apply_round(x, r)
|
|
x[0] = x[0] + ks[1]
|
|
x[1] = x[1] + ks[2] + np.uint32(4)
|
|
|
|
for r in rotations[0]:
|
|
x = apply_round(x, r)
|
|
x[0] = x[0] + ks[2]
|
|
x[1] = x[1] + ks[0] + np.uint32(5)
|
|
|
|
return tuple(x)
|
|
|
|
|
|
def _threefry2x32_gpu_lowering(ctx, k1, k2, x1, x2):
|
|
aval_out, _ = ctx.avals_out
|
|
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
|
|
rank = len(aval_out.shape)
|
|
if 0 in aval_out.shape:
|
|
zeros = mlir.full_like_aval(0, aval_out)
|
|
return [zeros, zeros]
|
|
def _broadcast(x, aval):
|
|
return mhlo.BroadcastInDimOp(
|
|
mlir.aval_to_ir_type(aval_out), x,
|
|
mlir.dense_int_elements(range(rank - len(aval.shape), rank))).result
|
|
if cuda_prng:
|
|
return cuda_prng.threefry2x32_lowering(
|
|
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
|
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
|
|
else:
|
|
return hip_prng.threefry2x32_lowering(
|
|
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
|
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)))
|
|
|
|
|
|
threefry2x32_p = core.Primitive("threefry2x32")
|
|
threefry2x32_p.multiple_results = True
|
|
threefry2x32_p.def_impl(partial(xla.apply_primitive, threefry2x32_p))
|
|
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
|
|
batching.defbroadcasting(threefry2x32_p)
|
|
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
|
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
|
multiple_results=True))
|
|
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
|
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
|
multiple_results=True), platform='cpu')
|
|
|
|
if cuda_prng or hip_prng:
|
|
mlir.register_lowering(threefry2x32_p, _threefry2x32_gpu_lowering,
|
|
platform='gpu')
|
|
|
|
@partial(jit, inline=True)
|
|
def threefry_2x32(keypair, count):
|
|
"""Apply the Threefry 2x32 hash.
|
|
|
|
Args:
|
|
keypair: a pair of 32bit unsigned integers used for the key.
|
|
count: an array of dtype uint32 used for the counts.
|
|
|
|
Returns:
|
|
An array of dtype uint32 with the same shape as `count`.
|
|
"""
|
|
key1, key2 = keypair
|
|
if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32:
|
|
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
|
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
|
|
|
try:
|
|
odd_size = count.size % 2
|
|
except core.InconclusiveDimensionOperation as e:
|
|
msg = ("jax.random functions have limited support for shape polymorphism. "
|
|
"In particular, the product of the known dimensions must be even.")
|
|
raise core.InconclusiveDimensionOperation(msg) from e
|
|
|
|
if odd_size:
|
|
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
|
|
else:
|
|
x = list(jnp.split(count.ravel(), 2))
|
|
|
|
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
|
|
out = jnp.concatenate(x)
|
|
assert out.dtype == np.uint32
|
|
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
|
|
|
|
|
def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
|
return _threefry_split(key, int(num)) # type: ignore
|
|
|
|
@partial(jit, static_argnums=(1,), inline=True)
|
|
def _threefry_split(key, num) -> jnp.ndarray:
|
|
counts = lax.iota(np.uint32, num * 2)
|
|
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
|
|
|
|
|
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
|
return _threefry_fold_in(key, jnp.uint32(data))
|
|
|
|
@partial(jit, inline=True)
|
|
def _threefry_fold_in(key, data):
|
|
return threefry_2x32(key, threefry_seed(data))
|
|
|
|
|
|
@partial(jit, static_argnums=(1, 2), inline=True)
|
|
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
|
"""Sample uniform random bits of given width and shape using PRNG key."""
|
|
if not _is_threefry_prng_key(key):
|
|
raise TypeError("threefry_random_bits got invalid prng key.")
|
|
if bit_width not in (8, 16, 32, 64):
|
|
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
|
shape = core.as_named_shape(shape)
|
|
for name, size in shape.named_items:
|
|
real_size = lax.psum(1, name)
|
|
if real_size != size:
|
|
raise ValueError(f"The shape of axis {name} was specified as {size}, "
|
|
f"but it really is {real_size}")
|
|
axis_index = lax.axis_index(name)
|
|
key = threefry_fold_in(key, axis_index)
|
|
size = prod(shape.positional)
|
|
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
|
|
# polymorphism
|
|
max_count, r = divmod(bit_width * size, 32)
|
|
if r > 0:
|
|
max_count += 1
|
|
|
|
if core.is_constant_dim(max_count):
|
|
nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
|
|
else:
|
|
nblocks, rem = 0, max_count
|
|
|
|
if not nblocks:
|
|
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
|
|
else:
|
|
keys = threefry_split(key, nblocks + 1)
|
|
subkeys, last_key = keys[:-1], keys[-1]
|
|
blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
|
|
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
|
|
bits = lax.concatenate([blocks.ravel(), last], 0)
|
|
|
|
dtype = UINT_DTYPES[bit_width]
|
|
if bit_width == 64:
|
|
bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
|
|
bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
|
|
elif bit_width in [8, 16]:
|
|
# this is essentially bits.view(dtype)[:size]
|
|
bits = lax.bitwise_and(
|
|
np.uint32(np.iinfo(dtype).max),
|
|
lax.shift_right_logical(
|
|
lax.broadcast(bits, (1,)),
|
|
lax.mul(
|
|
np.uint32(bit_width),
|
|
lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)
|
|
)
|
|
)
|
|
)
|
|
bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width),), (1, 0))
|
|
bits = lax.convert_element_type(bits, dtype)[:size]
|
|
return lax.reshape(bits, shape)
|
|
|
|
|
|
threefry_prng_impl = PRNGImpl(
|
|
key_shape=(2,),
|
|
seed=threefry_seed,
|
|
split=threefry_split,
|
|
random_bits=threefry_random_bits,
|
|
fold_in=threefry_fold_in)
|
|
|
|
|
|
# -- RngBitGenerator PRNG implementation --
|
|
|
|
# This code is experimental!
|
|
# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
|
|
# Notice that the RngBitGenerator operations are not guaranteed to be
|
|
# stable/deterministic across backends or compiler versions. Correspondingly, we
|
|
# reserve the right to change any of these implementations at any time!
|
|
|
|
def _rbg_seed(seed: int) -> jnp.ndarray:
|
|
halfkey = threefry_seed(seed)
|
|
return jnp.concatenate([halfkey, halfkey])
|
|
|
|
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
|
return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
|
|
|
|
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
|
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
|
|
|
|
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
|
|
) -> jnp.ndarray:
|
|
if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'):
|
|
raise TypeError("_rbg_random_bits got invalid prng key.")
|
|
if bit_width not in (8, 16, 32, 64):
|
|
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
|
_, bits = lax.rng_bit_generator(key, shape, dtype=UINT_DTYPES[bit_width])
|
|
return bits
|
|
|
|
rbg_prng_impl = PRNGImpl(
|
|
key_shape=(4,),
|
|
seed=_rbg_seed,
|
|
split=_rbg_split,
|
|
random_bits=_rbg_random_bits,
|
|
fold_in=_rbg_fold_in)
|
|
|
|
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
|
# treat 10 iterations of random bits as a 'hash function'
|
|
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
|
|
return keys[::10]
|
|
|
|
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
|
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
|
|
return key ^ random_bits[-1]
|
|
|
|
unsafe_rbg_prng_impl = PRNGImpl(
|
|
key_shape=(4,),
|
|
seed=_rbg_seed,
|
|
split=_unsafe_rbg_split,
|
|
random_bits=_rbg_random_bits,
|
|
fold_in=_unsafe_rbg_fold_in)
|