mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
rbg_split and rbg_fold_in: use vmap for fewer HLOs
This commit is contained in:
parent
b002bc178e
commit
022cb8c0fc
@ -505,12 +505,10 @@ def _rbg_seed(seed: int) -> jnp.ndarray:
|
||||
return jnp.concatenate([halfkey, halfkey])
|
||||
|
||||
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
return jnp.concatenate([_threefry_split(key[:2], num),
|
||||
_threefry_split(key[2:], num)], axis=1)
|
||||
return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
|
||||
|
||||
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
||||
return jnp.concatenate([_threefry_fold_in(key[:2], data),
|
||||
_threefry_fold_in(key[2:], data)])
|
||||
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
|
||||
|
||||
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
|
||||
) -> jnp.ndarray:
|
||||
|
Loading…
x
Reference in New Issue
Block a user