Internal change

PiperOrigin-RevId: 468712508
This commit is contained in:
jax authors 2022-08-19 08:56:17 -07:00
parent ff7cd9a136
commit a6c6416872
18 changed files with 258 additions and 1122 deletions

View File

@ -1107,9 +1107,6 @@ def _check_scalar(x):
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
raise TypeError(
f"{name} with input element type {core.aval_eltype(aval).name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
@ -1128,9 +1125,6 @@ _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
raise TypeError(
f"{name} with output element type {core.aval_eltype(aval).name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
@ -1208,9 +1202,6 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
raise TypeError(
f"jacfwd with input element type {core.aval_eltype(aval).name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
@ -2966,7 +2957,7 @@ class ShapeDtypeStruct:
__slots__ = ["shape", "dtype", "named_shape"]
def __init__(self, shape, dtype, named_shape=None):
self.shape = shape
self.dtype = dtype if core.is_custom_eltype(dtype) else np.dtype(dtype)
self.dtype = np.dtype(dtype)
self.named_shape = {} if named_shape is None else dict(named_shape)
size = property(lambda self: prod(self.shape))

View File

@ -108,14 +108,9 @@ def apply_primitive(prim, *args, **params):
**params)
return compiled_fun(*args)
# TODO(phawkins,frostig,mattjj): update code referring to
# xla.apply_primitive to point here, or use simple_impl if that's why
# it is using apply_primitive to begin with
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
xla.apply_primitive = apply_primitive
def simple_impl(prim):
prim.def_impl(partial(apply_primitive, prim))
RuntimeToken = Any
class RuntimeTokenSet(threading.local):

View File

@ -1574,12 +1574,9 @@ def _pred_bcast_select_mhlo(
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
x_y_type = mlir.aval_to_ir_type(x_y_aval)
bcast_pred_type = ir.RankedTensorType.get(
x_y_type.shape, mlir.dtype_to_ir_type(np.dtype(np.bool_)))
bcast_pred = mhlo.BroadcastInDimOp(
bcast_pred_type, pred,
mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results
### fori_loop

View File

@ -1239,14 +1239,11 @@ def stop_gradient(x: T) -> T:
DeviceArray(0., dtype=float32, weak_type=True)
"""
def stop(x):
# only bind primitive on inexact dtypes, to avoid some staging
if core.has_custom_eltype(x):
return x
elif (dtypes.issubdtype(_dtype(x), np.floating) or
if (dtypes.issubdtype(_dtype(x), np.floating) or
dtypes.issubdtype(_dtype(x), np.complexfloating)):
return ad_util.stop_gradient_p.bind(x)
else:
return x
return x # only bind primitive on inexact dtypes, to avoid some staging
return tree_map(stop, x)
def reduce_precision(operand: Union[float, Array],
@ -1507,7 +1504,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
return result_dtype(*avals)
def broadcasting_shape_rule(name, *avals):
def _broadcasting_shape_rule(name, *avals):
shapes = [aval.shape for aval in avals if aval.shape]
if not shapes:
return ()
@ -1548,7 +1545,7 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):
def naryop(result_dtype, accepted_dtypes, name):
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
shape_rule = partial(broadcasting_shape_rule, name)
shape_rule = partial(_broadcasting_shape_rule, name)
weak_type_rule = partial(_naryop_weak_type_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name,
weak_type_rule=weak_type_rule)

View File

@ -527,7 +527,7 @@ view of the input.
@_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a, axes=None):
_stackable(a) or _check_arraylike("transpose", a)
_check_arraylike("transpose", a)
axes = np.arange(ndim(a))[::-1] if axes is None else axes
return lax.transpose(a, axes)
@ -5088,31 +5088,27 @@ _set_shaped_array_attributes(ShapedArray)
_set_shaped_array_attributes(DShapedArray)
def _set_device_array_base_attributes(device_array, include=None):
def _set_device_array_base_attributes(device_array):
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
def maybe_setattr(attr_name, target):
if not include or attr_name in include:
setattr(device_array, attr_name, target)
for operator_name, function in _operators.items():
maybe_setattr(f"__{operator_name}__", function)
setattr(device_array, f"__{operator_name}__", function)
for method_name in _nondiff_methods + _diff_methods:
maybe_setattr(method_name, globals()[method_name])
setattr(device_array, method_name, globals()[method_name])
# TODO(jakevdp): remove tile method after August 2022
maybe_setattr("tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
maybe_setattr("reshape", _reshape)
maybe_setattr("transpose", _transpose)
maybe_setattr("flatten", ravel)
maybe_setattr("flat", property(_notimplemented_flat))
maybe_setattr("T", property(transpose))
maybe_setattr("real", property(real))
maybe_setattr("imag", property(imag))
maybe_setattr("astype", _astype)
maybe_setattr("view", _view)
maybe_setattr("nbytes", property(_nbytes))
maybe_setattr("itemsize", property(_itemsize))
maybe_setattr("clip", _clip)
setattr(device_array, "tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
setattr(device_array, "reshape", _reshape)
setattr(device_array, "transpose", _transpose)
setattr(device_array, "flatten", ravel)
setattr(device_array, "flat", property(_notimplemented_flat))
setattr(device_array, "T", property(transpose))
setattr(device_array, "real", property(real))
setattr(device_array, "imag", property(imag))
setattr(device_array, "astype", _astype)
setattr(device_array, "view", _view)
setattr(device_array, "nbytes", property(_nbytes))
setattr(device_array, "itemsize", property(_itemsize))
setattr(device_array, "clip", _clip)
_set_device_array_base_attributes(device_array.DeviceArray)
_set_device_array_base_attributes(Array)

View File

@ -13,9 +13,8 @@
# limitations under the License.
import abc
from functools import partial
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
from typing import Callable, Iterator, NamedTuple, Sequence
import warnings
import numpy as np
@ -29,29 +28,23 @@ from jax.config import config
from jax.dtypes import float0
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import dispatch
from jax._src import dtypes
from jax._src.api import jit, vmap
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy import lax_numpy
from jax._src.numpy.lax_numpy import (
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
_expand_bool_indices, _register_stackable)
import jax._src.pretty_printer as pp
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
from jax._src.util import canonicalize_axis, prod
from jax._src.lib import gpu_prng
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
UINT_DTYPES = {
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
# -- PRNG implementation interface
# -- PRNG implementation interface --
class PRNGImpl(NamedTuple):
"""Specifies PRNG key shape and operations.
@ -75,22 +68,15 @@ class PRNGImpl(NamedTuple):
split: Callable
random_bits: Callable
fold_in: Callable
tag: str = '?'
def __hash__(self) -> int:
return hash(self.tag)
def __str__(self) -> str:
return self.tag
def pprint(self):
return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") +
return (pp.text(f"{self.__class__.__name__}:") +
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
]))))
# -- PRNG key arrays
# -- PRNG key arrays --
def _check_prng_key_data(impl, key_data: jnp.ndarray):
ndim = len(impl.key_shape)
@ -108,19 +94,8 @@ def _check_prng_key_data(impl, key_data: jnp.ndarray):
f"got dtype={key_data.dtype}")
class PRNGKeyArrayMeta(abc.ABCMeta):
"""Metaclass for overriding PRNGKeyArray isinstance checks."""
def __instancecheck__(self, instance):
try:
return (hasattr(instance, 'aval') and
isinstance(instance.aval, core.ShapedArray) and
type(instance.aval.dtype) is KeyTy)
except AttributeError:
super().__instancecheck__(instance)
class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
@tree_util.register_pytree_node_class
class PRNGKeyArray:
"""An array whose elements are PRNG keys.
This class lifts the definition of a PRNG, provided in the form of a
@ -135,30 +110,58 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
"""
impl: PRNGImpl
_base_array: jnp.ndarray
_keys: jnp.ndarray
def __init__(self, impl, key_data: Any):
assert not isinstance(key_data, core.Tracer)
_check_prng_key_data(impl, key_data)
def __init__(self, impl, key_data: jnp.ndarray):
# key_data might be a placeholder python `object` or `bool`
# instead of a jnp.ndarray due to tree_unflatten
if type(key_data) not in [object, bool]:
_check_prng_key_data(impl, key_data)
self.impl = impl
self._base_array = key_data
self._keys = key_data
def tree_flatten(self):
return (self._keys,), self.impl
# TODO(frostig): rename to unsafe_base_array, or just offer base_array attr?
def unsafe_raw_array(self):
"""Access the raw numerical array that carries underlying key data.
Returns:
A uint32 JAX array whose leading dimensions are ``self.shape``.
"""
return self._base_array
return self._keys
def block_until_ready(self):
_ = self._base_array.block_until_ready()
return self
@classmethod
def tree_unflatten(cls, impl, keys):
keys, = keys
return cls(impl, keys)
@property
def dtype(self):
# TODO(frostig): remove after deprecation window
if config.jax_enable_custom_prng:
raise AttributeError("'PRNGKeyArray' has no attribute 'dtype'")
else:
warnings.warn(
'deprecated `dtype` attribute of PRNG key arrays', FutureWarning)
return np.uint32
@property
def shape(self):
return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape)
# TODO(frostig): simplify once we always enable_custom_prng
if config.jax_enable_custom_prng:
return self._shape
else:
warnings.warn(
'deprecated `shape` attribute of PRNG key arrays. In a future version '
'of JAX this attribute will be removed or its value may change.',
FutureWarning)
return self._keys.shape
@property
def _shape(self):
base_ndim = len(self.impl.key_shape)
return self._keys.shape[:-base_ndim]
@property
def ndim(self):
@ -166,527 +169,74 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
return self._base_array.ndim == base_ndim
return self._keys.ndim == base_ndim
def __len__(self):
if self._is_scalar():
raise TypeError('len() of unsized object')
return len(self._base_array)
return len(self._keys)
def __iter__(self) -> Iterator['PRNGKeyArray']:
if self._is_scalar():
raise TypeError('iteration over a 0-d key array')
# TODO(frostig): we may want to avoid iteration by slicing because
# a very common use of iteration is `k1, k2 = split(key)`, and
# slicing/indexing may be trickier to track for linearity checking
# purposes. Maybe we can:
# * introduce an unpack primitive+traceable (also allow direct use)
# * unpack upfront into shape[0] many keyarray slices
# * return iter over these unpacked slices
# Whatever we do, we'll want to do it by overriding
# ShapedArray._iter when the eltype is KeyTy...
return (PRNGKeyArray(self.impl, k) for k in iter(self._base_array))
raise TypeError('iteration over a 0-d single PRNG key')
return (PRNGKeyArray(self.impl, k) for k in iter(self._keys))
# TODO(frostig): are all of the stackable methods below (reshape,
# concat, broadcast_to, expand_dims), and the stackable registration,
# still needed? If, with some work, none are needed, then do we want
# to remove stackables altogether? This may be the only application.
def __getitem__(self, idx) -> 'PRNGKeyArray':
base_ndim = len(self.impl.key_shape)
ndim = self._keys.ndim - base_ndim
indexable_shape = self.impl.key_shape[:ndim]
idx = _eliminate_deprecated_list_indexing(idx)
idx = _expand_bool_indices(idx, indexable_shape)
idx = _canonicalize_tuple_index(ndim, idx, array_name='PRNGKeyArray')
return PRNGKeyArray(self.impl, self._keys[idx])
# TODO(frostig): Remove? Overwritten below in particular
def reshape(self, newshape, order=None) -> 'PRNGKeyArray':
reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order)
return PRNGKeyArray(self.impl, reshaped_base)
def _fold_in(self, data: int) -> 'PRNGKeyArray':
return PRNGKeyArray(self.impl, self.impl.fold_in(self._keys, data))
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
return self.impl.random_bits(self._keys, bit_width, shape)
def _split(self, num: int) -> 'PRNGKeyArray':
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
def reshape(self, newshape, order=None):
reshaped_keys = jnp.reshape(self._keys, (*newshape, -1), order=order)
return PRNGKeyArray(self.impl, reshaped_keys)
def concatenate(self, key_arrs, axis, dtype=None):
if dtype is not None:
raise ValueError(
'dtype argument not supported for concatenating PRNGKeyArray')
raise ValueError('dtype argument not supported for concatenating PRNGKeyArray')
axis = canonicalize_axis(axis, self.ndim)
arrs = [self._base_array, *[k._base_array for k in key_arrs]]
arrs = [self._keys, *[k._keys for k in key_arrs]]
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
def broadcast_to(self, shape):
if jnp.ndim(shape) == 0:
shape = (shape,)
new_shape = (*shape, *self.impl.key_shape)
return PRNGKeyArray(
self.impl, jnp.broadcast_to(self._base_array, new_shape))
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))
def expand_dims(self, dimensions: Sequence[int]):
# follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
ndim_out = self.ndim + len(set(dimensions))
dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
return PRNGKeyArray(
self.impl, lax.expand_dims(self._base_array, dimensions))
return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))
def __repr__(self):
return (f'{self.__class__.__name__}[{self.impl.tag}]'
f' {{ {self._base_array} }}')
def pprint(self):
pp_keys = pp.text('shape = ') + pp.text(str(self.shape))
arr_shape = self._shape
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
pp_impl = pp.text('impl = ') + self.impl.pprint()
return str(pp.group(
pp.text('PRNGKeyArray:') +
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
# Hollow defs only for typing purposes, overwritten below
#
# TODO(frostig): there may be a better way to do this with
# `typing.type_check_only`.
@property
def T(self) -> 'PRNGKeyArray': assert False
def __getitem__(self, _) -> 'PRNGKeyArray': assert False
def ravel(self, *_, **__) -> 'PRNGKeyArray': assert False
def squeeze(self, *_, **__) -> 'PRNGKeyArray': assert False
def swapaxes(self, *_, **__) -> 'PRNGKeyArray': assert False
def take(self, *_, **__) -> 'PRNGKeyArray': assert False
def transpose(self, *_, **__) -> 'PRNGKeyArray': assert False
def flatten(self, *_, **__) -> 'PRNGKeyArray': assert False
lax_numpy._set_device_array_base_attributes(PRNGKeyArray, include=[
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
'transpose', 'flatten', 'T'])
lax_numpy._register_stackable(PRNGKeyArray)
# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
return random_seed(seed, impl=impl)
return PRNGKeyArray(impl, impl.seed(seed))
_register_stackable(PRNGKeyArray)
def keys_shaped_array(impl, shape):
return core.ShapedArray(shape, KeyTy(impl))
def keys_aval_to_base_arr_aval(keys_aval):
shape = (*keys_aval.shape, *keys_aval.dtype.impl.key_shape)
return core.ShapedArray(shape, np.dtype('uint32'))
def base_arr_shape_to_keys_shape(impl, base_arr_shape):
base_ndim = len(impl.key_shape)
return base_arr_shape[:-base_ndim]
class KeyTy:
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
def __init__(self, impl):
self.impl = impl
@property
def name(self) -> str:
return f'key<{self.impl.tag}>'
def __repr__(self) -> str:
return self.name
def __eq__(self, other):
return type(other) is KeyTy and self.impl is other.impl
def __hash__(self) -> int:
return hash((self.__class__, self.impl))
# handlers
@staticmethod
def physical_avals(aval):
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
jnp.dtype('uint32'))]
@staticmethod
def aval_to_ir_types(aval):
phys_aval, = KeyTy.physical_avals(aval)
return mlir.aval_to_ir_types(phys_aval)
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return PRNGKeyArray(aval.dtype.impl, buf)
return handler
@staticmethod
def sharded_result_handler(aval, sharding, indices):
phys_aval, = KeyTy.physical_avals(aval)
phys_handler_maker = pxla.local_result_handlers[
(core.ShapedArray, pxla.OutputType.ShardedDeviceArray)]
phys_handler = phys_handler_maker(phys_aval, sharding, indices)
def handler(bufs):
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
return handler
# eltype-polymorphic primitive lowering rules
@staticmethod
def empty_mlir(ctx):
aval_out, = ctx.avals_out
return mlir.ir_constants(np.empty(aval_out.dtype.impl.key_shape,
dtype=np.dtype('uint32')))
@staticmethod
def slice_mlir(ctx, x, start_indices, limit_indices, strides):
aval_out, = ctx.avals_out
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [0] * len(key_shape)
trailing_ones = [1] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
limit_indices = (*limit_indices, *key_shape)
strides = (*strides, *trailing_ones)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
mlir.dense_int_elements(strides)).results
@staticmethod
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
aval_out, = ctx.avals_out
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, *key_shape))
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
@staticmethod
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
key_shape = aval_out.dtype.impl.key_shape
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
start_indices = (*start_indices, *zeros)
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results
@staticmethod
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
if dyn_shape: raise NotImplementedError
aval_out, = ctx.avals_out
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)).results
@staticmethod
def transpose_mlir(ctx, x, *, permutation):
aval_out, = ctx.avals_out
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
perm = [*permutation, *trailing_dims]
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
@staticmethod
def gather_mlir(ctx, x, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
aval_x, aval_indices = ctx.avals_in
aval_y, = ctx.avals_out
key_shape = aval_x.dtype.impl.key_shape
trailing_offset_dims = [aval_y.ndim + i for i in range(len(key_shape))]
dimension_numbers = dimension_numbers._replace(
offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims))
slice_sizes = (*slice_sizes, *key_shape)
gather_lower = partial(
lax_internal.slicing._gather_lower, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
return mlir.delegate_lowering(
ctx, gather_lower, x, indices,
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices],
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
core.custom_eltypes.add(KeyTy)
core.pytype_aval_mappings[PRNGKeyArray] = (
lambda x: keys_shaped_array(x.impl, x.shape))
xla.pytype_aval_mappings[PRNGKeyArray] = (
lambda x: keys_shaped_array(x.impl, x.shape))
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
def device_put_key_array(x: PRNGKeyArray, device):
return dispatch.device_put(x.unsafe_raw_array(), device)
dispatch.device_put_handlers[PRNGKeyArray] = device_put_key_array
def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, mode):
arr = x.unsafe_raw_array()
return pxla.shard_arg_handlers[type(arr)](arr, devices, indices, mode)
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
def key_array_constant_handler(x, canonicalize_dtypes):
arr = x.unsafe_raw_array()
return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes)
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
# -- primitives
def iterated_vmap_unary(n, f):
for _ in range(n):
f = jax.vmap(f)
return f
# TODO(frostig): Revise the following two functions? These basically
# undo the singleton dimensions added by `batching.defbroadcasting`.
# It works, but introduces some possibly-redundant squeezes. Can we
# borrow from other broadcasting primitives instead?
def squeeze_vmap(f, left):
def squeeze_vmap_f(x, y):
if left:
x = jnp.squeeze(x, axis=0)
axes = (None, 0)
else:
y = jnp.squeeze(y, axis=0)
axes = (0, None)
return jax.vmap(f, in_axes=axes, out_axes=0)(x, y)
return squeeze_vmap_f
def iterated_vmap_binary_bcast(shape1, shape2, f):
ndim1, ndim2 = len(shape1), len(shape2)
if ndim1 == ndim2 == 0:
return f
if 0 in [ndim1, ndim2]:
if ndim1 == 0:
return lambda x, y: iterated_vmap_unary(ndim2, lambda y: f(x, y))(y)
else:
return lambda x, y: iterated_vmap_unary(ndim1, lambda x: f(x, y))(x)
assert len(shape1) == len(shape2)
for sz1, sz2 in reversed(zip(shape1, shape2)):
if sz1 == sz2:
f = jax.vmap(f, out_axes=0)
else:
assert sz1 == 1 or sz2 == 1, (sz1, sz2)
f = squeeze_vmap(f, sz1 == 1)
return f
def random_seed(seeds, impl):
# Avoid overflow error in X32 mode by first converting ints to int64.
# This breaks JIT invariance for large ints, but supports the common
# use-case of instantiating with Python hashes in X32 mode.
if isinstance(seeds, int):
seeds_arr = jnp.asarray(np.int64(seeds))
else:
seeds_arr = jnp.asarray(seeds)
return random_seed_p.bind(seeds_arr, impl=impl)
random_seed_p = core.Primitive('random_seed')
batching.defvectorized(random_seed_p)
@random_seed_p.def_abstract_eval
def random_seed_abstract_eval(seeds_aval, *, impl):
return keys_shaped_array(impl, seeds_aval.shape)
@random_seed_p.def_impl
def random_seed_impl(seeds, *, impl):
base_arr = random_seed_impl_base(seeds, impl=impl)
return PRNGKeyArray(impl, base_arr)
def random_seed_impl_base(seeds, *, impl):
seed = iterated_vmap_unary(seeds.ndim, impl.seed)
return seed(seeds)
def random_seed_lowering(ctx, seeds, *, impl):
aval, = ctx.avals_in
seed = iterated_vmap_unary(aval.ndim, impl.seed)
seed_lowering = mlir.lower_fun(seed, multiple_results=False)
return mlir.delegate_lowering(
ctx, seed_lowering, seeds,
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
mlir.register_lowering(random_seed_p, random_seed_lowering)
def random_split(keys, count):
return random_split_p.bind(keys, count=count)
random_split_p = core.Primitive('random_split')
batching.defvectorized(random_split_p)
@random_split_p.def_abstract_eval
def random_split_abstract_eval(keys_aval, *, count):
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count))
@random_split_p.def_impl
def random_split_impl(keys, *, count):
base_arr = random_split_impl_base(
keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
return PRNGKeyArray(keys.impl, base_arr)
def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
return split(base_arr)
def random_split_lowering(ctx, keys, *, count):
aval, = ctx.avals_in
impl = aval.dtype.impl
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, count))
split_lowering = mlir.lower_fun(split, multiple_results=False)
return mlir.delegate_lowering(
ctx, split_lowering, keys,
avals_in=[keys_aval_to_base_arr_aval(aval)],
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
mlir.register_lowering(random_split_p, random_split_lowering)
def random_fold_in(keys, msgs):
return random_fold_in_p.bind(keys, jnp.asarray(msgs))
random_fold_in_p = core.Primitive('random_fold_in')
batching.defbroadcasting(random_fold_in_p)
@random_fold_in_p.def_abstract_eval
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
shape = lax_internal.broadcasting_shape_rule(
'random_fold_in', keys_aval, msgs_aval)
named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval)
return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape)
@random_fold_in_p.def_impl
def random_fold_in_impl(keys, msgs):
base_arr = random_fold_in_impl_base(
keys.impl, keys.unsafe_raw_array(), msgs, keys.shape)
return PRNGKeyArray(keys.impl, base_arr)
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
fold_in = iterated_vmap_binary_bcast(
keys_shape, np.shape(msgs), impl.fold_in)
return fold_in(base_arr, msgs)
def random_fold_in_lowering(ctx, keys, msgs):
keys_aval, msgs_aval = ctx.avals_in
impl = keys_aval.dtype.impl
fold_in = iterated_vmap_binary_bcast(
keys_aval.shape, msgs_aval.shape, impl.fold_in)
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
return mlir.delegate_lowering(
ctx, fold_in_lowering, keys, msgs,
avals_in=[keys_aval_to_base_arr_aval(keys_aval), msgs_aval],
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
def random_bits(keys, bit_width, shape):
shape = core.as_named_shape(shape)
for name, size in shape.named_items:
# TODO(frostig,mattjj,apaszke): Is this real_size check necessary,
# and is it meant to raise a user-facing ValueError? Should it be
# an `assert` (or RuntimeError) instead? Why do we check it in
# calls to `random_bits` instead of a more common paralleism path?
real_size = lax.psum(1, name)
if real_size != size:
raise ValueError(f"The shape of axis {name} was specified as {size}, "
f"but it really is {real_size}")
axis_index = lax.axis_index(name)
keys = random_fold_in(keys, axis_index)
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional)
random_bits_p = core.Primitive('random_bits')
batching.defvectorized(random_bits_p)
@random_bits_p.def_abstract_eval
def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
out_shape = (*keys_aval.shape, *shape)
out_dtype = dtypes.dtype(f'uint{bit_width}')
return core.ShapedArray(out_shape, out_dtype)
@random_bits_p.def_impl
def random_bits_impl(keys, *, bit_width, shape):
return random_bits_impl_base(keys.impl, keys.unsafe_raw_array(), keys.ndim,
bit_width=bit_width, shape=shape)
def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):
bits = iterated_vmap_unary(
keys_ndim, lambda k: impl.random_bits(k, bit_width, shape))
return bits(base_arr)
def random_bits_lowering(ctx, keys, *, bit_width, shape):
aval, = ctx.avals_in
impl = aval.dtype.impl
bits = iterated_vmap_unary(
aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
ctx_new = ctx.replace(avals_in=[keys_aval_to_base_arr_aval(aval)])
out = bits_lowering(ctx_new, keys)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
mlir.register_lowering(random_bits_p, random_bits_lowering)
# The following wrap/unwrap primitives are at least a stopgap for
# backwards compatibility, namely when `config.jax_enable_custom_prng`
# is False. We need to convert key arrays to and from underlying
# uint32 base array, and we may need to do so under a jit. For
# example, we want to support:
#
# keys = jax.jit(random.split)(key)
#
# where `key` and `keys` are both acceptably old-style uint32 arrays
# so long as enable_custom_prng is False. The way we handle this is
# that `random.split` adapts the input/output by converting to/from
# key arrays across its call to `random_split`. So we rely on these
# wrap/unwrap casting primitives to allow that conversion under jit.
#
# We may want to keep both around for testing and debugging escape
# hatches. We can rename them `unsafe` for emphasis, and/or issue a
# warning on entry to the traceable.
#
# TODO(frostig): Consider removal once we always enable_custom_prng.
def random_wrap(base_arr, *, impl):
_check_prng_key_data(impl, base_arr)
return random_wrap_p.bind(base_arr, impl=impl)
random_wrap_p = core.Primitive('random_wrap')
@random_wrap_p.def_abstract_eval
def random_wrap_abstract_eval(base_arr_aval, *, impl):
shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape)
return keys_shaped_array(impl, shape)
@random_wrap_p.def_impl
def random_wrap_impl(base_arr, *, impl):
return PRNGKeyArray(impl, base_arr)
def random_wrap_lowering(ctx, base_arr, *, impl):
return [base_arr]
def random_wrap_batch_rule(batched_args, batch_dims, *, impl):
x, = batched_args
d, = batch_dims
x = batching.bdim_at_front(x, d, 1)
return random_wrap(x, impl=impl), 0
mlir.register_lowering(random_wrap_p, random_wrap_lowering)
batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
def random_unwrap(keys):
assert isinstance(keys, PRNGKeyArray)
return random_unwrap_p.bind(keys)
random_unwrap_p = core.Primitive('random_unwrap')
batching.defvectorized(random_unwrap_p)
@random_unwrap_p.def_abstract_eval
def random_unwrap_abstract_eval(keys_aval):
return keys_aval_to_base_arr_aval(keys_aval)
@random_unwrap_p.def_impl
def random_unwrap_impl(keys):
return keys.unsafe_raw_array()
def random_unwrap_lowering(ctx, keys):
return [keys]
mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
# -- threefry2x32 PRNG implementation
# -- threefry2x32 PRNG implementation --
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
@ -696,8 +246,8 @@ def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
return False
def threefry_seed(seed: jnp.ndarray) -> jnp.ndarray:
"""Create a single raw threefry PRNG key from an integer seed.
def threefry_seed(seed: int) -> jnp.ndarray:
"""Create a single raw threefry PRNG key given an integer seed.
Args:
seed: a 64- or 32-bit integer used as the value of the key.
@ -708,17 +258,24 @@ def threefry_seed(seed: jnp.ndarray) -> jnp.ndarray:
bit-casting to a pair of uint32 values (or from a 32-bit seed by
first padding out with zeros).
"""
if seed.shape:
# Avoid overflowerror in X32 mode by first converting ints to int64.
# This breaks JIT invariance for large ints, but supports the common
# use-case of instantiating with Python hashes in X32 mode.
if isinstance(seed, int):
seed_arr = jnp.asarray(np.int64(seed))
else:
seed_arr = jnp.asarray(seed)
if seed_arr.shape:
raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
if not np.issubdtype(seed.dtype, np.integer):
if not np.issubdtype(seed_arr.dtype, np.integer):
raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
k1 = convert(
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
with jax.numpy_dtype_promotion('standard'):
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
# inputs. We should avoid this.
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
return lax.concatenate([k1, k2], 0)
@ -748,7 +305,7 @@ def _threefry2x32_abstract_eval(*args):
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
.format(args))
if all(isinstance(arg, core.ShapedArray) for arg in args):
shape = lax_internal.broadcasting_shape_rule(*args)
shape = lax_internal._broadcasting_shape_rule(*args)
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
else:
@ -910,8 +467,7 @@ def _threefry_split(key, num) -> jnp.ndarray:
return lax.reshape(threefry_2x32(key, counts), (num, 2))
def threefry_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
assert not data.shape
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
return _threefry_fold_in(key, jnp.uint32(data))
@partial(jit, inline=True)
@ -926,7 +482,15 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
raise TypeError("threefry_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
size = prod(shape)
shape = core.as_named_shape(shape)
for name, size in shape.named_items:
real_size = lax.psum(1, name)
if real_size != size:
raise ValueError(f"The shape of axis {name} was specified as {size}, "
f"but it really is {real_size}")
axis_index = lax.axis_index(name)
key = threefry_fold_in(key, axis_index)
size = prod(shape.positional)
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
# polymorphism
max_count, r = divmod(bit_width * size, 32)
@ -973,11 +537,10 @@ threefry_prng_impl = PRNGImpl(
seed=threefry_seed,
split=threefry_split,
random_bits=threefry_random_bits,
fold_in=threefry_fold_in,
tag='fry')
fold_in=threefry_fold_in)
# -- RngBitGenerator PRNG implementation
# -- RngBitGenerator PRNG implementation --
# This code is experimental!
# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
@ -985,16 +548,14 @@ threefry_prng_impl = PRNGImpl(
# stable/deterministic across backends or compiler versions. Correspondingly, we
# reserve the right to change any of these implementations at any time!
def _rbg_seed(seed: jnp.ndarray) -> jnp.ndarray:
assert not seed.shape
def _rbg_seed(seed: int) -> jnp.ndarray:
halfkey = threefry_seed(seed)
return jnp.concatenate([halfkey, halfkey])
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
def _rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
assert not data.shape
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
@ -1011,16 +572,14 @@ rbg_prng_impl = PRNGImpl(
seed=_rbg_seed,
split=_rbg_split,
random_bits=_rbg_random_bits,
fold_in=_rbg_fold_in,
tag='rbg')
fold_in=_rbg_fold_in)
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
# treat 10 iterations of random bits as a 'hash function'
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
return keys[::10]
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
assert not data.shape
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
return key ^ random_bits[-1]
@ -1029,5 +588,4 @@ unsafe_rbg_prng_impl = PRNGImpl(
seed=_rbg_seed,
split=_unsafe_rbg_split,
random_bits=_rbg_random_bits,
fold_in=_unsafe_rbg_fold_in,
tag='urbg')
fold_in=_unsafe_rbg_fold_in)

View File

@ -59,10 +59,9 @@ _lax_const = lax_internal._const
def _isnan(x):
return lax.ne(x, x)
def _check_prng_key(key):
# TODO(frostig): remove once we always enable_custom_prng
if isinstance(key, prng.PRNGKeyArray):
if type(key) is prng.PRNGKeyArray:
return key, False
elif _arraylike(key):
if config.jax_enable_custom_prng:
@ -70,23 +69,21 @@ def _check_prng_key(key):
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
return prng.random_wrap(key, impl=default_prng_impl()), True
return prng.PRNGKeyArray(default_prng_impl(), key), True
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')
def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng
assert isinstance(key, prng.PRNGKeyArray)
assert type(key) is prng.PRNGKeyArray, type(key)
if config.jax_enable_custom_prng:
return key
else:
return prng.random_unwrap(key) if was_wrapped else key
return key.unsafe_raw_array() if was_wrapped else key
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
assert isinstance(key, prng.PRNGKeyArray)
return prng.random_bits(key, bit_width=bit_width, shape=shape)
key, _ = _check_prng_key(key)
return key._random_bits(bit_width, shape)
PRNG_IMPLS = {
@ -161,8 +158,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray:
def _fold_in(key: KeyArray, data: int) -> KeyArray:
# Alternative to fold_in() to use within random samplers.
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
assert isinstance(key, prng.PRNGKeyArray)
return prng.random_fold_in(key, jnp.uint32(data))
return key._fold_in(jnp.uint32(data))
def fold_in(key: KeyArray, data: int) -> KeyArray:
"""Folds in data to a PRNG key to form a new PRNG key.
@ -180,10 +176,8 @@ def fold_in(key: KeyArray, data: int) -> KeyArray:
def _split(key: KeyArray, num: int = 2) -> KeyArray:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
# to always enable_custom_prng
assert isinstance(key, prng.PRNGKeyArray)
return prng.random_split(key, count=num)
# TODO(frostig): remove and use split() once we always enable_custom_prng
return key._split(num)
def split(key: KeyArray, num: int = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
@ -972,7 +966,8 @@ def _gamma_one(key: KeyArray, alpha, log_space):
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
def _gamma_grad(sample, a, *, log_space):
def _gamma_grad(sample, a, *, prng_impl, log_space):
del prng_impl # unused
samples = jnp.reshape(sample, -1)
alphas = jnp.reshape(a, -1)
if log_space:
@ -992,38 +987,34 @@ def _gamma_grad(sample, a, *, log_space):
grads = vmap(gamma_grad)(alphas, samples)
return grads.reshape(np.shape(a))
def _gamma_impl(key, a, *, log_space, use_vmap=False):
# split key to match the shape of a
def _gamma_impl(raw_key, a, *, prng_impl, log_space, use_vmap=False):
a_shape = jnp.shape(a)
split_count = prod(a_shape[key.ndim:])
keys = key.flatten()
keys = vmap(_split, in_axes=(0, None))(keys, split_count)
keys = keys.flatten()
alphas = a.flatten()
# split key to match the shape of a
key_ndim = len(raw_key.shape) - len(prng_impl.key_shape)
key = raw_key.reshape((-1,) + prng_impl.key_shape)
key = vmap(prng_impl.split, in_axes=(0, None))(key, prod(a_shape[key_ndim:]))
keys = key.reshape((-1,) + prng_impl.key_shape)
keys = prng.PRNGKeyArray(prng_impl, keys)
alphas = jnp.reshape(a, -1)
if use_vmap:
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
else:
samples = lax.map(
lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))
samples = lax.map(lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))
return jnp.reshape(samples, a_shape)
def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
k, a = batched_args
bk, ba = batch_dims
size = next(
t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
k = batching.bdim_at_front(k, bk, size)
a = batching.bdim_at_front(a, ba, size)
return random_gamma_p.bind(k, a, log_space=log_space), 0
def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl, log_space):
k, a = batched_args
bk, ba = batch_dims
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
k = batching.bdim_at_front(k, bk, size)
a = batching.bdim_at_front(a, ba, size)
return random_gamma_p.bind(k, a, prng_impl=prng_impl, log_space=log_space), 0
random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
ad.defjvp2(
random_gamma_p, None,
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
partial(_gamma_impl, use_vmap=True),
multiple_results=False))
@ -1117,7 +1108,7 @@ def _gamma(key, a, shape, dtype, log_space=False):
a = lax.convert_element_type(a, dtype)
if np.shape(a) != shape:
a = jnp.broadcast_to(a, shape)
return random_gamma_p.bind(key, a, log_space=log_space)
return random_gamma_p.bind(key.unsafe_raw_array(), a, prng_impl=key.impl, log_space=log_space)
@partial(jit, static_argnums=(2, 3, 4), inline=True)
@ -1226,13 +1217,10 @@ def poisson(key: KeyArray,
``shape is not None, or else by ``lam.shape``.
"""
key, _ = _check_prng_key(key)
# TODO(frostig): generalize underlying poisson implementation and
# remove this check (and use of core.get_aval)
key_impl = core.get_aval(key).dtype.impl
if key_impl is not prng.threefry_prng_impl:
if key.impl is not prng.threefry_prng_impl:
raise NotImplementedError(
'`poisson` is only implemented for the threefry2x32 RNG, '
f'not {key_impl}')
f'not {key.impl}')
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
@ -1631,7 +1619,6 @@ def orthogonal(
Returns:
A random array of shape `(*shape, n, n)` and specified dtype.
"""
key, _ = _check_prng_key(key)
_check_shape("orthogonal", shape)
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = normal(key, (*shape, n, n), dtype)
@ -1657,7 +1644,6 @@ def generalized_normal(
Returns:
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key(key)
_check_shape("generalized_normal", shape)
keys = split(key)
g = gamma(keys[0], 1/p, shape, dtype)
@ -1686,10 +1672,9 @@ def ball(
Returns:
A random array of shape `(*shape, d)` and specified dtype.
"""
key, _ = _check_prng_key(key)
_check_shape("ball", shape)
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
k1, k2 = split(key)
g = generalized_normal(k1, p, (*shape, d), dtype)
e = exponential(k2, shape, dtype)
keys = split(key)
g = generalized_normal(keys[0], p, (*shape, d), dtype)
e = exponential(keys[1], shape, dtype)
return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None]

View File

@ -160,15 +160,7 @@ def count_device_put():
def device_put_and_count(*args, **kwargs):
count[0] += 1
# device_put handlers might call `dispatch.device_put` (e.g. on an
# underlying payload or several). We only want to count these
# recursive puts once, so we skip counting more than the outermost
# one in such a call stack.
dispatch.device_put = device_put
try:
return device_put(*args, **kwargs)
finally:
dispatch.device_put = device_put_and_count
return device_put(*args, **kwargs)
dispatch.device_put = device_put_and_count
try:

View File

@ -1190,25 +1190,8 @@ def concrete_or_error(force: Any, val: Any, context=""):
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
custom_eltypes: Set[Any] = set()
# TODO(frostig): update inliners of the four functions below to call them
def has_custom_eltype(x: Any):
return aval_has_custom_eltype(get_aval(x))
def eltype(x: Any):
return aval_eltype(get_aval(x))
def aval_has_custom_eltype(aval: UnshapedArray):
return is_custom_eltype(aval.dtype)
def aval_eltype(aval: UnshapedArray):
return aval.dtype
def is_custom_eltype(eltype):
return type(eltype) in custom_eltypes
def _short_dtype_name(dtype) -> str:
if type(dtype) in custom_eltypes:
return str(dtype)
@ -1421,14 +1404,8 @@ class ConcreteArray(ShapedArray):
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
# TODO(frostig,mattjj): rename to primal_eltype_to_tangent_eltype
def primal_dtype_to_tangent_dtype(primal_dtype):
# TODO(frostig,mattjj): determines that all custom eltypes have
# float0 tangent type, which works fine for all our current custom
# eltype applications. We may some day want to delegate this
# decision to the eltype.
if (type(primal_dtype) in custom_eltypes or
not dtypes.issubdtype(primal_dtype, np.inexact)):
if not dtypes.issubdtype(primal_dtype, np.inexact):
return dtypes.float0
else:
return primal_dtype

View File

@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf
*Last generated on (YYYY-MM-DD): 2022-08-17*
*Last generated on (YYYY-MM-DD): 2022-08-10*
This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
@ -136,7 +136,6 @@ with jax2tf. The following table lists that cases when this does not quite hold:
| max | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
| min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
| pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
| random_split | Returns JAX key arrays, so compare underlying base array | all | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_add | Numeric comparison disabled: Large deviations on TPU for enable_xla=False | float16, float32 | tpu | compiled, eager, graph |
| sort | Numeric comparison disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph |
| svd | custom numeric comparison when compute_uv on CPU/GPU | all | cpu, gpu | compiled, eager, graph |

View File

@ -698,7 +698,6 @@ def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
def _convert_jax_impl(jax_impl: Callable, *,
multiple_results=True,
with_physical_avals=False,
extra_name_stack: Optional[str] = None) -> Callable:
"""Convert the JAX implementation of a primitive.
@ -719,10 +718,6 @@ def _convert_jax_impl(jax_impl: Callable, *,
_out_aval: core.ShapedArray,
**kwargs) -> Sequence[TfVal]:
if with_physical_avals:
_in_avals = map(_jax_physical_aval, _in_avals)
_out_aval = _jax_physical_aval(_out_aval)
# We wrap the jax_impl under _interpret_fun to abstract the TF values
# from jax_impl and turn them into JAX abstract values.
def jax_impl_jax_args(*jax_args):
@ -765,31 +760,8 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal,
return tuple(v for v, _ in out_with_avals)
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
"""Converts JAX avals from logical to physical, if relevant.
JAX might have avals whose logical vs physical shape/dtype may
differ, and only the physical view is expected to possibly
relate to TF. TF impl rules should operate on the physical form.
A JAX logical aval might even correspond, in principle, to several
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""
if type(aval.dtype) in core.custom_eltypes:
aval, = aval.dtype.physical_avals(aval)
return aval
return aval
def _jax_physical_dtype(dtype):
# assuming () is a fine stand-in shape
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
def _aval_to_tf_shape(aval: core.ShapedArray) -> Tuple[Optional[int], ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
aval = _jax_physical_aval(aval)
return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d,
aval.shape)) # type: ignore[attr-defined]
@ -799,12 +771,6 @@ _tf_np_dtype_for_float0 = np.int32
def _to_tf_dtype(jax_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
try:
jax_dtype = _jax_physical_dtype(jax_dtype)
except TypeError:
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
# tf.float32), so there is no physical dtype anyway
pass
if jax_dtype == dtypes.float0:
jax_dtype = _tf_np_dtype_for_float0
return tf.dtypes.as_dtype(jax_dtype)
@ -868,13 +834,9 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
return tf_val, jax_dtype
# TODO(frostig,mattjj): rename dtype argument to eltype, for now just
# being consistent.
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
if dtype is not None:
shape = _jax_physical_aval(core.ShapedArray(shape, dtype)).shape
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
eval_shape, dim_avals = shape_poly.get_shape_evaluator(dim_vars, shape)
shape_values, _ = util.unzip2(_interpret_fun(lu.wrap_init(eval_shape),
@ -898,13 +860,11 @@ def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize
class TensorFlowTracer(core.Tracer):
"""Tracer class that boxes a TF value and a JAX abstract value.
In addition to the TF value we carry the JAX abstract value because
there are some cases when it cannot be recovered from the value:
when we are converting with polymorphic shapes or when the JAX aval
has a custom element type. In these cases the shape of the value may
have dimensions set to `None`, or it may only correspond to the JAX
"physical" (TF/lowering-compatible) shape, so the JAX abstract value
may contain more precise information.
In addition to the TF value we carry the JAX abstract value because there is
one case when it cannot be recovered from the value: when we are converting
with polymorphic shapes, in which case the shape of the value may have
dimensions set to `None`, which the JAX abstract value may contain more
precise information.
When the value has a partially-known shape, the dimensions marked as `None`
must correspond to non-constant dimensions in the abstract value.
@ -919,34 +879,32 @@ class TensorFlowTracer(core.Tracer):
aval: core.AbstractValue):
self._trace = trace
self._aval = aval
phys_aval = _jax_physical_aval(self._aval) # type: ignore[arg-type]
if isinstance(val, (tf.Tensor, tf.Variable)):
val_shape = val.shape
if config.jax_enable_checks:
assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than val_shape={val_shape}"
assert len(self._aval.shape) == len(val_shape), f"_aval.shape={self._aval.shape} different rank than val_shape={val_shape}"
# To compare types, we must handle float0 in JAX and x64 in TF
if phys_aval.dtype == dtypes.float0:
assert _to_tf_dtype(phys_aval.dtype) == val.dtype, f"expected {phys_aval.dtype} == {val.dtype}"
if self._aval.dtype == dtypes.float0:
assert _to_tf_dtype(self._aval.dtype) == val.dtype, f"expected {self._aval.dtype} == {val.dtype}"
else:
assert phys_aval.dtype == _to_jax_dtype(val.dtype), f"expected {phys_aval.dtype} == {val.dtype}"
assert self._aval.dtype == _to_jax_dtype(val.dtype), f"expected {self._aval.dtype} == {val.dtype}"
for aval_dim, val_dim in zip(phys_aval.shape, val_shape): # type: ignore[attr-defined]
for aval_dim, val_dim in zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
assert shape_poly.is_poly_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
assert shape_poly.is_poly_dim(aval_dim), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
elif not shape_poly.is_poly_dim(aval_dim):
assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
else:
# We have a TF value with known shape, and the abstract shape is a shape variable.
try:
aval_int = int(_eval_shape([aval_dim])) # type: ignore
except (TypeError, KeyError):
continue
assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
assert aval_int == val_dim, f"expected {self._aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
self.val = _tfval_to_tensor_jax_dtype(val,
phys_aval.dtype,
self._aval.dtype,
memoize_constants=True)[0] # type: ignore[attr-defined]
@property
@ -1666,7 +1624,7 @@ tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
def _clamp(minval, operand, maxval, *, _in_avals, _out_aval):
# The below permits mirroring the behavior of JAX when maxval < minval
op_shape_tf_val = _eval_shape(_in_avals[1].shape, _in_avals[1].dtype)
op_shape_tf_val = _eval_shape(_in_avals[1].shape)
maxval = tf.broadcast_to(maxval, op_shape_tf_val)
minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval)
return tf.clip_by_value(operand, minval, maxval)
@ -1821,12 +1779,11 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
# bcast_dims must be strictly increasing.
# len(bcast_dims) == len(operand.shape)
op_shape = _in_avals[0].shape
dtype = _in_avals[0].dtype
add_1s_shape = [1] * len(shape)
for i, broadcast_dim_i in enumerate(broadcast_dimensions):
add_1s_shape[broadcast_dim_i] = op_shape[i]
with_1s = tf.reshape(operand, _eval_shape(add_1s_shape, dtype=dtype))
return tf.broadcast_to(with_1s, _eval_shape(shape, dtype=dtype))
with_1s = tf.reshape(operand, _eval_shape(add_1s_shape))
return tf.broadcast_to(with_1s, _eval_shape(shape))
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
@ -1841,21 +1798,20 @@ def _empty(*, eltype):
tf_impl[lax_internal.empty_p] = _empty
def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval):
def _reshape(operand, *, new_sizes, dimensions):
if dimensions is None:
dimensions = tf.range(tf.rank(operand))
new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype)
new_sizes_tf = _eval_shape(new_sizes)
return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf)
tf_impl_with_avals[lax.reshape_p] = _reshape
tf_impl[lax.reshape_p] = _reshape
def _squeeze(operand, *, dimensions, _in_avals, _out_aval):
op_aval = _jax_physical_aval(_in_avals[0])
op_shape = op_aval.shape
op_shape = _in_avals[0].shape
new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions)
new_shape_tf = _eval_shape(new_shape, op_aval.dtype)
new_shape_tf = _eval_shape(new_shape)
return tf.reshape(operand, new_shape_tf)
@ -2286,82 +2242,6 @@ def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval):
def impl_wrapper(seeds: TfVal, *, impl):
return jax._src.prng.random_seed_impl_base(seeds, impl=impl)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_seed")
return converted_impl(
seeds, impl=impl, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[jax._src.prng.random_seed_p] = _random_seed_impl
def _random_split_impl(keys: TfVal, *, count, _in_avals, _out_aval):
keys_aval, = _in_avals
def impl_wrapper(keys: TfVal, *, count):
return jax._src.prng.random_split_impl_base(
keys_aval.dtype.impl, keys, keys_aval.ndim, count=count)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_split")
return converted_impl(
keys, count=count, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[jax._src.prng.random_split_p] = _random_split_impl
def _random_fold_in_impl(keys: TfVal, msgs: TfVal, *, _in_avals, _out_aval):
keys_aval, _ = _in_avals
def impl_wrapper(keys: TfVal, msgs: TfVal):
return jax._src.prng.random_fold_in_impl_base(
keys_aval.dtype.impl, keys, msgs, keys_aval.shape)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_fold_in")
return converted_impl(
keys, msgs, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[jax._src.prng.random_fold_in_p] = _random_fold_in_impl
def _random_bits_impl(keys: TfVal, *, bit_width, shape, _in_avals, _out_aval):
keys_aval, = _in_avals
def impl_wrapper(keys: TfVal, **kwargs):
return jax._src.prng.random_bits_impl_base(
keys_aval.dtype.impl, keys, keys_aval.ndim,
bit_width=bit_width, shape=shape)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_bits")
return converted_impl(keys, bit_width=bit_width, shape=shape,
_in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[jax._src.prng.random_bits_p] = _random_bits_impl
def _random_wrap_impl(base_arr: TfVal, *, impl, _in_avals, _out_aval):
return base_arr
tf_impl_with_avals[jax._src.prng.random_wrap_p] = _random_wrap_impl
def _random_unwrap_impl(keys: TfVal, *, _in_avals, _out_aval):
return keys
tf_impl_with_avals[jax._src.prng.random_unwrap_p] = _random_unwrap_impl
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
res = _convert_jax_impl(
partial(jax._src.prng._threefry2x32_lowering, use_rolled_loops=False),
@ -2449,7 +2329,7 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shap
start_indices = _maybe_cast_to_int64(start_indices)
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
slice_sizes_tf = _eval_shape(slice_sizes, _in_avals[0].dtype)
slice_sizes_tf = _eval_shape(slice_sizes)
out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf,
indices_are_sorted)
out.set_shape(_aval_to_tf_shape(_out_aval))
@ -2482,7 +2362,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
start_indices = _maybe_cast_to_int64(tf.stack(start_indices))
slice_sizes_tf = _eval_shape(slice_sizes, dtype=_in_avals[0].dtype)
slice_sizes_tf = _eval_shape(slice_sizes)
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes_tf)
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
@ -2643,7 +2523,7 @@ def _batched_cond_while(*args: TfVal, cond_nconsts: int,
def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal:
pred_b_bcast = _broadcast_in_dim(
pred_b,
shape=_jax_physical_aval(c_aval).shape, # a JAX shape
shape=c_aval.shape, # a JAX shape
broadcast_dimensions=list(range(len(pred_b.shape))),
_in_avals=cond_jaxpr.out_avals,
_out_aval=core.ShapedArray(c_aval.shape, np.bool_))

View File

@ -16,7 +16,6 @@
import itertools
from typing import Any, Callable, Optional, Sequence, Union
import jax
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
@ -131,7 +130,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
"random_categorical", "random_uniform", "random_randint",
"random_categorical", "random_split", "random_uniform", "random_randint",
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
@ -154,18 +153,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
dtypes=[np.complex64, np.complex128],
custom_assert=custom_assert)
@classmethod
def random_seed(cls, handess: primitive_harness.Harness):
return [custom_random_keys_output()]
@classmethod
def random_split(cls, handess: primitive_harness.Harness):
return [custom_random_keys_output()]
@classmethod
def random_fold_in(cls, handess: primitive_harness.Harness):
return [custom_random_keys_output()]
@classmethod
def acos(cls, harness: primitive_harness.Harness):
return [
@ -1283,24 +1270,6 @@ def custom_numeric(
enabled=enabled,
tol=tol)
def custom_random_keys_output():
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
# TODO(frostig): Don't need this conditional once we always
# enable_custom_prng. We can even assert the isinstance instead.
def unwrap_keys(keys):
if isinstance(keys, jax.random.KeyArray):
return jax._src.prng.random_unwrap(keys)
else:
return keys
tst.assertAllClose(unwrap_keys(result_jax), result_tf,
atol=tol, rtol=tol, err_msg=err_msg)
return custom_numeric(
description="Returns JAX key arrays, so compare underlying base array",
modes=("eager", "graph", "compiled"),
custom_assert=custom_assert)
def missing_tf_kernel(*,
description="op not defined for dtype",

View File

@ -195,9 +195,6 @@ _constant_handlers : Dict[type, ConstantHandler] = {}
def register_constant_handler(type_: type, handler_fun: ConstantHandler):
_constant_handlers[type_] = handler_fun
def get_constant_handler(type_: type) -> ConstantHandler:
return _constant_handlers[type_]
def ir_constants(val: Any,
canonicalize_types: bool = True) -> Sequence[ir.Value]:
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
@ -1099,7 +1096,6 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
else:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
out, tokens = jaxpr_subcomp(
ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts),

View File

@ -570,11 +570,8 @@ local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRes
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
if type(aval.dtype) in core.custom_eltypes:
return aval.dtype.sharded_result_handler(aval, sharding, indices)
else:
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler

View File

@ -119,23 +119,13 @@ Here is a short summary:
NOTE: RNGs are currently identical across shardings because the random value
is first materialized replicated on each device and then the slice that each
device needs is later sliced out.
"""
# TODO(frostig): replace with KeyArray from jax._src.random once we
# always enable_custom_prng
from jax._src.prng import PRNGKeyArray
# TODO(frostig): remove this typechecking workaround. Our move away
# from PRNGKeyArray as a pytree led to Python typechecker breakages in
# several downstream annotations (e.g. annotations in jax-dependent
# libraries that are violated by their callers). It may be that the
# pytree registration decorator invalidated the checks. This will be
# easier to handle after we always enable_custom_prng.
import typing
if typing.TYPE_CHECKING:
KeyArray = typing.Any
else:
# TODO(frostig): replace with KeyArray from jax._src.random once we
# always enable_custom_prng
KeyArray = PRNGKeyArray
KeyArray = PRNGKeyArray
from jax._src.random import (
PRNGKey as PRNGKey,

View File

@ -764,11 +764,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_omnistaging(self):
# See https://github.com/google/jax/issues/5206
# TODO(frostig): remove `wrap` once we always enable_custom_prng
def wrap(arr):
# TODO(frostig): remove once we always enable_custom_prng
def _prng_key_as_array(key):
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
# TODO(frostig): remove once we always enable_custom_prng
def _array_as_prng_key(arr):
arr = np.array(arr, dtype=np.uint32)
if config.jax_enable_custom_prng:
return jax._src.prng.random_wrap(arr, impl=jax.random.default_prng_impl())
return jax._src.prng.PRNGKeyArray(
jax._src.prng.threefry_prng_impl, arr)
else:
return arr
@ -779,11 +784,10 @@ class CPPJitTest(jtu.BufferDonationTestCase):
key_list[0] = key
return jax.random.normal(subkey, ())
key_list[0] = wrap([2384771982, 3928867769])
key_list[0] = _array_as_prng_key([2384771982, 3928867769])
init()
self.jit(init)()
self.assertIsInstance(key_list[0], core.Tracer)
del key_list[0]
self.assertIsInstance(_prng_key_as_array(key_list[0]), core.Tracer)
def test_jit_wrapped_attributes(self):
def f(x: int) -> int:

View File

@ -919,7 +919,6 @@ class BatchingTest(jtu.JaxTestCase):
_ = hessian(f)(R) # don't crash on UnshapedArray
def testIssue489(self):
# https://github.com/google/jax/issues/489
def f(key):
def body_fn(uk):
key = uk[1]

View File

@ -37,9 +37,7 @@ from jax._src import dtypes
from jax._src import test_util as jtu
from jax import vmap
from jax.interpreters import xla
import jax._src.random
from jax._src import prng as prng_internal
from jax.config import config
config.parse_flags_with_absl()
@ -54,11 +52,6 @@ def _prng_key_as_array(key):
# TODO(frostig): remove once we upgrade to always enable_custom_prng
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
def _maybe_unwrap(key):
# TODO(frostig): remove once we upgrade to always enable_custom_prng
unwrap = jax._src.prng.random_unwrap
return unwrap(key) if config.jax_enable_custom_prng else key
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
('rbg', prng.rbg_prng_impl),
@ -233,28 +226,22 @@ class PrngTest(jtu.JaxTestCase):
def testRngRandomBits(self):
# Test specific outputs to ensure consistent random values between JAX versions.
# TODO(frostig): remove once we always enable_custom_prng
def random_bits(key, *args):
key, _ = jax._src.random._check_prng_key(key)
return jax._src.random._random_bits(key, *args)
key = random.PRNGKey(1701)
bits8 = random_bits(key, 8, (3,))
bits8 = jax._src.random._random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8)
bits16 = random_bits(key, 16, (3,))
bits16 = jax._src.random._random_bits(key, 16, (3,))
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16)
bits32 = random_bits(key, 32, (3,))
bits32 = jax._src.random._random_bits(key, 32, (3,))
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
self.assertArraysEqual(bits32, expected32)
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random_bits(key, 64, (3,))
bits64 = jax._src.random._random_bits(key, 64, (3,))
if config.x64_enabled:
expected64 = np.array([3982329540505020460, 16822122385914693683,
7882654074788531506], dtype=np.uint64)
@ -270,28 +257,23 @@ class PrngTest(jtu.JaxTestCase):
# on every PRNG implementation. Instead of values, only checks
# that shapes/dtypes are as expected.
# TODO(frostig): remove once we always enable_custom_prng
def random_bits(key, *args):
key, _ = jax._src.random._check_prng_key(key)
return jax._src.random._random_bits(key, *args)
with jax.default_prng_impl(prng_name):
key = random.PRNGKey(1701)
bits8 = random_bits(key, 8, (3,))
bits8 = jax._src.random._random_bits(key, 8, (3,))
self.assertEqual(bits8.shape, (3,))
self.assertEqual(bits8.dtype, np.dtype('uint8'))
bits16 = random_bits(key, 16, (3,))
bits16 = jax._src.random._random_bits(key, 16, (3,))
self.assertEqual(bits16.shape, (3,))
self.assertEqual(bits16.dtype, np.dtype('uint16'))
bits32 = random_bits(key, 32, (3,))
bits32 = jax._src.random._random_bits(key, 32, (3,))
self.assertEqual(bits32.shape, (3,))
self.assertEqual(bits32.dtype, np.dtype('uint32'))
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
bits64 = random_bits(key, 64, (3,))
bits64 = jax._src.random._random_bits(key, 64, (3,))
expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32')
self.assertEqual(bits64.shape, (3,))
self.assertEqual(bits64.dtype, expected_dtype)
@ -299,16 +281,11 @@ class PrngTest(jtu.JaxTestCase):
def testRngRandomBitsViewProperty(self):
# TODO: add 64-bit if it ever supports this property.
# TODO: will this property hold across endian-ness?
# TODO(frostig): remove once we always enable_custom_prng
def random_bits(key, *args):
key, _ = jax._src.random._check_prng_key(key)
return jax._src.random._random_bits(key, *args)
N = 10
key = random.PRNGKey(1701)
nbits = [8, 16, 32]
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
rand_bits = [jax._src.random._random_bits(key, n, (N * 64 // n,))
for n in nbits]
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])
@ -453,7 +430,8 @@ class PrngTest(jtu.JaxTestCase):
key = random.PRNGKey(1701)
self.assertEqual(key.shape, ())
self.assertEqual(key[None].shape, (1,))
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: key[0])
def test_key_array_indexing_nd(self):
if not config.jax_enable_custom_prng:
@ -472,9 +450,9 @@ class PrngTest(jtu.JaxTestCase):
(1,) * 6)
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
self.assertRaisesRegex(IndexError, 'Too many indices.*',
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: keys[0, 1, 2])
self.assertRaisesRegex(IndexError, 'Too many indices.*',
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: keys[0, 1, None, 2])
@ -1427,18 +1405,18 @@ class LaxRandomTest(jtu.JaxTestCase):
jax.eval_shape(f, 0) # doesn't error
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_seed={seed}_type={type_}", "seed": seed, "type_": type_}
for type_ in ["int", "np.array", "jnp.array"]
{"testcase_name": f"_seed={seed}_type={type}", "seed": seed, "type": type}
for type in ["int", "np.array", "jnp.array"]
for seed in [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)]))
def test_prng_jit_invariance(self, seed, type_):
if type_ == "int" and seed == (1 << 64) - 1:
def test_prng_jit_invariance(self, seed, type):
if type == "int" and seed == (1 << 64) - 1:
self.skipTest("Expected failure: Python int too large.")
if not config.x64_enabled and seed > np.iinfo(np.int32).max:
self.skipTest("Expected failure: Python int too large.")
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
args_maker = lambda: [type_(seed)]
f = lambda s: _maybe_unwrap(self.seed_prng(s))
self._CompileAndCheck(f, args_maker)
type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type]
args_maker = lambda: [type(seed)]
make_prng = lambda seed: _prng_key_as_array(self.seed_prng(seed))
self._CompileAndCheck(make_prng, args_maker)
def test_prng_errors(self):
seed = np.iinfo(np.int64).max + 1
@ -1448,7 +1426,7 @@ class LaxRandomTest(jtu.JaxTestCase):
jax.jit(self.seed_prng)(seed)
def test_random_split_doesnt_device_put_during_tracing(self):
key = self.seed_prng(1).block_until_ready()
key = _prng_key_as_array(self.seed_prng(1)).block_until_ready()
with jtu.count_device_put() as count:
jax.jit(random.split)(key)
self.assertEqual(count[0], 1) # 1 for the argument device_put
@ -1490,120 +1468,6 @@ class LaxRandomTest(jtu.JaxTestCase):
jax.jit(f).lower()
class KeyArrayTest(jtu.JaxTestCase):
# Key arrays involve:
# * a Python key array type, backed by an underlying uint32 "base" array,
# * an abstract shaped array with key eltype,
# * primitives that return or operate on such shaped arrays,
# * compiler lowerings,
# * a device-side data representation...
# Test it all!
#
# A handful of these tests follow CustomElementTypesTest in
# lax_tests.py as an example. If you add a test here (e.g. testing
# lowering of an key-eltyped shaped array), consider whether it
# might also be a more general test of extended/custom eltypes. If
# so, add a corresponding test to to CustomElementTypesTest as well.
def make_keys(self, *shape):
key = prng.seed_with_impl(prng.threefry_prng_impl, 28)
return jnp.reshape(jax.random.split(key, np.prod(shape)), shape)
# -- prng primitives
def test_random_wrap_vmap(self):
f = partial(prng_internal.random_wrap, impl=prng.threefry_prng_impl)
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
keys = jax.vmap(f, in_axes=0)(base_arr)
self.assertIsInstance(keys, random.KeyArray)
self.assertEqual(keys.shape, (3,))
keys = jax.vmap(f, in_axes=1)(base_arr.T)
self.assertIsInstance(keys, random.KeyArray)
self.assertEqual(keys.shape, (3,))
# -- eltype-polymorphic operations
def test_scan_jaxpr(self):
ks = self.make_keys(3, 4, 5)
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
# { lambda ; a:key<fry>[3,4,5]. let
# b:key<fry>[3,5,4] = scan[
# jaxpr={ lambda ; c:key<fry>[4,5]. let
# d:key<fry>[5,4] = transpose[permutation=(1, 0)] c
# in (d,) }
# ] a
# in (b,) }
self.assertLen(jaxpr.invars, 1)
a, = jaxpr.invars
self.assertIsInstance(a.aval, core.ShapedArray)
self.assertEqual(a.aval.shape, (3, 4, 5))
self.assertIs(type(a.aval.dtype), jax._src.prng.KeyTy)
self.assertLen(jaxpr.eqns, 1)
e, = jaxpr.eqns
self.assertLen(e.outvars, 1)
b, = e.outvars
self.assertIsInstance(b.aval, core.ShapedArray)
self.assertEqual(b.aval.shape, (3, 5, 4))
self.assertIs(type(b.aval.dtype), jax._src.prng.KeyTy)
def test_scan_lowering(self):
ks = self.make_keys(3, 4)
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
_, out = jax.jit(f)(ks) # doesn't crash
self.assertIsInstance(out, random.KeyArray)
self.assertEqual(out.shape, (3, 4))
def test_vmap(self):
ks = self.make_keys(3, 4, 5)
ys = jax.vmap(jax.jit(lambda k: k.T))(ks)
self.assertEqual(ys.shape, (3, 5, 4))
def test_slice(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (2, 4))
def test_dynamic_slice(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (2, 4))
def test_transpose(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x.T)(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (4, 3))
def test_gather(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (4,))
ks = self.make_keys(3, 4, 5)
ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (4, 5))
ys = jax.jit(lambda x: x[1, 2:4])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (2, 5))
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (2,))
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
self.assertIsInstance(ys, random.KeyArray)
self.assertEqual(ys.shape, (3, 2, 1))
# TODO(frostig,mattjj): more polymorphic primitives tests
threefry_seed = jax._src.prng.threefry_seed
threefry_split = jax._src.prng.threefry_split
threefry_random_bits = jax._src.prng.threefry_random_bits
@ -1636,8 +1500,7 @@ double_threefry_prng_impl = prng.PRNGImpl(
seed=_double_threefry_seed,
split=_double_threefry_split,
random_bits=_double_threefry_random_bits,
fold_in=_double_threefry_fold_in,
tag='fry2')
fold_in=_double_threefry_fold_in)
@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
@ -1651,40 +1514,9 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
self.assertEqual(keys.shape, (10,))
def test_vmap_fold_in_shape(self):
# broadcast with scalar
keys = random.split(self.seed_prng(73), 2)
msgs = jnp.arange(3)
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
self.assertEqual(out.shape, (3,))
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys)
self.assertEqual(out.shape, (2,))
out = vmap(random.fold_in, in_axes=(None, 0))(keys[0], msgs)
self.assertEqual(out.shape, (3,))
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
self.assertEqual(out.shape, (2,))
# vmap all
msgs = jnp.arange(2)
out = vmap(random.fold_in)(keys, msgs)
self.assertEqual(out.shape, (2,))
# nested vmap
keys = random.split(self.seed_prng(73), 2 * 3).reshape((2, 3))
msgs = jnp.arange(2 * 3).reshape((2, 3))
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T)
self.assertEqual(out.shape, (2, 3))
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys, msgs.T)
self.assertEqual(out.shape, (3, 2))
def test_vmap_split_mapped_key(self):
key = self.seed_prng(73)
mapped_keys = random.split(key, num=3)
forloop_keys = [random.split(k) for k in mapped_keys]
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2))
for fk, vk in zip(forloop_keys, vmapped_keys):
self.assertArraysEqual(fk.unsafe_raw_array(),
vk.unsafe_raw_array())
keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3))
self.assertEqual(keys.shape, (3,))
def test_cannot_add(self):
key = self.seed_prng(73)
@ -1696,27 +1528,24 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
"https://github.com/numpy/numpy/issues/19305")
def test_grad_of_prng_key(self):
key = self.seed_prng(73)
with self.assertRaisesRegex(TypeError, 'input element type key<fry2>'):
jax.grad(lambda x: 1., allow_int=True)(key)
jax.grad(lambda x: 1., allow_int=True)(key) # does not crash
# TODO(frostig): remove `with_config` we always enable_custom_prng
@jtu.with_config(jax_default_prng_impl='rbg')
@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def seed_prng(self, seed):
return random.rbg_key(seed)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_split_shape(self):
key = self.seed_prng(73)
keys = random.split(key, 10)
self.assertEqual(keys.shape, (10,))
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_vmap_fold_in_shape(self):
LaxRandomWithCustomPRNGTest.test_vmap_fold_in_shape(self)
key = self.seed_prng(73)
keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3))
self.assertEqual(keys.shape, (3,))
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_vmap_split_not_mapped_key(self):
key = self.seed_prng(73)
single_split_key = random.split(key)
@ -1726,7 +1555,6 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertArraysEqual(vk.unsafe_raw_array(),
single_split_key.unsafe_raw_array())
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_vmap_split_mapped_key(self):
key = self.seed_prng(73)
mapped_keys = random.split(key, num=3)
@ -1746,7 +1574,6 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
self.assertEqual(rand_nums.shape, (3,))
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_cannot_add(self):
key = self.seed_prng(73)
self.assertRaisesRegex(
@ -1755,11 +1582,9 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
@skipIf(np.__version__ == "1.21.0",
"https://github.com/numpy/numpy/issues/19305")
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
def test_grad_of_prng_key(self):
key = self.seed_prng(73)
with self.assertRaisesRegex(TypeError, 'input element type key<.?rbg>'):
jax.grad(lambda x: 1., allow_int=True)(key)
jax.grad(lambda x: 1., allow_int=True)(key) # does not crash
def test_random_split_doesnt_device_put_during_tracing(self):
return # this test doesn't apply to the RBG PRNG
@ -1768,19 +1593,16 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
# TODO(mattjj): enable this test if/when RngBitGenerator supports it
raise SkipTest('8-bit types not supported with RBG PRNG')
# TODO(frostig): remove `with_config` we always enable_custom_prng
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def seed_prng(self, seed):
return random.unsafe_rbg_key(seed)
return prng.seed_with_impl(prng.unsafe_rbg_prng_impl, seed)
def like(keys):
return jnp.ones(keys.shape)
@skipIf(not config.jax_enable_custom_prng,
'custom PRNG tests require config.jax_enable_custom_prng')
class JnpWithKeyArrayTest(jtu.JaxTestCase):
class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase):
def test_reshape(self):
key = random.PRNGKey(123)
keys = random.split(key, 4)
@ -1797,14 +1619,10 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.assertEqual(out.shape, (3,))
def test_concatenate(self):
self.skipTest('jnp.concatenate on key arrays') # TODO(frostig)
key = random.PRNGKey(123)
keys = random.split(key, 2)
ref = jnp.concatenate([like(keys)] * 3, axis=0)
out = jnp.concatenate([keys, keys, keys], axis=0)
self.assertEqual(out.shape, ref.shape)
self.assertEqual(out.shape, (6,))
out = jax.jit(lambda xs: jnp.concatenate(xs, axis=0))([keys, keys, keys])
ref = jnp.concatenate([like(keys)] * 3, axis=0)
self.assertEqual(out.shape, ref.shape)
self.assertEqual(out.shape, (6,))
@ -1836,19 +1654,15 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.assertEqual(out.shape, (3,))
def test_append(self):
self.skipTest('jnp.append on key arrays') # TODO(frostig)
key = random.PRNGKey(123)
out = jnp.append(key, key)
ref = jnp.append(like(key), like(key))
self.assertEqual(out.shape, ref.shape)
self.assertEqual(out.shape, (2,))
out1 = jnp.append(out, out)
ref1 = jnp.append(like(out), like(out))
self.assertEqual(out1.shape, ref1.shape)
self.assertEqual(out1.shape, (4,))
out2 = jax.jit(jnp.append)(key, key)
self.assertEqual(out2.shape, ref.shape)
self.assertEqual(out2.shape, (6,))
out_ = jnp.append(out, out)
ref_ = jnp.append(like(out), like(out))
self.assertEqual(out_.shape, ref_.shape)
self.assertEqual(out_.shape, (4,))
def test_ravel(self):
key = random.PRNGKey(123)