rbg_split and rbg_fold_in: use vmap for fewer HLOs

This commit is contained in:
Matthew Johnson 2021-10-07 21:19:06 -07:00
parent b002bc178e
commit 022cb8c0fc

View File

@ -505,12 +505,10 @@ def _rbg_seed(seed: int) -> jnp.ndarray:
return jnp.concatenate([halfkey, halfkey])
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
return jnp.concatenate([_threefry_split(key[:2], num),
_threefry_split(key[2:], num)], axis=1)
return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
return jnp.concatenate([_threefry_fold_in(key[:2], data),
_threefry_fold_in(key[2:], data)])
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
) -> jnp.ndarray: