mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type during Python typechecking. PiperOrigin-RevId: 468835674
This commit is contained in:
parent
78cfbebfba
commit
9789e83b26
@ -1107,6 +1107,9 @@ def _check_scalar(x):
|
||||
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
raise TypeError(
|
||||
f"{name} with input element type {core.aval_eltype(aval).name}")
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
|
||||
@ -1125,6 +1128,9 @@ _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
|
||||
|
||||
def _check_output_dtype_revderiv(name, holomorphic, x):
|
||||
aval = core.get_aval(x)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
raise TypeError(
|
||||
f"{name} with output element type {core.aval_eltype(aval).name}")
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
|
||||
@ -1202,6 +1208,9 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
raise TypeError(
|
||||
f"jacfwd with input element type {core.aval_eltype(aval).name}")
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
|
||||
@ -2957,7 +2966,7 @@ class ShapeDtypeStruct:
|
||||
__slots__ = ["shape", "dtype", "named_shape"]
|
||||
def __init__(self, shape, dtype, named_shape=None):
|
||||
self.shape = shape
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.dtype = dtype if core.is_custom_eltype(dtype) else np.dtype(dtype)
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
size = property(lambda self: prod(self.shape))
|
||||
|
@ -112,9 +112,14 @@ def apply_primitive(prim, *args, **params):
|
||||
**params)
|
||||
return compiled_fun(*args)
|
||||
|
||||
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
|
||||
# TODO(phawkins,frostig,mattjj): update code referring to
|
||||
# xla.apply_primitive to point here, or use simple_impl if that's why
|
||||
# it is using apply_primitive to begin with
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
def simple_impl(prim):
|
||||
prim.def_impl(partial(apply_primitive, prim))
|
||||
|
||||
RuntimeToken = Any
|
||||
|
||||
class RuntimeTokenSet(threading.local):
|
||||
|
@ -1574,9 +1574,12 @@ def _pred_bcast_select_mhlo(
|
||||
assert x.type == y.type, (x.type, y.type)
|
||||
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
||||
pred_aval.shape, x_y_aval)
|
||||
x_y_type = mlir.aval_to_ir_type(x_y_aval)
|
||||
bcast_pred_type = ir.RankedTensorType.get(
|
||||
x_y_type.shape, mlir.dtype_to_ir_type(np.dtype(np.bool_)))
|
||||
bcast_pred = mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
|
||||
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
|
||||
bcast_pred_type, pred,
|
||||
mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
|
||||
return mhlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
### fori_loop
|
||||
|
@ -1239,11 +1239,14 @@ def stop_gradient(x: T) -> T:
|
||||
DeviceArray(0., dtype=float32, weak_type=True)
|
||||
"""
|
||||
def stop(x):
|
||||
if (dtypes.issubdtype(_dtype(x), np.floating) or
|
||||
# only bind primitive on inexact dtypes, to avoid some staging
|
||||
if core.has_custom_eltype(x):
|
||||
return x
|
||||
elif (dtypes.issubdtype(_dtype(x), np.floating) or
|
||||
dtypes.issubdtype(_dtype(x), np.complexfloating)):
|
||||
return ad_util.stop_gradient_p.bind(x)
|
||||
else:
|
||||
return x # only bind primitive on inexact dtypes, to avoid some staging
|
||||
return x
|
||||
return tree_map(stop, x)
|
||||
|
||||
def reduce_precision(operand: Union[float, Array],
|
||||
@ -1504,7 +1507,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
|
||||
return result_dtype(*avals)
|
||||
|
||||
|
||||
def _broadcasting_shape_rule(name, *avals):
|
||||
def broadcasting_shape_rule(name, *avals):
|
||||
shapes = [aval.shape for aval in avals if aval.shape]
|
||||
if not shapes:
|
||||
return ()
|
||||
@ -1545,7 +1548,7 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):
|
||||
|
||||
def naryop(result_dtype, accepted_dtypes, name):
|
||||
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name)
|
||||
shape_rule = partial(_broadcasting_shape_rule, name)
|
||||
shape_rule = partial(broadcasting_shape_rule, name)
|
||||
weak_type_rule = partial(_naryop_weak_type_rule, name)
|
||||
prim = standard_primitive(shape_rule, dtype_rule, name,
|
||||
weak_type_rule=weak_type_rule)
|
||||
|
@ -527,7 +527,7 @@ view of the input.
|
||||
|
||||
@_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
|
||||
def transpose(a, axes=None):
|
||||
_check_arraylike("transpose", a)
|
||||
_stackable(a) or _check_arraylike("transpose", a)
|
||||
axes = np.arange(ndim(a))[::-1] if axes is None else axes
|
||||
return lax.transpose(a, axes)
|
||||
|
||||
@ -5107,27 +5107,31 @@ _set_shaped_array_attributes(ShapedArray)
|
||||
_set_shaped_array_attributes(DShapedArray)
|
||||
|
||||
|
||||
def _set_device_array_base_attributes(device_array):
|
||||
def _set_device_array_base_attributes(device_array, include=None):
|
||||
# Forward operators, methods, and properties on DeviceArray to lax_numpy
|
||||
# functions (with no Tracers involved; this forwarding is direct)
|
||||
def maybe_setattr(attr_name, target):
|
||||
if not include or attr_name in include:
|
||||
setattr(device_array, attr_name, target)
|
||||
|
||||
for operator_name, function in _operators.items():
|
||||
setattr(device_array, f"__{operator_name}__", function)
|
||||
maybe_setattr(f"__{operator_name}__", function)
|
||||
for method_name in _nondiff_methods + _diff_methods:
|
||||
setattr(device_array, method_name, globals()[method_name])
|
||||
maybe_setattr(method_name, globals()[method_name])
|
||||
# TODO(jakevdp): remove tile method after August 2022
|
||||
setattr(device_array, "tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
|
||||
setattr(device_array, "reshape", _reshape)
|
||||
setattr(device_array, "transpose", _transpose)
|
||||
setattr(device_array, "flatten", ravel)
|
||||
setattr(device_array, "flat", property(_notimplemented_flat))
|
||||
setattr(device_array, "T", property(transpose))
|
||||
setattr(device_array, "real", property(real))
|
||||
setattr(device_array, "imag", property(imag))
|
||||
setattr(device_array, "astype", _astype)
|
||||
setattr(device_array, "view", _view)
|
||||
setattr(device_array, "nbytes", property(_nbytes))
|
||||
setattr(device_array, "itemsize", property(_itemsize))
|
||||
setattr(device_array, "clip", _clip)
|
||||
maybe_setattr("tile", _deprecate_function(tile, "arr.tile(...) is deprecated and will be removed. Use jnp.tile(arr, ...) instead."))
|
||||
maybe_setattr("reshape", _reshape)
|
||||
maybe_setattr("transpose", _transpose)
|
||||
maybe_setattr("flatten", ravel)
|
||||
maybe_setattr("flat", property(_notimplemented_flat))
|
||||
maybe_setattr("T", property(transpose))
|
||||
maybe_setattr("real", property(real))
|
||||
maybe_setattr("imag", property(imag))
|
||||
maybe_setattr("astype", _astype)
|
||||
maybe_setattr("view", _view)
|
||||
maybe_setattr("nbytes", property(_nbytes))
|
||||
maybe_setattr("itemsize", property(_itemsize))
|
||||
maybe_setattr("clip", _clip)
|
||||
|
||||
_set_device_array_base_attributes(device_array.DeviceArray)
|
||||
_set_device_array_base_attributes(Array)
|
||||
|
668
jax/_src/prng.py
668
jax/_src/prng.py
@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import abc
|
||||
from functools import partial
|
||||
from typing import Callable, Iterator, NamedTuple, Sequence
|
||||
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -28,23 +29,29 @@ from jax.config import config
|
||||
from jax.dtypes import float0
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
|
||||
_expand_bool_indices, _register_stackable)
|
||||
from jax._src.numpy import lax_numpy
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import canonicalize_axis, prod
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
|
||||
from jax._src.lib import gpu_prng
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
UINT_DTYPES = {
|
||||
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
|
||||
|
||||
# -- PRNG implementation interface --
|
||||
# -- PRNG implementation interface
|
||||
|
||||
class PRNGImpl(NamedTuple):
|
||||
"""Specifies PRNG key shape and operations.
|
||||
@ -68,15 +75,22 @@ class PRNGImpl(NamedTuple):
|
||||
split: Callable
|
||||
random_bits: Callable
|
||||
fold_in: Callable
|
||||
tag: str = '?'
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.tag)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.tag
|
||||
|
||||
def pprint(self):
|
||||
return (pp.text(f"{self.__class__.__name__}:") +
|
||||
return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") +
|
||||
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
|
||||
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
|
||||
]))))
|
||||
|
||||
|
||||
# -- PRNG key arrays --
|
||||
# -- PRNG key arrays
|
||||
|
||||
def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
||||
ndim = len(impl.key_shape)
|
||||
@ -94,8 +108,19 @@ def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
||||
f"got dtype={key_data.dtype}")
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class PRNGKeyArray:
|
||||
class PRNGKeyArrayMeta(abc.ABCMeta):
|
||||
"""Metaclass for overriding PRNGKeyArray isinstance checks."""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
try:
|
||||
return (hasattr(instance, 'aval') and
|
||||
isinstance(instance.aval, core.ShapedArray) and
|
||||
type(instance.aval.dtype) is KeyTy)
|
||||
except AttributeError:
|
||||
super().__instancecheck__(instance)
|
||||
|
||||
|
||||
class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
"""An array whose elements are PRNG keys.
|
||||
|
||||
This class lifts the definition of a PRNG, provided in the form of a
|
||||
@ -110,58 +135,30 @@ class PRNGKeyArray:
|
||||
"""
|
||||
|
||||
impl: PRNGImpl
|
||||
_keys: jnp.ndarray
|
||||
_base_array: jnp.ndarray
|
||||
|
||||
def __init__(self, impl, key_data: jnp.ndarray):
|
||||
# key_data might be a placeholder python `object` or `bool`
|
||||
# instead of a jnp.ndarray due to tree_unflatten
|
||||
if type(key_data) not in [object, bool]:
|
||||
_check_prng_key_data(impl, key_data)
|
||||
def __init__(self, impl, key_data: Any):
|
||||
assert not isinstance(key_data, core.Tracer)
|
||||
_check_prng_key_data(impl, key_data)
|
||||
self.impl = impl
|
||||
self._keys = key_data
|
||||
|
||||
def tree_flatten(self):
|
||||
return (self._keys,), self.impl
|
||||
self._base_array = key_data
|
||||
|
||||
# TODO(frostig): rename to unsafe_base_array, or just offer base_array attr?
|
||||
def unsafe_raw_array(self):
|
||||
"""Access the raw numerical array that carries underlying key data.
|
||||
|
||||
Returns:
|
||||
A uint32 JAX array whose leading dimensions are ``self.shape``.
|
||||
"""
|
||||
return self._keys
|
||||
return self._base_array
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, impl, keys):
|
||||
keys, = keys
|
||||
return cls(impl, keys)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
# TODO(frostig): remove after deprecation window
|
||||
if config.jax_enable_custom_prng:
|
||||
raise AttributeError("'PRNGKeyArray' has no attribute 'dtype'")
|
||||
else:
|
||||
warnings.warn(
|
||||
'deprecated `dtype` attribute of PRNG key arrays', FutureWarning)
|
||||
return np.uint32
|
||||
def block_until_ready(self):
|
||||
_ = self._base_array.block_until_ready()
|
||||
return self
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
# TODO(frostig): simplify once we always enable_custom_prng
|
||||
if config.jax_enable_custom_prng:
|
||||
return self._shape
|
||||
else:
|
||||
warnings.warn(
|
||||
'deprecated `shape` attribute of PRNG key arrays. In a future version '
|
||||
'of JAX this attribute will be removed or its value may change.',
|
||||
FutureWarning)
|
||||
return self._keys.shape
|
||||
|
||||
@property
|
||||
def _shape(self):
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
return self._keys.shape[:-base_ndim]
|
||||
return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape)
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
@ -169,74 +166,527 @@ class PRNGKeyArray:
|
||||
|
||||
def _is_scalar(self):
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
return self._keys.ndim == base_ndim
|
||||
return self._base_array.ndim == base_ndim
|
||||
|
||||
def __len__(self):
|
||||
if self._is_scalar():
|
||||
raise TypeError('len() of unsized object')
|
||||
return len(self._keys)
|
||||
return len(self._base_array)
|
||||
|
||||
def __iter__(self) -> Iterator['PRNGKeyArray']:
|
||||
if self._is_scalar():
|
||||
raise TypeError('iteration over a 0-d single PRNG key')
|
||||
return (PRNGKeyArray(self.impl, k) for k in iter(self._keys))
|
||||
raise TypeError('iteration over a 0-d key array')
|
||||
# TODO(frostig): we may want to avoid iteration by slicing because
|
||||
# a very common use of iteration is `k1, k2 = split(key)`, and
|
||||
# slicing/indexing may be trickier to track for linearity checking
|
||||
# purposes. Maybe we can:
|
||||
# * introduce an unpack primitive+traceable (also allow direct use)
|
||||
# * unpack upfront into shape[0] many keyarray slices
|
||||
# * return iter over these unpacked slices
|
||||
# Whatever we do, we'll want to do it by overriding
|
||||
# ShapedArray._iter when the eltype is KeyTy...
|
||||
return (PRNGKeyArray(self.impl, k) for k in iter(self._base_array))
|
||||
|
||||
def __getitem__(self, idx) -> 'PRNGKeyArray':
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
ndim = self._keys.ndim - base_ndim
|
||||
indexable_shape = self.impl.key_shape[:ndim]
|
||||
idx = _eliminate_deprecated_list_indexing(idx)
|
||||
idx = _expand_bool_indices(idx, indexable_shape)
|
||||
idx = _canonicalize_tuple_index(ndim, idx, array_name='PRNGKeyArray')
|
||||
return PRNGKeyArray(self.impl, self._keys[idx])
|
||||
# TODO(frostig): are all of the stackable methods below (reshape,
|
||||
# concat, broadcast_to, expand_dims), and the stackable registration,
|
||||
# still needed? If, with some work, none are needed, then do we want
|
||||
# to remove stackables altogether? This may be the only application.
|
||||
|
||||
def _fold_in(self, data: int) -> 'PRNGKeyArray':
|
||||
return PRNGKeyArray(self.impl, self.impl.fold_in(self._keys, data))
|
||||
|
||||
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
|
||||
return self.impl.random_bits(self._keys, bit_width, shape)
|
||||
|
||||
def _split(self, num: int) -> 'PRNGKeyArray':
|
||||
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
|
||||
|
||||
def reshape(self, newshape, order=None):
|
||||
reshaped_keys = jnp.reshape(self._keys, (*newshape, -1), order=order)
|
||||
return PRNGKeyArray(self.impl, reshaped_keys)
|
||||
# TODO(frostig): Remove? Overwritten below in particular
|
||||
def reshape(self, newshape, order=None) -> 'PRNGKeyArray':
|
||||
reshaped_base = jnp.reshape(self._base_array, (*newshape, -1), order=order)
|
||||
return PRNGKeyArray(self.impl, reshaped_base)
|
||||
|
||||
def concatenate(self, key_arrs, axis, dtype=None):
|
||||
if dtype is not None:
|
||||
raise ValueError('dtype argument not supported for concatenating PRNGKeyArray')
|
||||
raise ValueError(
|
||||
'dtype argument not supported for concatenating PRNGKeyArray')
|
||||
axis = canonicalize_axis(axis, self.ndim)
|
||||
arrs = [self._keys, *[k._keys for k in key_arrs]]
|
||||
arrs = [self._base_array, *[k._base_array for k in key_arrs]]
|
||||
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))
|
||||
|
||||
def broadcast_to(self, shape):
|
||||
if jnp.ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
new_shape = (*shape, *self.impl.key_shape)
|
||||
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))
|
||||
return PRNGKeyArray(
|
||||
self.impl, jnp.broadcast_to(self._base_array, new_shape))
|
||||
|
||||
def expand_dims(self, dimensions: Sequence[int]):
|
||||
# follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
|
||||
ndim_out = self.ndim + len(set(dimensions))
|
||||
dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
|
||||
return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))
|
||||
return PRNGKeyArray(
|
||||
self.impl, lax.expand_dims(self._base_array, dimensions))
|
||||
|
||||
def __repr__(self):
|
||||
arr_shape = self._shape
|
||||
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
|
||||
return (f'{self.__class__.__name__}[{self.impl.tag}]'
|
||||
f' {{ {self._base_array} }}')
|
||||
|
||||
def pprint(self):
|
||||
pp_keys = pp.text('shape = ') + pp.text(str(self.shape))
|
||||
pp_impl = pp.text('impl = ') + self.impl.pprint()
|
||||
return str(pp.group(
|
||||
pp.text('PRNGKeyArray:') +
|
||||
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))
|
||||
|
||||
# Hollow defs only for typing purposes, overwritten below
|
||||
#
|
||||
# TODO(frostig): there may be a better way to do this with
|
||||
# `typing.type_check_only`.
|
||||
|
||||
@property
|
||||
def T(self) -> 'PRNGKeyArray': assert False
|
||||
def __getitem__(self, _) -> 'PRNGKeyArray': assert False
|
||||
def ravel(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def squeeze(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def swapaxes(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def take(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def transpose(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def flatten(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
|
||||
|
||||
lax_numpy._set_device_array_base_attributes(PRNGKeyArray, include=[
|
||||
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
|
||||
'transpose', 'flatten', 'T'])
|
||||
lax_numpy._register_stackable(PRNGKeyArray)
|
||||
|
||||
|
||||
# TODO(frostig): remove, rerouting callers directly to random_seed
|
||||
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
|
||||
return PRNGKeyArray(impl, impl.seed(seed))
|
||||
return random_seed(seed, impl=impl)
|
||||
|
||||
_register_stackable(PRNGKeyArray)
|
||||
|
||||
# -- threefry2x32 PRNG implementation --
|
||||
def keys_shaped_array(impl, shape):
|
||||
return core.ShapedArray(shape, KeyTy(impl))
|
||||
|
||||
def keys_aval_to_base_arr_aval(keys_aval):
|
||||
shape = (*keys_aval.shape, *keys_aval.dtype.impl.key_shape)
|
||||
return core.ShapedArray(shape, np.dtype('uint32'))
|
||||
|
||||
def base_arr_shape_to_keys_shape(impl, base_arr_shape):
|
||||
base_ndim = len(impl.key_shape)
|
||||
return base_arr_shape[:-base_ndim]
|
||||
|
||||
|
||||
class KeyTy:
|
||||
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
|
||||
def __init__(self, impl):
|
||||
self.impl = impl
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'key<{self.impl.tag}>'
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
def __eq__(self, other):
|
||||
return type(other) is KeyTy and self.impl is other.impl
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.__class__, self.impl))
|
||||
|
||||
# handlers
|
||||
|
||||
@staticmethod
|
||||
def physical_avals(aval):
|
||||
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
|
||||
jnp.dtype('uint32'))]
|
||||
|
||||
@staticmethod
|
||||
def aval_to_ir_types(aval):
|
||||
phys_aval, = KeyTy.physical_avals(aval)
|
||||
return mlir.aval_to_ir_types(phys_aval)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return PRNGKeyArray(aval.dtype.impl, buf)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def sharded_result_handler(aval, sharding, indices):
|
||||
phys_aval, = KeyTy.physical_avals(aval)
|
||||
phys_handler_maker = pxla.local_result_handlers[
|
||||
(core.ShapedArray, pxla.OutputType.ShardedDeviceArray)]
|
||||
phys_handler = phys_handler_maker(phys_aval, sharding, indices)
|
||||
def handler(bufs):
|
||||
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def empty_mlir(ctx):
|
||||
aval_out, = ctx.avals_out
|
||||
return mlir.ir_constants(np.empty(aval_out.dtype.impl.key_shape,
|
||||
dtype=np.dtype('uint32')))
|
||||
|
||||
@staticmethod
|
||||
def slice_mlir(ctx, x, start_indices, limit_indices, strides):
|
||||
aval_out, = ctx.avals_out
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_zeros = [0] * len(key_shape)
|
||||
trailing_ones = [1] * len(key_shape)
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
limit_indices = (*limit_indices, *key_shape)
|
||||
strides = (*strides, *trailing_ones)
|
||||
return mhlo.SliceOp(x,
|
||||
mlir.dense_int_elements(start_indices),
|
||||
mlir.dense_int_elements(limit_indices),
|
||||
mlir.dense_int_elements(strides)).results
|
||||
|
||||
@staticmethod
|
||||
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
|
||||
aval_out, = ctx.avals_out
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, *key_shape))
|
||||
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
|
||||
|
||||
@staticmethod
|
||||
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
|
||||
start_indices = (*start_indices, *zeros)
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
start_indices).results
|
||||
|
||||
@staticmethod
|
||||
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
||||
if dyn_shape: raise NotImplementedError
|
||||
aval_out, = ctx.avals_out
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
||||
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
|
||||
return mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions)).results
|
||||
|
||||
@staticmethod
|
||||
def transpose_mlir(ctx, x, *, permutation):
|
||||
aval_out, = ctx.avals_out
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
||||
perm = [*permutation, *trailing_dims]
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
|
||||
|
||||
@staticmethod
|
||||
def gather_mlir(ctx, x, indices, *,
|
||||
dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
aval_x, aval_indices = ctx.avals_in
|
||||
aval_y, = ctx.avals_out
|
||||
key_shape = aval_x.dtype.impl.key_shape
|
||||
trailing_offset_dims = [aval_y.ndim + i for i in range(len(key_shape))]
|
||||
dimension_numbers = dimension_numbers._replace(
|
||||
offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims))
|
||||
slice_sizes = (*slice_sizes, *key_shape)
|
||||
gather_lower = partial(
|
||||
lax_internal.slicing._gather_lower, dimension_numbers=dimension_numbers,
|
||||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
||||
return mlir.delegate_lowering(
|
||||
ctx, gather_lower, x, indices,
|
||||
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices],
|
||||
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
|
||||
|
||||
core.custom_eltypes.add(KeyTy)
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = (
|
||||
lambda x: keys_shaped_array(x.impl, x.shape))
|
||||
|
||||
xla.pytype_aval_mappings[PRNGKeyArray] = (
|
||||
lambda x: keys_shaped_array(x.impl, x.shape))
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
def device_put_key_array(x: PRNGKeyArray, device):
|
||||
return dispatch.device_put(x.unsafe_raw_array(), device)
|
||||
dispatch.device_put_handlers[PRNGKeyArray] = device_put_key_array
|
||||
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, mode):
|
||||
arr = x.unsafe_raw_array()
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, devices, indices, mode)
|
||||
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
|
||||
|
||||
def key_array_constant_handler(x, canonicalize_dtypes):
|
||||
arr = x.unsafe_raw_array()
|
||||
return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes)
|
||||
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
|
||||
|
||||
|
||||
# -- primitives
|
||||
|
||||
def iterated_vmap_unary(n, f):
|
||||
for _ in range(n):
|
||||
f = jax.vmap(f)
|
||||
return f
|
||||
|
||||
# TODO(frostig): Revise the following two functions? These basically
|
||||
# undo the singleton dimensions added by `batching.defbroadcasting`.
|
||||
# It works, but introduces some possibly-redundant squeezes. Can we
|
||||
# borrow from other broadcasting primitives instead?
|
||||
|
||||
def squeeze_vmap(f, left):
|
||||
def squeeze_vmap_f(x, y):
|
||||
if left:
|
||||
x = jnp.squeeze(x, axis=0)
|
||||
axes = (None, 0)
|
||||
else:
|
||||
y = jnp.squeeze(y, axis=0)
|
||||
axes = (0, None)
|
||||
return jax.vmap(f, in_axes=axes, out_axes=0)(x, y)
|
||||
return squeeze_vmap_f
|
||||
|
||||
def iterated_vmap_binary_bcast(shape1, shape2, f):
|
||||
ndim1, ndim2 = len(shape1), len(shape2)
|
||||
if ndim1 == ndim2 == 0:
|
||||
return f
|
||||
if 0 in [ndim1, ndim2]:
|
||||
if ndim1 == 0:
|
||||
return lambda x, y: iterated_vmap_unary(ndim2, lambda y: f(x, y))(y)
|
||||
else:
|
||||
return lambda x, y: iterated_vmap_unary(ndim1, lambda x: f(x, y))(x)
|
||||
assert len(shape1) == len(shape2)
|
||||
for sz1, sz2 in reversed(zip(shape1, shape2)):
|
||||
if sz1 == sz2:
|
||||
f = jax.vmap(f, out_axes=0)
|
||||
else:
|
||||
assert sz1 == 1 or sz2 == 1, (sz1, sz2)
|
||||
f = squeeze_vmap(f, sz1 == 1)
|
||||
return f
|
||||
|
||||
|
||||
def random_seed(seeds, impl):
|
||||
# Avoid overflow error in X32 mode by first converting ints to int64.
|
||||
# This breaks JIT invariance for large ints, but supports the common
|
||||
# use-case of instantiating with Python hashes in X32 mode.
|
||||
if isinstance(seeds, int):
|
||||
seeds_arr = jnp.asarray(np.int64(seeds))
|
||||
else:
|
||||
seeds_arr = jnp.asarray(seeds)
|
||||
return random_seed_p.bind(seeds_arr, impl=impl)
|
||||
|
||||
random_seed_p = core.Primitive('random_seed')
|
||||
batching.defvectorized(random_seed_p)
|
||||
|
||||
@random_seed_p.def_abstract_eval
|
||||
def random_seed_abstract_eval(seeds_aval, *, impl):
|
||||
return keys_shaped_array(impl, seeds_aval.shape)
|
||||
|
||||
@random_seed_p.def_impl
|
||||
def random_seed_impl(seeds, *, impl):
|
||||
base_arr = random_seed_impl_base(seeds, impl=impl)
|
||||
return PRNGKeyArray(impl, base_arr)
|
||||
|
||||
def random_seed_impl_base(seeds, *, impl):
|
||||
seed = iterated_vmap_unary(seeds.ndim, impl.seed)
|
||||
return seed(seeds)
|
||||
|
||||
def random_seed_lowering(ctx, seeds, *, impl):
|
||||
aval, = ctx.avals_in
|
||||
seed = iterated_vmap_unary(aval.ndim, impl.seed)
|
||||
seed_lowering = mlir.lower_fun(seed, multiple_results=False)
|
||||
return mlir.delegate_lowering(
|
||||
ctx, seed_lowering, seeds,
|
||||
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
|
||||
|
||||
mlir.register_lowering(random_seed_p, random_seed_lowering)
|
||||
|
||||
|
||||
def random_split(keys, count):
|
||||
return random_split_p.bind(keys, count=count)
|
||||
|
||||
random_split_p = core.Primitive('random_split')
|
||||
batching.defvectorized(random_split_p)
|
||||
|
||||
@random_split_p.def_abstract_eval
|
||||
def random_split_abstract_eval(keys_aval, *, count):
|
||||
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count))
|
||||
|
||||
@random_split_p.def_impl
|
||||
def random_split_impl(keys, *, count):
|
||||
base_arr = random_split_impl_base(
|
||||
keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
|
||||
return PRNGKeyArray(keys.impl, base_arr)
|
||||
|
||||
def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
|
||||
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
|
||||
return split(base_arr)
|
||||
|
||||
def random_split_lowering(ctx, keys, *, count):
|
||||
aval, = ctx.avals_in
|
||||
impl = aval.dtype.impl
|
||||
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, count))
|
||||
split_lowering = mlir.lower_fun(split, multiple_results=False)
|
||||
return mlir.delegate_lowering(
|
||||
ctx, split_lowering, keys,
|
||||
avals_in=[keys_aval_to_base_arr_aval(aval)],
|
||||
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
|
||||
|
||||
mlir.register_lowering(random_split_p, random_split_lowering)
|
||||
|
||||
|
||||
def random_fold_in(keys, msgs):
|
||||
return random_fold_in_p.bind(keys, jnp.asarray(msgs))
|
||||
|
||||
random_fold_in_p = core.Primitive('random_fold_in')
|
||||
batching.defbroadcasting(random_fold_in_p)
|
||||
|
||||
@random_fold_in_p.def_abstract_eval
|
||||
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
|
||||
shape = lax_internal.broadcasting_shape_rule(
|
||||
'random_fold_in', keys_aval, msgs_aval)
|
||||
named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval)
|
||||
return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape)
|
||||
|
||||
@random_fold_in_p.def_impl
|
||||
def random_fold_in_impl(keys, msgs):
|
||||
base_arr = random_fold_in_impl_base(
|
||||
keys.impl, keys.unsafe_raw_array(), msgs, keys.shape)
|
||||
return PRNGKeyArray(keys.impl, base_arr)
|
||||
|
||||
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
|
||||
fold_in = iterated_vmap_binary_bcast(
|
||||
keys_shape, np.shape(msgs), impl.fold_in)
|
||||
return fold_in(base_arr, msgs)
|
||||
|
||||
def random_fold_in_lowering(ctx, keys, msgs):
|
||||
keys_aval, msgs_aval = ctx.avals_in
|
||||
impl = keys_aval.dtype.impl
|
||||
fold_in = iterated_vmap_binary_bcast(
|
||||
keys_aval.shape, msgs_aval.shape, impl.fold_in)
|
||||
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
|
||||
return mlir.delegate_lowering(
|
||||
ctx, fold_in_lowering, keys, msgs,
|
||||
avals_in=[keys_aval_to_base_arr_aval(keys_aval), msgs_aval],
|
||||
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
|
||||
|
||||
mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
|
||||
|
||||
|
||||
def random_bits(keys, bit_width, shape):
|
||||
shape = core.as_named_shape(shape)
|
||||
for name, size in shape.named_items:
|
||||
# TODO(frostig,mattjj,apaszke): Is this real_size check necessary,
|
||||
# and is it meant to raise a user-facing ValueError? Should it be
|
||||
# an `assert` (or RuntimeError) instead? Why do we check it in
|
||||
# calls to `random_bits` instead of a more common paralleism path?
|
||||
real_size = lax.psum(1, name)
|
||||
if real_size != size:
|
||||
raise ValueError(f"The shape of axis {name} was specified as {size}, "
|
||||
f"but it really is {real_size}")
|
||||
axis_index = lax.axis_index(name)
|
||||
keys = random_fold_in(keys, axis_index)
|
||||
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional)
|
||||
|
||||
random_bits_p = core.Primitive('random_bits')
|
||||
batching.defvectorized(random_bits_p)
|
||||
|
||||
@random_bits_p.def_abstract_eval
|
||||
def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
|
||||
out_shape = (*keys_aval.shape, *shape)
|
||||
out_dtype = dtypes.dtype(f'uint{bit_width}')
|
||||
return core.ShapedArray(out_shape, out_dtype)
|
||||
|
||||
@random_bits_p.def_impl
|
||||
def random_bits_impl(keys, *, bit_width, shape):
|
||||
return random_bits_impl_base(keys.impl, keys.unsafe_raw_array(), keys.ndim,
|
||||
bit_width=bit_width, shape=shape)
|
||||
|
||||
def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):
|
||||
bits = iterated_vmap_unary(
|
||||
keys_ndim, lambda k: impl.random_bits(k, bit_width, shape))
|
||||
return bits(base_arr)
|
||||
|
||||
def random_bits_lowering(ctx, keys, *, bit_width, shape):
|
||||
aval, = ctx.avals_in
|
||||
impl = aval.dtype.impl
|
||||
bits = iterated_vmap_unary(
|
||||
aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
|
||||
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
|
||||
ctx_new = ctx.replace(avals_in=[keys_aval_to_base_arr_aval(aval)])
|
||||
out = bits_lowering(ctx_new, keys)
|
||||
ctx.set_tokens_out(ctx_new.tokens_out)
|
||||
return out
|
||||
|
||||
mlir.register_lowering(random_bits_p, random_bits_lowering)
|
||||
|
||||
|
||||
# The following wrap/unwrap primitives are at least a stopgap for
|
||||
# backwards compatibility, namely when `config.jax_enable_custom_prng`
|
||||
# is False. We need to convert key arrays to and from underlying
|
||||
# uint32 base array, and we may need to do so under a jit. For
|
||||
# example, we want to support:
|
||||
#
|
||||
# keys = jax.jit(random.split)(key)
|
||||
#
|
||||
# where `key` and `keys` are both acceptably old-style uint32 arrays
|
||||
# so long as enable_custom_prng is False. The way we handle this is
|
||||
# that `random.split` adapts the input/output by converting to/from
|
||||
# key arrays across its call to `random_split`. So we rely on these
|
||||
# wrap/unwrap casting primitives to allow that conversion under jit.
|
||||
#
|
||||
# We may want to keep both around for testing and debugging escape
|
||||
# hatches. We can rename them `unsafe` for emphasis, and/or issue a
|
||||
# warning on entry to the traceable.
|
||||
#
|
||||
# TODO(frostig): Consider removal once we always enable_custom_prng.
|
||||
|
||||
def random_wrap(base_arr, *, impl):
|
||||
_check_prng_key_data(impl, base_arr)
|
||||
return random_wrap_p.bind(base_arr, impl=impl)
|
||||
|
||||
random_wrap_p = core.Primitive('random_wrap')
|
||||
|
||||
@random_wrap_p.def_abstract_eval
|
||||
def random_wrap_abstract_eval(base_arr_aval, *, impl):
|
||||
shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape)
|
||||
return keys_shaped_array(impl, shape)
|
||||
|
||||
@random_wrap_p.def_impl
|
||||
def random_wrap_impl(base_arr, *, impl):
|
||||
return PRNGKeyArray(impl, base_arr)
|
||||
|
||||
def random_wrap_lowering(ctx, base_arr, *, impl):
|
||||
return [base_arr]
|
||||
|
||||
def random_wrap_batch_rule(batched_args, batch_dims, *, impl):
|
||||
x, = batched_args
|
||||
d, = batch_dims
|
||||
x = batching.bdim_at_front(x, d, 1)
|
||||
return random_wrap(x, impl=impl), 0
|
||||
|
||||
mlir.register_lowering(random_wrap_p, random_wrap_lowering)
|
||||
batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
|
||||
|
||||
|
||||
def random_unwrap(keys):
|
||||
assert isinstance(keys, PRNGKeyArray)
|
||||
return random_unwrap_p.bind(keys)
|
||||
|
||||
random_unwrap_p = core.Primitive('random_unwrap')
|
||||
batching.defvectorized(random_unwrap_p)
|
||||
|
||||
@random_unwrap_p.def_abstract_eval
|
||||
def random_unwrap_abstract_eval(keys_aval):
|
||||
return keys_aval_to_base_arr_aval(keys_aval)
|
||||
|
||||
@random_unwrap_p.def_impl
|
||||
def random_unwrap_impl(keys):
|
||||
return keys.unsafe_raw_array()
|
||||
|
||||
def random_unwrap_lowering(ctx, keys):
|
||||
return [keys]
|
||||
|
||||
mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
|
||||
|
||||
|
||||
# -- threefry2x32 PRNG implementation
|
||||
|
||||
|
||||
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
|
||||
@ -246,8 +696,8 @@ def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def threefry_seed(seed: int) -> jnp.ndarray:
|
||||
"""Create a single raw threefry PRNG key given an integer seed.
|
||||
def threefry_seed(seed: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Create a single raw threefry PRNG key from an integer seed.
|
||||
|
||||
Args:
|
||||
seed: a 64- or 32-bit integer used as the value of the key.
|
||||
@ -258,24 +708,17 @@ def threefry_seed(seed: int) -> jnp.ndarray:
|
||||
bit-casting to a pair of uint32 values (or from a 32-bit seed by
|
||||
first padding out with zeros).
|
||||
"""
|
||||
# Avoid overflowerror in X32 mode by first converting ints to int64.
|
||||
# This breaks JIT invariance for large ints, but supports the common
|
||||
# use-case of instantiating with Python hashes in X32 mode.
|
||||
if isinstance(seed, int):
|
||||
seed_arr = jnp.asarray(np.int64(seed))
|
||||
else:
|
||||
seed_arr = jnp.asarray(seed)
|
||||
if seed_arr.shape:
|
||||
if seed.shape:
|
||||
raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
|
||||
if not np.issubdtype(seed_arr.dtype, np.integer):
|
||||
if not np.issubdtype(seed.dtype, np.integer):
|
||||
raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")
|
||||
convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1])
|
||||
k1 = convert(
|
||||
lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
|
||||
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
# TODO(jakevdp): in X64 mode, this can generate 64-bit computations for 32-bit
|
||||
# inputs. We should avoid this.
|
||||
k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
|
||||
k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
|
||||
return lax.concatenate([k1, k2], 0)
|
||||
|
||||
|
||||
@ -305,7 +748,7 @@ def _threefry2x32_abstract_eval(*args):
|
||||
raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
|
||||
.format(args))
|
||||
if all(isinstance(arg, core.ShapedArray) for arg in args):
|
||||
shape = lax_internal._broadcasting_shape_rule(*args)
|
||||
shape = lax_internal.broadcasting_shape_rule(*args)
|
||||
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
|
||||
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
|
||||
else:
|
||||
@ -467,7 +910,8 @@ def _threefry_split(key, num) -> jnp.ndarray:
|
||||
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
||||
|
||||
|
||||
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
||||
def threefry_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
||||
assert not data.shape
|
||||
return _threefry_fold_in(key, jnp.uint32(data))
|
||||
|
||||
@partial(jit, inline=True)
|
||||
@ -482,15 +926,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
||||
raise TypeError("threefry_random_bits got invalid prng key.")
|
||||
if bit_width not in (8, 16, 32, 64):
|
||||
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
||||
shape = core.as_named_shape(shape)
|
||||
for name, size in shape.named_items:
|
||||
real_size = lax.psum(1, name)
|
||||
if real_size != size:
|
||||
raise ValueError(f"The shape of axis {name} was specified as {size}, "
|
||||
f"but it really is {real_size}")
|
||||
axis_index = lax.axis_index(name)
|
||||
key = threefry_fold_in(key, axis_index)
|
||||
size = prod(shape.positional)
|
||||
size = prod(shape)
|
||||
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
|
||||
# polymorphism
|
||||
max_count, r = divmod(bit_width * size, 32)
|
||||
@ -537,10 +973,11 @@ threefry_prng_impl = PRNGImpl(
|
||||
seed=threefry_seed,
|
||||
split=threefry_split,
|
||||
random_bits=threefry_random_bits,
|
||||
fold_in=threefry_fold_in)
|
||||
fold_in=threefry_fold_in,
|
||||
tag='fry')
|
||||
|
||||
|
||||
# -- RngBitGenerator PRNG implementation --
|
||||
# -- RngBitGenerator PRNG implementation
|
||||
|
||||
# This code is experimental!
|
||||
# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
|
||||
@ -548,14 +985,16 @@ threefry_prng_impl = PRNGImpl(
|
||||
# stable/deterministic across backends or compiler versions. Correspondingly, we
|
||||
# reserve the right to change any of these implementations at any time!
|
||||
|
||||
def _rbg_seed(seed: int) -> jnp.ndarray:
|
||||
def _rbg_seed(seed: jnp.ndarray) -> jnp.ndarray:
|
||||
assert not seed.shape
|
||||
halfkey = threefry_seed(seed)
|
||||
return jnp.concatenate([halfkey, halfkey])
|
||||
|
||||
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
|
||||
|
||||
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
||||
def _rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
||||
assert not data.shape
|
||||
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
|
||||
|
||||
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
|
||||
@ -572,14 +1011,16 @@ rbg_prng_impl = PRNGImpl(
|
||||
seed=_rbg_seed,
|
||||
split=_rbg_split,
|
||||
random_bits=_rbg_random_bits,
|
||||
fold_in=_rbg_fold_in)
|
||||
fold_in=_rbg_fold_in,
|
||||
tag='rbg')
|
||||
|
||||
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
# treat 10 iterations of random bits as a 'hash function'
|
||||
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
|
||||
return keys[::10]
|
||||
|
||||
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
||||
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
||||
assert not data.shape
|
||||
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
|
||||
return key ^ random_bits[-1]
|
||||
|
||||
@ -588,4 +1029,5 @@ unsafe_rbg_prng_impl = PRNGImpl(
|
||||
seed=_rbg_seed,
|
||||
split=_unsafe_rbg_split,
|
||||
random_bits=_rbg_random_bits,
|
||||
fold_in=_unsafe_rbg_fold_in)
|
||||
fold_in=_unsafe_rbg_fold_in,
|
||||
tag='urbg')
|
||||
|
@ -59,9 +59,10 @@ _lax_const = lax_internal._const
|
||||
def _isnan(x):
|
||||
return lax.ne(x, x)
|
||||
|
||||
|
||||
def _check_prng_key(key):
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
if type(key) is prng.PRNGKeyArray:
|
||||
if isinstance(key, prng.PRNGKeyArray):
|
||||
return key, False
|
||||
elif _arraylike(key):
|
||||
if config.jax_enable_custom_prng:
|
||||
@ -69,21 +70,23 @@ def _check_prng_key(key):
|
||||
'Raw arrays as random keys to jax.random functions are deprecated. '
|
||||
'Assuming valid threefry2x32 key for now.',
|
||||
FutureWarning)
|
||||
return prng.PRNGKeyArray(default_prng_impl(), key), True
|
||||
return prng.random_wrap(key, impl=default_prng_impl()), True
|
||||
else:
|
||||
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
||||
|
||||
|
||||
def _return_prng_keys(was_wrapped, key):
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
assert type(key) is prng.PRNGKeyArray, type(key)
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
if config.jax_enable_custom_prng:
|
||||
return key
|
||||
else:
|
||||
return key.unsafe_raw_array() if was_wrapped else key
|
||||
return prng.random_unwrap(key) if was_wrapped else key
|
||||
|
||||
|
||||
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
|
||||
key, _ = _check_prng_key(key)
|
||||
return key._random_bits(bit_width, shape)
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
return prng.random_bits(key, bit_width=bit_width, shape=shape)
|
||||
|
||||
|
||||
PRNG_IMPLS = {
|
||||
@ -158,7 +161,8 @@ def unsafe_rbg_key(seed: int) -> KeyArray:
|
||||
def _fold_in(key: KeyArray, data: int) -> KeyArray:
|
||||
# Alternative to fold_in() to use within random samplers.
|
||||
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
|
||||
return key._fold_in(jnp.uint32(data))
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
return prng.random_fold_in(key, jnp.uint32(data))
|
||||
|
||||
def fold_in(key: KeyArray, data: int) -> KeyArray:
|
||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||
@ -176,8 +180,10 @@ def fold_in(key: KeyArray, data: int) -> KeyArray:
|
||||
|
||||
def _split(key: KeyArray, num: int = 2) -> KeyArray:
|
||||
# Alternative to split() to use within random samplers.
|
||||
# TODO(frostig): remove and use split() once we always enable_custom_prng
|
||||
return key._split(num)
|
||||
# TODO(frostig): remove and use split(); we no longer need to wait
|
||||
# to always enable_custom_prng
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
return prng.random_split(key, count=num)
|
||||
|
||||
def split(key: KeyArray, num: int = 2) -> KeyArray:
|
||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||
@ -966,8 +972,7 @@ def _gamma_one(key: KeyArray, alpha, log_space):
|
||||
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
|
||||
|
||||
|
||||
def _gamma_grad(sample, a, *, prng_impl, log_space):
|
||||
del prng_impl # unused
|
||||
def _gamma_grad(sample, a, *, log_space):
|
||||
samples = jnp.reshape(sample, -1)
|
||||
alphas = jnp.reshape(a, -1)
|
||||
if log_space:
|
||||
@ -987,34 +992,38 @@ def _gamma_grad(sample, a, *, prng_impl, log_space):
|
||||
grads = vmap(gamma_grad)(alphas, samples)
|
||||
return grads.reshape(np.shape(a))
|
||||
|
||||
def _gamma_impl(raw_key, a, *, prng_impl, log_space, use_vmap=False):
|
||||
a_shape = jnp.shape(a)
|
||||
def _gamma_impl(key, a, *, log_space, use_vmap=False):
|
||||
# split key to match the shape of a
|
||||
key_ndim = len(raw_key.shape) - len(prng_impl.key_shape)
|
||||
key = raw_key.reshape((-1,) + prng_impl.key_shape)
|
||||
key = vmap(prng_impl.split, in_axes=(0, None))(key, prod(a_shape[key_ndim:]))
|
||||
keys = key.reshape((-1,) + prng_impl.key_shape)
|
||||
keys = prng.PRNGKeyArray(prng_impl, keys)
|
||||
alphas = jnp.reshape(a, -1)
|
||||
a_shape = jnp.shape(a)
|
||||
split_count = prod(a_shape[key.ndim:])
|
||||
keys = key.flatten()
|
||||
keys = vmap(_split, in_axes=(0, None))(keys, split_count)
|
||||
keys = keys.flatten()
|
||||
alphas = a.flatten()
|
||||
|
||||
if use_vmap:
|
||||
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
|
||||
else:
|
||||
samples = lax.map(lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))
|
||||
samples = lax.map(
|
||||
lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))
|
||||
|
||||
return jnp.reshape(samples, a_shape)
|
||||
|
||||
def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl, log_space):
|
||||
k, a = batched_args
|
||||
bk, ba = batch_dims
|
||||
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
|
||||
k = batching.bdim_at_front(k, bk, size)
|
||||
a = batching.bdim_at_front(a, ba, size)
|
||||
return random_gamma_p.bind(k, a, prng_impl=prng_impl, log_space=log_space), 0
|
||||
def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
|
||||
k, a = batched_args
|
||||
bk, ba = batch_dims
|
||||
size = next(
|
||||
t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
|
||||
k = batching.bdim_at_front(k, bk, size)
|
||||
a = batching.bdim_at_front(a, ba, size)
|
||||
return random_gamma_p.bind(k, a, log_space=log_space), 0
|
||||
|
||||
random_gamma_p = core.Primitive('random_gamma')
|
||||
random_gamma_p.def_impl(_gamma_impl)
|
||||
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
|
||||
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
|
||||
ad.defjvp2(
|
||||
random_gamma_p, None,
|
||||
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
|
||||
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
|
||||
partial(_gamma_impl, use_vmap=True),
|
||||
multiple_results=False))
|
||||
@ -1108,7 +1117,7 @@ def _gamma(key, a, shape, dtype, log_space=False):
|
||||
a = lax.convert_element_type(a, dtype)
|
||||
if np.shape(a) != shape:
|
||||
a = jnp.broadcast_to(a, shape)
|
||||
return random_gamma_p.bind(key.unsafe_raw_array(), a, prng_impl=key.impl, log_space=log_space)
|
||||
return random_gamma_p.bind(key, a, log_space=log_space)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2, 3, 4), inline=True)
|
||||
@ -1217,10 +1226,13 @@ def poisson(key: KeyArray,
|
||||
``shape is not None, or else by ``lam.shape``.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
if key.impl is not prng.threefry_prng_impl:
|
||||
# TODO(frostig): generalize underlying poisson implementation and
|
||||
# remove this check (and use of core.get_aval)
|
||||
key_impl = core.get_aval(key).dtype.impl
|
||||
if key_impl is not prng.threefry_prng_impl:
|
||||
raise NotImplementedError(
|
||||
'`poisson` is only implemented for the threefry2x32 RNG, '
|
||||
f'not {key.impl}')
|
||||
f'not {key_impl}')
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
@ -1619,6 +1631,7 @@ def orthogonal(
|
||||
Returns:
|
||||
A random array of shape `(*shape, n, n)` and specified dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
_check_shape("orthogonal", shape)
|
||||
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
||||
z = normal(key, (*shape, n, n), dtype)
|
||||
@ -1644,6 +1657,7 @@ def generalized_normal(
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
_check_shape("generalized_normal", shape)
|
||||
keys = split(key)
|
||||
g = gamma(keys[0], 1/p, shape, dtype)
|
||||
@ -1672,9 +1686,10 @@ def ball(
|
||||
Returns:
|
||||
A random array of shape `(*shape, d)` and specified dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key(key)
|
||||
_check_shape("ball", shape)
|
||||
d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()")
|
||||
keys = split(key)
|
||||
g = generalized_normal(keys[0], p, (*shape, d), dtype)
|
||||
e = exponential(keys[1], shape, dtype)
|
||||
k1, k2 = split(key)
|
||||
g = generalized_normal(k1, p, (*shape, d), dtype)
|
||||
e = exponential(k2, shape, dtype)
|
||||
return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None]
|
||||
|
@ -160,7 +160,15 @@ def count_device_put():
|
||||
|
||||
def device_put_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return device_put(*args, **kwargs)
|
||||
# device_put handlers might call `dispatch.device_put` (e.g. on an
|
||||
# underlying payload or several). We only want to count these
|
||||
# recursive puts once, so we skip counting more than the outermost
|
||||
# one in such a call stack.
|
||||
dispatch.device_put = device_put
|
||||
try:
|
||||
return device_put(*args, **kwargs)
|
||||
finally:
|
||||
dispatch.device_put = device_put_and_count
|
||||
|
||||
dispatch.device_put = device_put_and_count
|
||||
try:
|
||||
|
25
jax/core.py
25
jax/core.py
@ -1190,8 +1190,25 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
|
||||
|
||||
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
|
||||
|
||||
custom_eltypes: Set[Any] = set()
|
||||
|
||||
# TODO(frostig): update inliners of the four functions below to call them
|
||||
def has_custom_eltype(x: Any):
|
||||
return aval_has_custom_eltype(get_aval(x))
|
||||
|
||||
def eltype(x: Any):
|
||||
return aval_eltype(get_aval(x))
|
||||
|
||||
def aval_has_custom_eltype(aval: UnshapedArray):
|
||||
return is_custom_eltype(aval.dtype)
|
||||
|
||||
def aval_eltype(aval: UnshapedArray):
|
||||
return aval.dtype
|
||||
|
||||
def is_custom_eltype(eltype):
|
||||
return type(eltype) in custom_eltypes
|
||||
|
||||
def _short_dtype_name(dtype) -> str:
|
||||
if type(dtype) in custom_eltypes:
|
||||
return str(dtype)
|
||||
@ -1404,8 +1421,14 @@ class ConcreteArray(ShapedArray):
|
||||
_float = concretization_function_error(float, True)
|
||||
_complex = concretization_function_error(complex, True)
|
||||
|
||||
# TODO(frostig,mattjj): rename to primal_eltype_to_tangent_eltype
|
||||
def primal_dtype_to_tangent_dtype(primal_dtype):
|
||||
if not dtypes.issubdtype(primal_dtype, np.inexact):
|
||||
# TODO(frostig,mattjj): determines that all custom eltypes have
|
||||
# float0 tangent type, which works fine for all our current custom
|
||||
# eltype applications. We may some day want to delegate this
|
||||
# decision to the eltype.
|
||||
if (type(primal_dtype) in custom_eltypes or
|
||||
not dtypes.issubdtype(primal_dtype, np.inexact)):
|
||||
return dtypes.float0
|
||||
else:
|
||||
return primal_dtype
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Primitives with limited support for jax2tf
|
||||
|
||||
*Last generated on (YYYY-MM-DD): 2022-08-10*
|
||||
*Last generated on (YYYY-MM-DD): 2022-08-17*
|
||||
|
||||
This document summarizes known limitations of the jax2tf conversion.
|
||||
There are several kinds of limitations.
|
||||
@ -136,6 +136,7 @@ with jax2tf. The following table lists that cases when this does not quite hold:
|
||||
| max | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
|
||||
| random_split | Returns JAX key arrays, so compare underlying base array | all | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_add | Numeric comparison disabled: Large deviations on TPU for enable_xla=False | float16, float32 | tpu | compiled, eager, graph |
|
||||
| sort | Numeric comparison disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph |
|
||||
| svd | custom numeric comparison when compute_uv on CPU/GPU | all | cpu, gpu | compiled, eager, graph |
|
||||
|
@ -698,6 +698,7 @@ def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
|
||||
|
||||
def _convert_jax_impl(jax_impl: Callable, *,
|
||||
multiple_results=True,
|
||||
with_physical_avals=False,
|
||||
extra_name_stack: Optional[str] = None) -> Callable:
|
||||
"""Convert the JAX implementation of a primitive.
|
||||
|
||||
@ -718,6 +719,10 @@ def _convert_jax_impl(jax_impl: Callable, *,
|
||||
_out_aval: core.ShapedArray,
|
||||
**kwargs) -> Sequence[TfVal]:
|
||||
|
||||
if with_physical_avals:
|
||||
_in_avals = map(_jax_physical_aval, _in_avals)
|
||||
_out_aval = _jax_physical_aval(_out_aval)
|
||||
|
||||
# We wrap the jax_impl under _interpret_fun to abstract the TF values
|
||||
# from jax_impl and turn them into JAX abstract values.
|
||||
def jax_impl_jax_args(*jax_args):
|
||||
@ -760,8 +765,31 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal,
|
||||
return tuple(v for v, _ in out_with_avals)
|
||||
|
||||
|
||||
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
|
||||
"""Converts JAX avals from logical to physical, if relevant.
|
||||
|
||||
JAX might have avals whose logical vs physical shape/dtype may
|
||||
differ, and only the physical view is expected to possibly
|
||||
relate to TF. TF impl rules should operate on the physical form.
|
||||
|
||||
A JAX logical aval might even correspond, in principle, to several
|
||||
physical avals, but we don't support those here. Instead we assert
|
||||
there is only one and return it.
|
||||
"""
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
aval, = aval.dtype.physical_avals(aval)
|
||||
return aval
|
||||
return aval
|
||||
|
||||
def _jax_physical_dtype(dtype):
|
||||
# assuming () is a fine stand-in shape
|
||||
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
|
||||
|
||||
|
||||
def _aval_to_tf_shape(aval: core.ShapedArray) -> Tuple[Optional[int], ...]:
|
||||
|
||||
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
|
||||
aval = _jax_physical_aval(aval)
|
||||
return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d,
|
||||
aval.shape)) # type: ignore[attr-defined]
|
||||
|
||||
@ -771,6 +799,12 @@ _tf_np_dtype_for_float0 = np.int32
|
||||
def _to_tf_dtype(jax_dtype):
|
||||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||||
# due to float0 and 64-bit behavior.
|
||||
try:
|
||||
jax_dtype = _jax_physical_dtype(jax_dtype)
|
||||
except TypeError:
|
||||
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
|
||||
# tf.float32), so there is no physical dtype anyway
|
||||
pass
|
||||
if jax_dtype == dtypes.float0:
|
||||
jax_dtype = _tf_np_dtype_for_float0
|
||||
return tf.dtypes.as_dtype(jax_dtype)
|
||||
@ -834,9 +868,13 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
return tf_val, jax_dtype
|
||||
|
||||
|
||||
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
||||
# TODO(frostig,mattjj): rename dtype argument to eltype, for now just
|
||||
# being consistent.
|
||||
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
|
||||
assert all(map(lambda x: x is not None, shape)), (
|
||||
f"Argument shape should be a valid JAX shape but got {shape}")
|
||||
if dtype is not None:
|
||||
shape = _jax_physical_aval(core.ShapedArray(shape, dtype)).shape
|
||||
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
|
||||
eval_shape, dim_avals = shape_poly.get_shape_evaluator(dim_vars, shape)
|
||||
shape_values, _ = util.unzip2(_interpret_fun(lu.wrap_init(eval_shape),
|
||||
@ -860,11 +898,13 @@ def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize
|
||||
class TensorFlowTracer(core.Tracer):
|
||||
"""Tracer class that boxes a TF value and a JAX abstract value.
|
||||
|
||||
In addition to the TF value we carry the JAX abstract value because there is
|
||||
one case when it cannot be recovered from the value: when we are converting
|
||||
with polymorphic shapes, in which case the shape of the value may have
|
||||
dimensions set to `None`, which the JAX abstract value may contain more
|
||||
precise information.
|
||||
In addition to the TF value we carry the JAX abstract value because
|
||||
there are some cases when it cannot be recovered from the value:
|
||||
when we are converting with polymorphic shapes or when the JAX aval
|
||||
has a custom element type. In these cases the shape of the value may
|
||||
have dimensions set to `None`, or it may only correspond to the JAX
|
||||
"physical" (TF/lowering-compatible) shape, so the JAX abstract value
|
||||
may contain more precise information.
|
||||
|
||||
When the value has a partially-known shape, the dimensions marked as `None`
|
||||
must correspond to non-constant dimensions in the abstract value.
|
||||
@ -879,32 +919,34 @@ class TensorFlowTracer(core.Tracer):
|
||||
aval: core.AbstractValue):
|
||||
self._trace = trace
|
||||
self._aval = aval
|
||||
phys_aval = _jax_physical_aval(self._aval) # type: ignore[arg-type]
|
||||
|
||||
if isinstance(val, (tf.Tensor, tf.Variable)):
|
||||
val_shape = val.shape
|
||||
|
||||
if config.jax_enable_checks:
|
||||
assert len(self._aval.shape) == len(val_shape), f"_aval.shape={self._aval.shape} different rank than val_shape={val_shape}"
|
||||
assert len(phys_aval.shape) == len(val_shape), f"_aval.shape={phys_aval.shape} different rank than val_shape={val_shape}"
|
||||
# To compare types, we must handle float0 in JAX and x64 in TF
|
||||
if self._aval.dtype == dtypes.float0:
|
||||
assert _to_tf_dtype(self._aval.dtype) == val.dtype, f"expected {self._aval.dtype} == {val.dtype}"
|
||||
if phys_aval.dtype == dtypes.float0:
|
||||
assert _to_tf_dtype(phys_aval.dtype) == val.dtype, f"expected {phys_aval.dtype} == {val.dtype}"
|
||||
else:
|
||||
assert self._aval.dtype == _to_jax_dtype(val.dtype), f"expected {self._aval.dtype} == {val.dtype}"
|
||||
assert phys_aval.dtype == _to_jax_dtype(val.dtype), f"expected {phys_aval.dtype} == {val.dtype}"
|
||||
|
||||
for aval_dim, val_dim in zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
|
||||
for aval_dim, val_dim in zip(phys_aval.shape, val_shape): # type: ignore[attr-defined]
|
||||
if val_dim is None:
|
||||
assert shape_poly.is_poly_dim(aval_dim), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
assert shape_poly.is_poly_dim(aval_dim), f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
elif not shape_poly.is_poly_dim(aval_dim):
|
||||
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
assert aval_dim == val_dim, f"expected {phys_aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
else:
|
||||
# We have a TF value with known shape, and the abstract shape is a shape variable.
|
||||
try:
|
||||
aval_int = int(_eval_shape([aval_dim])) # type: ignore
|
||||
except (TypeError, KeyError):
|
||||
continue
|
||||
assert aval_int == val_dim, f"expected {self._aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
|
||||
assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." # type: ignore
|
||||
|
||||
self.val = _tfval_to_tensor_jax_dtype(val,
|
||||
self._aval.dtype,
|
||||
phys_aval.dtype,
|
||||
memoize_constants=True)[0] # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
@ -1624,7 +1666,7 @@ tf_impl[lax.bitcast_convert_type_p] = _bitcast_convert_type
|
||||
|
||||
def _clamp(minval, operand, maxval, *, _in_avals, _out_aval):
|
||||
# The below permits mirroring the behavior of JAX when maxval < minval
|
||||
op_shape_tf_val = _eval_shape(_in_avals[1].shape)
|
||||
op_shape_tf_val = _eval_shape(_in_avals[1].shape, _in_avals[1].dtype)
|
||||
maxval = tf.broadcast_to(maxval, op_shape_tf_val)
|
||||
minval = tf.math.minimum(tf.broadcast_to(minval, op_shape_tf_val), maxval)
|
||||
return tf.clip_by_value(operand, minval, maxval)
|
||||
@ -1779,11 +1821,12 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
|
||||
# bcast_dims must be strictly increasing.
|
||||
# len(bcast_dims) == len(operand.shape)
|
||||
op_shape = _in_avals[0].shape
|
||||
dtype = _in_avals[0].dtype
|
||||
add_1s_shape = [1] * len(shape)
|
||||
for i, broadcast_dim_i in enumerate(broadcast_dimensions):
|
||||
add_1s_shape[broadcast_dim_i] = op_shape[i]
|
||||
with_1s = tf.reshape(operand, _eval_shape(add_1s_shape))
|
||||
return tf.broadcast_to(with_1s, _eval_shape(shape))
|
||||
with_1s = tf.reshape(operand, _eval_shape(add_1s_shape, dtype=dtype))
|
||||
return tf.broadcast_to(with_1s, _eval_shape(shape, dtype=dtype))
|
||||
|
||||
|
||||
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
|
||||
@ -1798,20 +1841,21 @@ def _empty(*, eltype):
|
||||
tf_impl[lax_internal.empty_p] = _empty
|
||||
|
||||
|
||||
def _reshape(operand, *, new_sizes, dimensions):
|
||||
def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval):
|
||||
if dimensions is None:
|
||||
dimensions = tf.range(tf.rank(operand))
|
||||
new_sizes_tf = _eval_shape(new_sizes)
|
||||
new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype)
|
||||
return tf.reshape(tf.transpose(operand, dimensions), new_sizes_tf)
|
||||
|
||||
|
||||
tf_impl[lax.reshape_p] = _reshape
|
||||
tf_impl_with_avals[lax.reshape_p] = _reshape
|
||||
|
||||
|
||||
def _squeeze(operand, *, dimensions, _in_avals, _out_aval):
|
||||
op_shape = _in_avals[0].shape
|
||||
op_aval = _jax_physical_aval(_in_avals[0])
|
||||
op_shape = op_aval.shape
|
||||
new_shape = tuple(d for i, d in enumerate(op_shape) if i not in dimensions)
|
||||
new_shape_tf = _eval_shape(new_shape)
|
||||
new_shape_tf = _eval_shape(new_shape, op_aval.dtype)
|
||||
return tf.reshape(operand, new_shape_tf)
|
||||
|
||||
|
||||
@ -2242,6 +2286,82 @@ def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
|
||||
|
||||
|
||||
def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval):
|
||||
|
||||
def impl_wrapper(seeds: TfVal, *, impl):
|
||||
return jax._src.prng.random_seed_impl_base(seeds, impl=impl)
|
||||
|
||||
converted_impl = _convert_jax_impl(
|
||||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||||
extra_name_stack="random_seed")
|
||||
return converted_impl(
|
||||
seeds, impl=impl, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
|
||||
tf_impl_with_avals[jax._src.prng.random_seed_p] = _random_seed_impl
|
||||
|
||||
|
||||
def _random_split_impl(keys: TfVal, *, count, _in_avals, _out_aval):
|
||||
keys_aval, = _in_avals
|
||||
|
||||
def impl_wrapper(keys: TfVal, *, count):
|
||||
return jax._src.prng.random_split_impl_base(
|
||||
keys_aval.dtype.impl, keys, keys_aval.ndim, count=count)
|
||||
|
||||
converted_impl = _convert_jax_impl(
|
||||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||||
extra_name_stack="random_split")
|
||||
return converted_impl(
|
||||
keys, count=count, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
|
||||
tf_impl_with_avals[jax._src.prng.random_split_p] = _random_split_impl
|
||||
|
||||
|
||||
def _random_fold_in_impl(keys: TfVal, msgs: TfVal, *, _in_avals, _out_aval):
|
||||
keys_aval, _ = _in_avals
|
||||
|
||||
def impl_wrapper(keys: TfVal, msgs: TfVal):
|
||||
return jax._src.prng.random_fold_in_impl_base(
|
||||
keys_aval.dtype.impl, keys, msgs, keys_aval.shape)
|
||||
|
||||
converted_impl = _convert_jax_impl(
|
||||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||||
extra_name_stack="random_fold_in")
|
||||
return converted_impl(
|
||||
keys, msgs, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
|
||||
tf_impl_with_avals[jax._src.prng.random_fold_in_p] = _random_fold_in_impl
|
||||
|
||||
|
||||
def _random_bits_impl(keys: TfVal, *, bit_width, shape, _in_avals, _out_aval):
|
||||
keys_aval, = _in_avals
|
||||
|
||||
def impl_wrapper(keys: TfVal, **kwargs):
|
||||
return jax._src.prng.random_bits_impl_base(
|
||||
keys_aval.dtype.impl, keys, keys_aval.ndim,
|
||||
bit_width=bit_width, shape=shape)
|
||||
|
||||
converted_impl = _convert_jax_impl(
|
||||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||||
extra_name_stack="random_bits")
|
||||
return converted_impl(keys, bit_width=bit_width, shape=shape,
|
||||
_in_avals=_in_avals, _out_aval=_out_aval)
|
||||
|
||||
tf_impl_with_avals[jax._src.prng.random_bits_p] = _random_bits_impl
|
||||
|
||||
|
||||
def _random_wrap_impl(base_arr: TfVal, *, impl, _in_avals, _out_aval):
|
||||
return base_arr
|
||||
|
||||
tf_impl_with_avals[jax._src.prng.random_wrap_p] = _random_wrap_impl
|
||||
|
||||
|
||||
def _random_unwrap_impl(keys: TfVal, *, _in_avals, _out_aval):
|
||||
return keys
|
||||
|
||||
tf_impl_with_avals[jax._src.prng.random_unwrap_p] = _random_unwrap_impl
|
||||
|
||||
|
||||
|
||||
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
|
||||
res = _convert_jax_impl(
|
||||
partial(jax._src.prng._threefry2x32_lowering, use_rolled_loops=False),
|
||||
@ -2329,7 +2449,7 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shap
|
||||
|
||||
start_indices = _maybe_cast_to_int64(start_indices)
|
||||
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
|
||||
slice_sizes_tf = _eval_shape(slice_sizes)
|
||||
slice_sizes_tf = _eval_shape(slice_sizes, _in_avals[0].dtype)
|
||||
out = tfxla.gather(operand, start_indices, proto, slice_sizes_tf,
|
||||
indices_are_sorted)
|
||||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||||
@ -2362,7 +2482,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
start_indices = _maybe_cast_to_int64(tf.stack(start_indices))
|
||||
slice_sizes_tf = _eval_shape(slice_sizes)
|
||||
slice_sizes_tf = _eval_shape(slice_sizes, dtype=_in_avals[0].dtype)
|
||||
|
||||
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes_tf)
|
||||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||||
@ -2523,7 +2643,7 @@ def _batched_cond_while(*args: TfVal, cond_nconsts: int,
|
||||
def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal:
|
||||
pred_b_bcast = _broadcast_in_dim(
|
||||
pred_b,
|
||||
shape=c_aval.shape, # a JAX shape
|
||||
shape=_jax_physical_aval(c_aval).shape, # a JAX shape
|
||||
broadcast_dimensions=list(range(len(pred_b.shape))),
|
||||
_in_avals=cond_jaxpr.out_avals,
|
||||
_out_aval=core.ShapedArray(c_aval.shape, np.bool_))
|
||||
|
@ -16,6 +16,7 @@
|
||||
import itertools
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
@ -130,7 +131,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
|
||||
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
|
||||
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
|
||||
"random_categorical", "random_split", "random_uniform", "random_randint",
|
||||
"random_categorical", "random_uniform", "random_randint",
|
||||
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
||||
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
|
||||
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
|
||||
@ -153,6 +154,18 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
dtypes=[np.complex64, np.complex128],
|
||||
custom_assert=custom_assert)
|
||||
|
||||
@classmethod
|
||||
def random_seed(cls, handess: primitive_harness.Harness):
|
||||
return [custom_random_keys_output()]
|
||||
|
||||
@classmethod
|
||||
def random_split(cls, handess: primitive_harness.Harness):
|
||||
return [custom_random_keys_output()]
|
||||
|
||||
@classmethod
|
||||
def random_fold_in(cls, handess: primitive_harness.Harness):
|
||||
return [custom_random_keys_output()]
|
||||
|
||||
@classmethod
|
||||
def acos(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
@ -1270,6 +1283,24 @@ def custom_numeric(
|
||||
enabled=enabled,
|
||||
tol=tol)
|
||||
|
||||
def custom_random_keys_output():
|
||||
def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
|
||||
# TODO(frostig): Don't need this conditional once we always
|
||||
# enable_custom_prng. We can even assert the isinstance instead.
|
||||
def unwrap_keys(keys):
|
||||
if isinstance(keys, jax.random.KeyArray):
|
||||
return jax._src.prng.random_unwrap(keys)
|
||||
else:
|
||||
return keys
|
||||
|
||||
tst.assertAllClose(unwrap_keys(result_jax), result_tf,
|
||||
atol=tol, rtol=tol, err_msg=err_msg)
|
||||
|
||||
return custom_numeric(
|
||||
description="Returns JAX key arrays, so compare underlying base array",
|
||||
modes=("eager", "graph", "compiled"),
|
||||
custom_assert=custom_assert)
|
||||
|
||||
|
||||
def missing_tf_kernel(*,
|
||||
description="op not defined for dtype",
|
||||
|
@ -195,6 +195,9 @@ _constant_handlers : Dict[type, ConstantHandler] = {}
|
||||
def register_constant_handler(type_: type, handler_fun: ConstantHandler):
|
||||
_constant_handlers[type_] = handler_fun
|
||||
|
||||
def get_constant_handler(type_: type) -> ConstantHandler:
|
||||
return _constant_handlers[type_]
|
||||
|
||||
def ir_constants(val: Any,
|
||||
canonicalize_types: bool = True) -> Sequence[ir.Value]:
|
||||
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
|
||||
@ -1096,6 +1099,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
|
||||
else:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||
|
||||
out, tokens = jaxpr_subcomp(
|
||||
ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts),
|
||||
|
@ -570,8 +570,11 @@ local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRes
|
||||
|
||||
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
|
||||
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
return aval.dtype.sharded_result_handler(aval, sharding, indices)
|
||||
else:
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
local_result_handlers[(ShapedArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
|
||||
local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_array_result_handler
|
||||
|
||||
|
@ -119,13 +119,25 @@ Here is a short summary:
|
||||
NOTE: RNGs are currently identical across shardings because the random value
|
||||
is first materialized replicated on each device and then the slice that each
|
||||
device needs is later sliced out.
|
||||
|
||||
"""
|
||||
|
||||
# TODO(frostig): replace with KeyArray from jax._src.random once we
|
||||
# always enable_custom_prng
|
||||
from jax._src.prng import PRNGKeyArray
|
||||
KeyArray = PRNGKeyArray
|
||||
from jax._src.prng import PRNGKeyArray as _PRNGKeyArray
|
||||
# TODO(frostig): remove this typechecking workaround. Our move away
|
||||
# from PRNGKeyArray as a pytree led to Python typechecker breakages in
|
||||
# several downstream annotations (e.g. annotations in jax-dependent
|
||||
# libraries that are violated by their callers). It may be that the
|
||||
# pytree registration decorator invalidated the checks. This will be
|
||||
# easier to handle after we always enable_custom_prng.
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
PRNGKeyArray = typing.Any
|
||||
KeyArray = typing.Any
|
||||
else:
|
||||
# TODO(frostig): replace with KeyArray from jax._src.random once we
|
||||
# always enable_custom_prng
|
||||
PRNGKeyArray = _PRNGKeyArray
|
||||
KeyArray = PRNGKeyArray
|
||||
|
||||
|
||||
from jax._src.random import (
|
||||
PRNGKey as PRNGKey,
|
||||
|
@ -764,16 +764,11 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
def test_omnistaging(self):
|
||||
# See https://github.com/google/jax/issues/5206
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def _prng_key_as_array(key):
|
||||
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def _array_as_prng_key(arr):
|
||||
# TODO(frostig): remove `wrap` once we always enable_custom_prng
|
||||
def wrap(arr):
|
||||
arr = np.array(arr, dtype=np.uint32)
|
||||
if config.jax_enable_custom_prng:
|
||||
return jax._src.prng.PRNGKeyArray(
|
||||
jax._src.prng.threefry_prng_impl, arr)
|
||||
return jax._src.prng.random_wrap(arr, impl=jax.random.default_prng_impl())
|
||||
else:
|
||||
return arr
|
||||
|
||||
@ -784,10 +779,11 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
key_list[0] = key
|
||||
return jax.random.normal(subkey, ())
|
||||
|
||||
key_list[0] = _array_as_prng_key([2384771982, 3928867769])
|
||||
key_list[0] = wrap([2384771982, 3928867769])
|
||||
init()
|
||||
self.jit(init)()
|
||||
self.assertIsInstance(_prng_key_as_array(key_list[0]), core.Tracer)
|
||||
self.assertIsInstance(key_list[0], core.Tracer)
|
||||
del key_list[0]
|
||||
|
||||
def test_jit_wrapped_attributes(self):
|
||||
def f(x: int) -> int:
|
||||
|
@ -919,6 +919,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
_ = hessian(f)(R) # don't crash on UnshapedArray
|
||||
|
||||
def testIssue489(self):
|
||||
# https://github.com/google/jax/issues/489
|
||||
def f(key):
|
||||
def body_fn(uk):
|
||||
key = uk[1]
|
||||
|
@ -37,7 +37,9 @@ from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax.interpreters import xla
|
||||
|
||||
import jax._src.random
|
||||
from jax._src import prng as prng_internal
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -52,6 +54,11 @@ def _prng_key_as_array(key):
|
||||
# TODO(frostig): remove once we upgrade to always enable_custom_prng
|
||||
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
|
||||
|
||||
def _maybe_unwrap(key):
|
||||
# TODO(frostig): remove once we upgrade to always enable_custom_prng
|
||||
unwrap = jax._src.prng.random_unwrap
|
||||
return unwrap(key) if config.jax_enable_custom_prng else key
|
||||
|
||||
|
||||
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
|
||||
('rbg', prng.rbg_prng_impl),
|
||||
@ -226,22 +233,28 @@ class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
def testRngRandomBits(self):
|
||||
# Test specific outputs to ensure consistent random values between JAX versions.
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def random_bits(key, *args):
|
||||
key, _ = jax._src.random._check_prng_key(key)
|
||||
return jax._src.random._random_bits(key, *args)
|
||||
|
||||
key = random.PRNGKey(1701)
|
||||
|
||||
bits8 = jax._src.random._random_bits(key, 8, (3,))
|
||||
bits8 = random_bits(key, 8, (3,))
|
||||
expected8 = np.array([216, 115, 43], dtype=np.uint8)
|
||||
self.assertArraysEqual(bits8, expected8)
|
||||
|
||||
bits16 = jax._src.random._random_bits(key, 16, (3,))
|
||||
bits16 = random_bits(key, 16, (3,))
|
||||
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
|
||||
self.assertArraysEqual(bits16, expected16)
|
||||
|
||||
bits32 = jax._src.random._random_bits(key, 32, (3,))
|
||||
bits32 = random_bits(key, 32, (3,))
|
||||
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
|
||||
self.assertArraysEqual(bits32, expected32)
|
||||
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = jax._src.random._random_bits(key, 64, (3,))
|
||||
bits64 = random_bits(key, 64, (3,))
|
||||
if config.x64_enabled:
|
||||
expected64 = np.array([3982329540505020460, 16822122385914693683,
|
||||
7882654074788531506], dtype=np.uint64)
|
||||
@ -257,23 +270,28 @@ class PrngTest(jtu.JaxTestCase):
|
||||
# on every PRNG implementation. Instead of values, only checks
|
||||
# that shapes/dtypes are as expected.
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def random_bits(key, *args):
|
||||
key, _ = jax._src.random._check_prng_key(key)
|
||||
return jax._src.random._random_bits(key, *args)
|
||||
|
||||
with jax.default_prng_impl(prng_name):
|
||||
key = random.PRNGKey(1701)
|
||||
|
||||
bits8 = jax._src.random._random_bits(key, 8, (3,))
|
||||
bits8 = random_bits(key, 8, (3,))
|
||||
self.assertEqual(bits8.shape, (3,))
|
||||
self.assertEqual(bits8.dtype, np.dtype('uint8'))
|
||||
|
||||
bits16 = jax._src.random._random_bits(key, 16, (3,))
|
||||
bits16 = random_bits(key, 16, (3,))
|
||||
self.assertEqual(bits16.shape, (3,))
|
||||
self.assertEqual(bits16.dtype, np.dtype('uint16'))
|
||||
|
||||
bits32 = jax._src.random._random_bits(key, 32, (3,))
|
||||
bits32 = random_bits(key, 32, (3,))
|
||||
self.assertEqual(bits32.shape, (3,))
|
||||
self.assertEqual(bits32.dtype, np.dtype('uint32'))
|
||||
|
||||
with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"):
|
||||
bits64 = jax._src.random._random_bits(key, 64, (3,))
|
||||
bits64 = random_bits(key, 64, (3,))
|
||||
expected_dtype = np.dtype('uint64' if config.x64_enabled else 'uint32')
|
||||
self.assertEqual(bits64.shape, (3,))
|
||||
self.assertEqual(bits64.dtype, expected_dtype)
|
||||
@ -281,11 +299,16 @@ class PrngTest(jtu.JaxTestCase):
|
||||
def testRngRandomBitsViewProperty(self):
|
||||
# TODO: add 64-bit if it ever supports this property.
|
||||
# TODO: will this property hold across endian-ness?
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def random_bits(key, *args):
|
||||
key, _ = jax._src.random._check_prng_key(key)
|
||||
return jax._src.random._random_bits(key, *args)
|
||||
|
||||
N = 10
|
||||
key = random.PRNGKey(1701)
|
||||
nbits = [8, 16, 32]
|
||||
rand_bits = [jax._src.random._random_bits(key, n, (N * 64 // n,))
|
||||
for n in nbits]
|
||||
rand_bits = [random_bits(key, n, (N * 64 // n,)) for n in nbits]
|
||||
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
|
||||
assert np.all(rand_bits_32 == rand_bits_32[0])
|
||||
|
||||
@ -430,8 +453,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
key = random.PRNGKey(1701)
|
||||
self.assertEqual(key.shape, ())
|
||||
self.assertEqual(key[None].shape, (1,))
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
|
||||
lambda: key[0])
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*', lambda: key[0])
|
||||
|
||||
def test_key_array_indexing_nd(self):
|
||||
if not config.jax_enable_custom_prng:
|
||||
@ -450,9 +472,9 @@ class PrngTest(jtu.JaxTestCase):
|
||||
(1,) * 6)
|
||||
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
|
||||
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, 2])
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
|
||||
self.assertRaisesRegex(IndexError, 'Too many indices.*',
|
||||
lambda: keys[0, 1, None, 2])
|
||||
|
||||
|
||||
@ -1405,18 +1427,18 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
jax.eval_shape(f, 0) # doesn't error
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_seed={seed}_type={type}", "seed": seed, "type": type}
|
||||
for type in ["int", "np.array", "jnp.array"]
|
||||
{"testcase_name": f"_seed={seed}_type={type_}", "seed": seed, "type_": type_}
|
||||
for type_ in ["int", "np.array", "jnp.array"]
|
||||
for seed in [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)]))
|
||||
def test_prng_jit_invariance(self, seed, type):
|
||||
if type == "int" and seed == (1 << 64) - 1:
|
||||
def test_prng_jit_invariance(self, seed, type_):
|
||||
if type_ == "int" and seed == (1 << 64) - 1:
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
if not config.x64_enabled and seed > np.iinfo(np.int32).max:
|
||||
self.skipTest("Expected failure: Python int too large.")
|
||||
type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type]
|
||||
args_maker = lambda: [type(seed)]
|
||||
make_prng = lambda seed: _prng_key_as_array(self.seed_prng(seed))
|
||||
self._CompileAndCheck(make_prng, args_maker)
|
||||
type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_]
|
||||
args_maker = lambda: [type_(seed)]
|
||||
f = lambda s: _maybe_unwrap(self.seed_prng(s))
|
||||
self._CompileAndCheck(f, args_maker)
|
||||
|
||||
def test_prng_errors(self):
|
||||
seed = np.iinfo(np.int64).max + 1
|
||||
@ -1426,7 +1448,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
jax.jit(self.seed_prng)(seed)
|
||||
|
||||
def test_random_split_doesnt_device_put_during_tracing(self):
|
||||
key = _prng_key_as_array(self.seed_prng(1)).block_until_ready()
|
||||
key = self.seed_prng(1).block_until_ready()
|
||||
with jtu.count_device_put() as count:
|
||||
jax.jit(random.split)(key)
|
||||
self.assertEqual(count[0], 1) # 1 for the argument device_put
|
||||
@ -1468,6 +1490,120 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
jax.jit(f).lower()
|
||||
|
||||
|
||||
class KeyArrayTest(jtu.JaxTestCase):
|
||||
# Key arrays involve:
|
||||
# * a Python key array type, backed by an underlying uint32 "base" array,
|
||||
# * an abstract shaped array with key eltype,
|
||||
# * primitives that return or operate on such shaped arrays,
|
||||
# * compiler lowerings,
|
||||
# * a device-side data representation...
|
||||
# Test it all!
|
||||
#
|
||||
# A handful of these tests follow CustomElementTypesTest in
|
||||
# lax_tests.py as an example. If you add a test here (e.g. testing
|
||||
# lowering of an key-eltyped shaped array), consider whether it
|
||||
# might also be a more general test of extended/custom eltypes. If
|
||||
# so, add a corresponding test to to CustomElementTypesTest as well.
|
||||
|
||||
def make_keys(self, *shape):
|
||||
key = prng.seed_with_impl(prng.threefry_prng_impl, 28)
|
||||
return jnp.reshape(jax.random.split(key, np.prod(shape)), shape)
|
||||
|
||||
# -- prng primitives
|
||||
|
||||
def test_random_wrap_vmap(self):
|
||||
f = partial(prng_internal.random_wrap, impl=prng.threefry_prng_impl)
|
||||
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
|
||||
keys = jax.vmap(f, in_axes=0)(base_arr)
|
||||
self.assertIsInstance(keys, random.KeyArray)
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
keys = jax.vmap(f, in_axes=1)(base_arr.T)
|
||||
self.assertIsInstance(keys, random.KeyArray)
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
|
||||
# -- eltype-polymorphic operations
|
||||
|
||||
def test_scan_jaxpr(self):
|
||||
ks = self.make_keys(3, 4, 5)
|
||||
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
|
||||
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
|
||||
# { lambda ; a:key<fry>[3,4,5]. let
|
||||
# b:key<fry>[3,5,4] = scan[
|
||||
# jaxpr={ lambda ; c:key<fry>[4,5]. let
|
||||
# d:key<fry>[5,4] = transpose[permutation=(1, 0)] c
|
||||
# in (d,) }
|
||||
# ] a
|
||||
# in (b,) }
|
||||
self.assertLen(jaxpr.invars, 1)
|
||||
a, = jaxpr.invars
|
||||
self.assertIsInstance(a.aval, core.ShapedArray)
|
||||
self.assertEqual(a.aval.shape, (3, 4, 5))
|
||||
self.assertIs(type(a.aval.dtype), jax._src.prng.KeyTy)
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
e, = jaxpr.eqns
|
||||
self.assertLen(e.outvars, 1)
|
||||
b, = e.outvars
|
||||
self.assertIsInstance(b.aval, core.ShapedArray)
|
||||
self.assertEqual(b.aval.shape, (3, 5, 4))
|
||||
self.assertIs(type(b.aval.dtype), jax._src.prng.KeyTy)
|
||||
|
||||
def test_scan_lowering(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
|
||||
_, out = jax.jit(f)(ks) # doesn't crash
|
||||
self.assertIsInstance(out, random.KeyArray)
|
||||
self.assertEqual(out.shape, (3, 4))
|
||||
|
||||
def test_vmap(self):
|
||||
ks = self.make_keys(3, 4, 5)
|
||||
ys = jax.vmap(jax.jit(lambda k: k.T))(ks)
|
||||
self.assertEqual(ys.shape, (3, 5, 4))
|
||||
|
||||
def test_slice(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (2, 4))
|
||||
|
||||
def test_dynamic_slice(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (2, 4))
|
||||
|
||||
def test_transpose(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: x.T)(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (4, 3))
|
||||
|
||||
def test_gather(self):
|
||||
ks = self.make_keys(3, 4)
|
||||
ys = jax.jit(lambda x: x[1])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (4,))
|
||||
|
||||
ks = self.make_keys(3, 4, 5)
|
||||
|
||||
ys = jax.jit(lambda x: x[1])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (4, 5))
|
||||
|
||||
ys = jax.jit(lambda x: x[1, 2:4])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (2, 5))
|
||||
|
||||
ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (2,))
|
||||
|
||||
ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
|
||||
self.assertIsInstance(ys, random.KeyArray)
|
||||
self.assertEqual(ys.shape, (3, 2, 1))
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
|
||||
threefry_seed = jax._src.prng.threefry_seed
|
||||
threefry_split = jax._src.prng.threefry_split
|
||||
threefry_random_bits = jax._src.prng.threefry_random_bits
|
||||
@ -1500,7 +1636,8 @@ double_threefry_prng_impl = prng.PRNGImpl(
|
||||
seed=_double_threefry_seed,
|
||||
split=_double_threefry_split,
|
||||
random_bits=_double_threefry_random_bits,
|
||||
fold_in=_double_threefry_fold_in)
|
||||
fold_in=_double_threefry_fold_in,
|
||||
tag='fry2')
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng,
|
||||
'custom PRNG tests require config.jax_enable_custom_prng')
|
||||
@ -1514,9 +1651,40 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
self.assertEqual(keys.shape, (10,))
|
||||
|
||||
def test_vmap_fold_in_shape(self):
|
||||
# broadcast with scalar
|
||||
keys = random.split(self.seed_prng(73), 2)
|
||||
msgs = jnp.arange(3)
|
||||
out = vmap(lambda i: random.fold_in(keys[0], i))(msgs)
|
||||
self.assertEqual(out.shape, (3,))
|
||||
out = vmap(lambda k: random.fold_in(k, msgs[0]))(keys)
|
||||
self.assertEqual(out.shape, (2,))
|
||||
out = vmap(random.fold_in, in_axes=(None, 0))(keys[0], msgs)
|
||||
self.assertEqual(out.shape, (3,))
|
||||
out = vmap(random.fold_in, in_axes=(0, None))(keys, msgs[0])
|
||||
self.assertEqual(out.shape, (2,))
|
||||
|
||||
# vmap all
|
||||
msgs = jnp.arange(2)
|
||||
out = vmap(random.fold_in)(keys, msgs)
|
||||
self.assertEqual(out.shape, (2,))
|
||||
|
||||
# nested vmap
|
||||
keys = random.split(self.seed_prng(73), 2 * 3).reshape((2, 3))
|
||||
msgs = jnp.arange(2 * 3).reshape((2, 3))
|
||||
out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T)
|
||||
self.assertEqual(out.shape, (2, 3))
|
||||
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys, msgs.T)
|
||||
self.assertEqual(out.shape, (3, 2))
|
||||
|
||||
def test_vmap_split_mapped_key(self):
|
||||
key = self.seed_prng(73)
|
||||
keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3))
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
mapped_keys = random.split(key, num=3)
|
||||
forloop_keys = [random.split(k) for k in mapped_keys]
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
self.assertEqual(vmapped_keys.shape, (3, 2))
|
||||
for fk, vk in zip(forloop_keys, vmapped_keys):
|
||||
self.assertArraysEqual(fk.unsafe_raw_array(),
|
||||
vk.unsafe_raw_array())
|
||||
|
||||
def test_cannot_add(self):
|
||||
key = self.seed_prng(73)
|
||||
@ -1528,24 +1696,27 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
"https://github.com/numpy/numpy/issues/19305")
|
||||
def test_grad_of_prng_key(self):
|
||||
key = self.seed_prng(73)
|
||||
jax.grad(lambda x: 1., allow_int=True)(key) # does not crash
|
||||
with self.assertRaisesRegex(TypeError, 'input element type key<fry2>'):
|
||||
jax.grad(lambda x: 1., allow_int=True)(key)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng,
|
||||
'custom PRNG tests require config.jax_enable_custom_prng')
|
||||
|
||||
# TODO(frostig): remove `with_config` we always enable_custom_prng
|
||||
@jtu.with_config(jax_default_prng_impl='rbg')
|
||||
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
def seed_prng(self, seed):
|
||||
return random.rbg_key(seed)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_split_shape(self):
|
||||
key = self.seed_prng(73)
|
||||
keys = random.split(key, 10)
|
||||
self.assertEqual(keys.shape, (10,))
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_vmap_fold_in_shape(self):
|
||||
key = self.seed_prng(73)
|
||||
keys = vmap(lambda i: random.fold_in(key, i))(jnp.arange(3))
|
||||
self.assertEqual(keys.shape, (3,))
|
||||
LaxRandomWithCustomPRNGTest.test_vmap_fold_in_shape(self)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_vmap_split_not_mapped_key(self):
|
||||
key = self.seed_prng(73)
|
||||
single_split_key = random.split(key)
|
||||
@ -1555,6 +1726,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertArraysEqual(vk.unsafe_raw_array(),
|
||||
single_split_key.unsafe_raw_array())
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_vmap_split_mapped_key(self):
|
||||
key = self.seed_prng(73)
|
||||
mapped_keys = random.split(key, num=3)
|
||||
@ -1574,6 +1746,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
self.assertEqual(rand_nums.shape, (3,))
|
||||
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_cannot_add(self):
|
||||
key = self.seed_prng(73)
|
||||
self.assertRaisesRegex(
|
||||
@ -1582,9 +1755,11 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
|
||||
@skipIf(np.__version__ == "1.21.0",
|
||||
"https://github.com/numpy/numpy/issues/19305")
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key arrays')
|
||||
def test_grad_of_prng_key(self):
|
||||
key = self.seed_prng(73)
|
||||
jax.grad(lambda x: 1., allow_int=True)(key) # does not crash
|
||||
with self.assertRaisesRegex(TypeError, 'input element type key<.?rbg>'):
|
||||
jax.grad(lambda x: 1., allow_int=True)(key)
|
||||
|
||||
def test_random_split_doesnt_device_put_during_tracing(self):
|
||||
return # this test doesn't apply to the RBG PRNG
|
||||
@ -1593,16 +1768,19 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
# TODO(mattjj): enable this test if/when RngBitGenerator supports it
|
||||
raise SkipTest('8-bit types not supported with RBG PRNG')
|
||||
|
||||
|
||||
# TODO(frostig): remove `with_config` we always enable_custom_prng
|
||||
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
|
||||
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
|
||||
def seed_prng(self, seed):
|
||||
return prng.seed_with_impl(prng.unsafe_rbg_prng_impl, seed)
|
||||
return random.unsafe_rbg_key(seed)
|
||||
|
||||
def like(keys):
|
||||
return jnp.ones(keys.shape)
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng,
|
||||
'custom PRNG tests require config.jax_enable_custom_prng')
|
||||
class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase):
|
||||
class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
def test_reshape(self):
|
||||
key = random.PRNGKey(123)
|
||||
keys = random.split(key, 4)
|
||||
@ -1619,10 +1797,14 @@ class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, (3,))
|
||||
|
||||
def test_concatenate(self):
|
||||
self.skipTest('jnp.concatenate on key arrays') # TODO(frostig)
|
||||
key = random.PRNGKey(123)
|
||||
keys = random.split(key, 2)
|
||||
out = jnp.concatenate([keys, keys, keys], axis=0)
|
||||
ref = jnp.concatenate([like(keys)] * 3, axis=0)
|
||||
out = jnp.concatenate([keys, keys, keys], axis=0)
|
||||
self.assertEqual(out.shape, ref.shape)
|
||||
self.assertEqual(out.shape, (6,))
|
||||
out = jax.jit(lambda xs: jnp.concatenate(xs, axis=0))([keys, keys, keys])
|
||||
self.assertEqual(out.shape, ref.shape)
|
||||
self.assertEqual(out.shape, (6,))
|
||||
|
||||
@ -1654,15 +1836,19 @@ class JnpWithPRNGKeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, (3,))
|
||||
|
||||
def test_append(self):
|
||||
self.skipTest('jnp.append on key arrays') # TODO(frostig)
|
||||
key = random.PRNGKey(123)
|
||||
out = jnp.append(key, key)
|
||||
ref = jnp.append(like(key), like(key))
|
||||
self.assertEqual(out.shape, ref.shape)
|
||||
self.assertEqual(out.shape, (2,))
|
||||
out_ = jnp.append(out, out)
|
||||
ref_ = jnp.append(like(out), like(out))
|
||||
self.assertEqual(out_.shape, ref_.shape)
|
||||
self.assertEqual(out_.shape, (4,))
|
||||
out1 = jnp.append(out, out)
|
||||
ref1 = jnp.append(like(out), like(out))
|
||||
self.assertEqual(out1.shape, ref1.shape)
|
||||
self.assertEqual(out1.shape, (4,))
|
||||
out2 = jax.jit(jnp.append)(key, key)
|
||||
self.assertEqual(out2.shape, ref.shape)
|
||||
self.assertEqual(out2.shape, (6,))
|
||||
|
||||
def test_ravel(self):
|
||||
key = random.PRNGKey(123)
|
||||
|
Loading…
x
Reference in New Issue
Block a user