mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
sketch: setup for new key array implementation based on eltypes
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
39d54bdbf6
commit
acb5e491ab
@ -108,9 +108,14 @@ def apply_primitive(prim, *args, **params):
|
||||
**params)
|
||||
return compiled_fun(*args)
|
||||
|
||||
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
|
||||
# TODO(phawkins,frostig,mattjj): update code referring to
|
||||
# xla.apply_primitive to point here, or use simple_impl if that's why
|
||||
# it is using apply_primitive to begin with
|
||||
xla.apply_primitive = apply_primitive
|
||||
|
||||
def simple_impl(prim):
|
||||
prim.def_impl(partial(apply_primitive, prim))
|
||||
|
||||
RuntimeToken = Any
|
||||
|
||||
class RuntimeTokenSet(threading.local):
|
||||
|
282
jax/_src/prng.py
282
jax/_src/prng.py
@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Iterator, NamedTuple, Sequence
|
||||
from typing import Callable, Hashable, Iterator, NamedTuple, Sequence
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -29,6 +29,9 @@ from jax.dtypes import float0
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
@ -36,15 +39,18 @@ from jax._src.numpy.lax_numpy import (
|
||||
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
|
||||
_expand_bool_indices, _register_stackable)
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import canonicalize_axis, prod
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
|
||||
from jax._src.lib import gpu_prng
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
UINT_DTYPES = {
|
||||
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
|
||||
|
||||
# -- PRNG implementation interface --
|
||||
# -- PRNG implementation interface
|
||||
|
||||
class PRNGImpl(NamedTuple):
|
||||
"""Specifies PRNG key shape and operations.
|
||||
@ -68,15 +74,22 @@ class PRNGImpl(NamedTuple):
|
||||
split: Callable
|
||||
random_bits: Callable
|
||||
fold_in: Callable
|
||||
tag: str
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.tag)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.tag
|
||||
|
||||
def pprint(self):
|
||||
return (pp.text(f"{self.__class__.__name__}:") +
|
||||
return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") +
|
||||
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
|
||||
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
|
||||
]))))
|
||||
|
||||
|
||||
# -- PRNG key arrays --
|
||||
# -- PRNG key arrays
|
||||
|
||||
def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
||||
ndim = len(impl.key_shape)
|
||||
@ -94,7 +107,6 @@ def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
||||
f"got dtype={key_data.dtype}")
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class PRNGKeyArray:
|
||||
"""An array whose elements are PRNG keys.
|
||||
|
||||
@ -117,6 +129,7 @@ class PRNGKeyArray:
|
||||
# instead of a jnp.ndarray due to tree_unflatten
|
||||
if type(key_data) not in [object, bool]:
|
||||
_check_prng_key_data(impl, key_data)
|
||||
assert not isinstance(key_data, core.Tracer)
|
||||
self.impl = impl
|
||||
self._keys = key_data
|
||||
|
||||
@ -148,6 +161,7 @@ class PRNGKeyArray:
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._shape
|
||||
# TODO(frostig): simplify once we always enable_custom_prng
|
||||
if config.jax_enable_custom_prng:
|
||||
return self._shape
|
||||
@ -191,12 +205,15 @@ class PRNGKeyArray:
|
||||
return PRNGKeyArray(self.impl, self._keys[idx])
|
||||
|
||||
def _fold_in(self, data: int) -> 'PRNGKeyArray':
|
||||
assert False
|
||||
return PRNGKeyArray(self.impl, self.impl.fold_in(self._keys, data))
|
||||
|
||||
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
|
||||
assert False
|
||||
return self.impl.random_bits(self._keys, bit_width, shape)
|
||||
|
||||
def _split(self, num: int) -> 'PRNGKeyArray':
|
||||
assert False
|
||||
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
|
||||
|
||||
def reshape(self, newshape, order=None):
|
||||
@ -223,6 +240,9 @@ class PRNGKeyArray:
|
||||
return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{self.impl.tag}] {{ {self._keys} }}'
|
||||
|
||||
def pprint(self):
|
||||
arr_shape = self._shape
|
||||
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
|
||||
pp_impl = pp.text('impl = ') + self.impl.pprint()
|
||||
@ -232,11 +252,242 @@ class PRNGKeyArray:
|
||||
|
||||
|
||||
def seed_with_impl(impl: PRNGImpl, seed: int) -> PRNGKeyArray:
|
||||
return PRNGKeyArray(impl, impl.seed(seed))
|
||||
return random_seed(seed, impl=impl)
|
||||
|
||||
_register_stackable(PRNGKeyArray)
|
||||
|
||||
# -- threefry2x32 PRNG implementation --
|
||||
|
||||
class KeyTy:
|
||||
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
|
||||
def __init__(self, impl):
|
||||
self.impl = impl
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'key<{self.impl.tag}>'
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
def __eq__(self, other):
|
||||
return type(other) is KeyTy and self.impl is other.impl
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.__class__, self.impl))
|
||||
|
||||
# handlers
|
||||
|
||||
@staticmethod
|
||||
def aval_to_ir_types(aval):
|
||||
phys_aval = core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
|
||||
jnp.dtype('uint32'))
|
||||
return mlir.aval_to_ir_types(phys_aval)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return PRNGKeyArray(aval.dtype.impl, buf)
|
||||
return handler
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def empty_mlir(ctx):
|
||||
aval_out, = ctx.aval_out
|
||||
return mlir.ir_constants(np.empty(aval_out.dtype.impl.key_shape,
|
||||
dtype=np.dtype('uint32')))
|
||||
|
||||
@staticmethod
|
||||
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
|
||||
aval_out, = ctx.aval_out
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_zeros = [mlir.ir_constant(np.zeros(0, dtype))] * len(key_shape)
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, *key_shape))
|
||||
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
|
||||
|
||||
@staticmethod
|
||||
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.aval_out
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
|
||||
start_indices = (*start_indices, *zeros)
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
start_indices).results
|
||||
|
||||
@staticmethod
|
||||
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
||||
if dyn_shape: raise NotImplementedError
|
||||
aval_out, = ctx.avals_out
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
||||
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
|
||||
return mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions)).results
|
||||
|
||||
@staticmethod
|
||||
def transpose_mlir(ctx, x, *, permutation):
|
||||
aval_out, = ctx.avals_out
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
||||
perm = [*permutation, *trailing_dims]
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
|
||||
|
||||
core.custom_eltypes.add(KeyTy)
|
||||
|
||||
|
||||
def key_shaped_array(impl, shape):
|
||||
return core.ShapedArray(shape, KeyTy(impl))
|
||||
|
||||
def key_aval_to_raw_aval(key_array_aval):
|
||||
shape = (*key_array_aval.shape, *key_array_aval.dtype.impl.key_shape)
|
||||
return core.ShapedArray(shape, np.dtype('uint32'))
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = (
|
||||
lambda x: key_shaped_array(x.impl, x._shape))
|
||||
|
||||
xla.pytype_aval_mappings[PRNGKeyArray] = (
|
||||
lambda x: key_shaped_array(x.impl, x._shape))
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
def device_put_key_array(x: PRNGKeyArray, device):
|
||||
return dispatch._device_put_array(x.unsafe_raw_array(), device)
|
||||
dispatch.device_put_handlers[PRNGKeyArray] = device_put_key_array
|
||||
|
||||
def key_array_constant_handler(val, canonicalize_dtypes):
|
||||
return mlir._device_array_constant_handler(
|
||||
val.unsafe_raw_array(), canonicalize_dtypes)
|
||||
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)
|
||||
|
||||
|
||||
# -- primitives
|
||||
|
||||
def vmap_n(n, f):
|
||||
for _ in range(n):
|
||||
f = jax.vmap(f)
|
||||
return f
|
||||
|
||||
|
||||
def random_seed(seeds, impl):
|
||||
return random_seed_p.bind(seeds, impl=impl)
|
||||
|
||||
random_seed_p = core.Primitive('random_seed')
|
||||
batching.defvectorized(random_seed_p)
|
||||
|
||||
@random_seed_p.def_abstract_eval
|
||||
def random_seed_abstract_eval(seeds_aval, *, impl):
|
||||
return key_shaped_array(impl, seeds_aval.shape)
|
||||
|
||||
@random_seed_p.def_impl
|
||||
def random_seed_impl(seeds, *, impl):
|
||||
seed = vmap_n(seeds.ndim, impl.seed)
|
||||
return PRNGKeyArray(impl, seed(seeds))
|
||||
|
||||
def random_seed_lowering(ctx, seeds, *, impl):
|
||||
aval, = ctx.avals_in
|
||||
seed = vmap_n(aval.ndim, impl.seed)
|
||||
seed_lowering = mlir.lower_fun(seed, multiple_results=False)
|
||||
ctx_new = ctx.replace(avals_out=map(key_aval_to_raw_aval, ctx.avals_out))
|
||||
out = seed_lowering(ctx_new, seeds)
|
||||
ctx.set_tokens_out(ctx_new.tokens_out)
|
||||
return out
|
||||
|
||||
mlir.register_lowering(random_seed_p, random_seed_lowering)
|
||||
|
||||
|
||||
def random_split(keys, count):
|
||||
return random_split_p.bind(keys, count=count)
|
||||
|
||||
random_split_p = core.Primitive('random_split')
|
||||
batching.defvectorized(random_split_p)
|
||||
|
||||
@random_split_p.def_abstract_eval
|
||||
def random_split_abstract_eval(keys_aval, *, count):
|
||||
return key_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count))
|
||||
|
||||
@random_split_p.def_impl
|
||||
def random_split_impl(keys, *, count):
|
||||
impl = keys.impl
|
||||
split = vmap_n(keys.ndim, impl.split)
|
||||
return PRNGKeyArray(impl, split(keys.unsafe_raw_array(), count))
|
||||
|
||||
def random_split_lowering(ctx, keys, *, count):
|
||||
aval, = ctx.avals_in
|
||||
impl = aval.dtype.impl
|
||||
split = vmap_n(aval.ndim, impl.split)
|
||||
split_lowering = mlir.lower_fun(split, multiple_results=False)
|
||||
ctx_new = ctx.replace(avals_in=[key_aval_to_raw_aval(aval)],
|
||||
avals_out=map(key_aval_to_raw_aval, ctx.avals_out))
|
||||
out = split_lowering(ctx_new, keys)
|
||||
ctx.set_tokens_out(ctx_new.tokens_out)
|
||||
return out
|
||||
|
||||
mlir.register_lowering(random_split_p, random_split_lowering)
|
||||
|
||||
|
||||
def random_fold_in(keys, msgs):
|
||||
return random_fold_in_p.bind(keys, msgs)
|
||||
|
||||
random_fold_in_p = core.Primitive('random_fold_in')
|
||||
batching.defvectorized(random_fold_in_p)
|
||||
|
||||
@random_fold_in_p.def_abstract_eval
|
||||
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
|
||||
return keys_aval
|
||||
|
||||
@random_fold_in_p.def_impl
|
||||
def random_fold_in_impl(keys, msgs):
|
||||
impl = keys.impl
|
||||
fold_in = vmap_n(keys.ndim, impl.fold_in)
|
||||
return PRNGKeyArray(impl, fold_in(keys.unsafe_raw_array(), msgs))
|
||||
|
||||
def random_fold_in_lowering(ctx, keys, msgs):
|
||||
keys_aval, msgs_aval = ctx.avals_in
|
||||
impl = keys_aval.dtype.impl
|
||||
fold_in = vmap_n(keys_aval.ndim, impl.fold_in)
|
||||
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
|
||||
ctx_new = ctx.replace(avals_in=[key_aval_to_raw_aval(keys_aval), msgs_aval],
|
||||
avals_out=map(key_aval_to_raw_aval, ctx.avals_out))
|
||||
out = fold_in_lowering(ctx_new, keys, msgs)
|
||||
ctx.set_tokens_out(ctx_new.tokens_out)
|
||||
return out
|
||||
|
||||
mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
|
||||
|
||||
|
||||
def random_bits(keys, bit_width, shape):
|
||||
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape)
|
||||
|
||||
random_bits_p = core.Primitive('random_bits')
|
||||
batching.defvectorized(random_bits_p)
|
||||
|
||||
@random_bits_p.def_abstract_eval
|
||||
def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
|
||||
out_shape = (*keys_aval.shape, *shape)
|
||||
out_dtype = dtypes.dtype(f'uint{bit_width}')
|
||||
return core.ShapedArray(out_shape, out_dtype)
|
||||
|
||||
@random_bits_p.def_impl
|
||||
def random_bits_impl(keys, *, bit_width, shape):
|
||||
impl = keys.impl
|
||||
bits = vmap_n(keys.ndim, lambda k: impl.random_bits(k, bit_width, shape))
|
||||
return bits(keys.unsafe_raw_array())
|
||||
|
||||
def random_bits_lowering(ctx, keys, *, bit_width, shape):
|
||||
aval, = ctx.avals_in
|
||||
impl = aval.dtype.impl
|
||||
bits = vmap_n(aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
|
||||
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
|
||||
ctx_new = ctx.replace(avals_in=[key_aval_to_raw_aval(aval)])
|
||||
out = bits_lowering(ctx_new, keys)
|
||||
ctx.set_tokens_out(ctx_new.tokens_out)
|
||||
return out
|
||||
|
||||
mlir.register_lowering(random_bits_p, random_bits_lowering)
|
||||
|
||||
|
||||
# -- threefry2x32 PRNG implementation
|
||||
|
||||
|
||||
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
|
||||
@ -424,6 +675,7 @@ mlir.register_lowering(
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# TODO(frostig): no longer need to jit?
|
||||
@partial(jit, inline=True)
|
||||
def threefry_2x32(keypair, count):
|
||||
"""Apply the Threefry 2x32 hash.
|
||||
@ -461,6 +713,7 @@ def threefry_2x32(keypair, count):
|
||||
def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
return _threefry_split(key, int(num)) # type: ignore
|
||||
|
||||
# TODO(frostig): no longer need to jit?
|
||||
@partial(jit, static_argnums=(1,), inline=True)
|
||||
def _threefry_split(key, num) -> jnp.ndarray:
|
||||
counts = lax.iota(np.uint32, num * 2)
|
||||
@ -470,11 +723,13 @@ def _threefry_split(key, num) -> jnp.ndarray:
|
||||
def threefry_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
||||
return _threefry_fold_in(key, jnp.uint32(data))
|
||||
|
||||
# TODO(frostig): no longer need to jit?
|
||||
@partial(jit, inline=True)
|
||||
def _threefry_fold_in(key, data):
|
||||
return threefry_2x32(key, threefry_seed(data))
|
||||
|
||||
|
||||
# TODO(frostig): no longer need to jit?
|
||||
@partial(jit, static_argnums=(1, 2), inline=True)
|
||||
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
||||
"""Sample uniform random bits of given width and shape using PRNG key."""
|
||||
@ -537,10 +792,11 @@ threefry_prng_impl = PRNGImpl(
|
||||
seed=threefry_seed,
|
||||
split=threefry_split,
|
||||
random_bits=threefry_random_bits,
|
||||
fold_in=threefry_fold_in)
|
||||
fold_in=threefry_fold_in,
|
||||
tag='fry')
|
||||
|
||||
|
||||
# -- RngBitGenerator PRNG implementation --
|
||||
# -- RngBitGenerator PRNG implementation
|
||||
|
||||
# This code is experimental!
|
||||
# https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator
|
||||
@ -572,7 +828,8 @@ rbg_prng_impl = PRNGImpl(
|
||||
seed=_rbg_seed,
|
||||
split=_rbg_split,
|
||||
random_bits=_rbg_random_bits,
|
||||
fold_in=_rbg_fold_in)
|
||||
fold_in=_rbg_fold_in,
|
||||
tag='rbg')
|
||||
|
||||
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
# treat 10 iterations of random bits as a 'hash function'
|
||||
@ -588,4 +845,5 @@ unsafe_rbg_prng_impl = PRNGImpl(
|
||||
seed=_rbg_seed,
|
||||
split=_unsafe_rbg_split,
|
||||
random_bits=_rbg_random_bits,
|
||||
fold_in=_unsafe_rbg_fold_in)
|
||||
fold_in=_unsafe_rbg_fold_in,
|
||||
tag='urbg')
|
||||
|
@ -61,20 +61,14 @@ def _isnan(x):
|
||||
|
||||
def _check_prng_key(key):
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
if type(key) is prng.PRNGKeyArray:
|
||||
if config.jax_enable_custom_prng:
|
||||
return key, False
|
||||
elif _arraylike(key):
|
||||
if config.jax_enable_custom_prng:
|
||||
warnings.warn(
|
||||
'Raw arrays as random keys to jax.random functions are deprecated. '
|
||||
'Assuming valid threefry2x32 key for now.',
|
||||
FutureWarning)
|
||||
return prng.PRNGKeyArray(default_prng_impl(), key), True
|
||||
else:
|
||||
raise TypeError(f'unexpected PRNG key type {type(key)}')
|
||||
|
||||
def _return_prng_keys(was_wrapped, key):
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
return key
|
||||
assert type(key) is prng.PRNGKeyArray, type(key)
|
||||
if config.jax_enable_custom_prng:
|
||||
return key
|
||||
@ -82,8 +76,7 @@ def _return_prng_keys(was_wrapped, key):
|
||||
return key.unsafe_raw_array() if was_wrapped else key
|
||||
|
||||
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
|
||||
key, _ = _check_prng_key(key)
|
||||
return key._random_bits(bit_width, shape)
|
||||
return prng.random_bits(key, bit_width=bit_width, shape=shape)
|
||||
|
||||
|
||||
PRNG_IMPLS = {
|
||||
@ -158,7 +151,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray:
|
||||
def _fold_in(key: KeyArray, data: int) -> KeyArray:
|
||||
# Alternative to fold_in() to use within random samplers.
|
||||
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
|
||||
return key._fold_in(jnp.uint32(data))
|
||||
return prng.random_bits(key, jnp.uint32(data))
|
||||
|
||||
def fold_in(key: KeyArray, data: int) -> KeyArray:
|
||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||
@ -177,7 +170,7 @@ def fold_in(key: KeyArray, data: int) -> KeyArray:
|
||||
def _split(key: KeyArray, num: int = 2) -> KeyArray:
|
||||
# Alternative to split() to use within random samplers.
|
||||
# TODO(frostig): remove and use split() once we always enable_custom_prng
|
||||
return key._split(num)
|
||||
return prng.random_split(key, count=num)
|
||||
|
||||
def split(key: KeyArray, num: int = 2) -> KeyArray:
|
||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||
|
@ -1096,6 +1096,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
|
||||
else:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||
|
||||
out, tokens = jaxpr_subcomp(
|
||||
ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts),
|
||||
|
Loading…
x
Reference in New Issue
Block a user