mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add Pallas Philox implementation.
Implemented in the same style as the threefry kernel. Philox is roughly 2x faster than the existing JAX Threefry implementation in both runtime and compile time. PiperOrigin-RevId: 707276043
This commit is contained in:
parent
d4031e9963
commit
7e96914e61
208
jax/experimental/pallas/ops/tpu/random/philox.py
Normal file
208
jax/experimental/pallas/ops/tpu/random/philox.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of the Philox PRNG as a Pallas kernel."""
|
||||
from typing import Sequence
|
||||
import jax
|
||||
from jax import typing
|
||||
from jax._src import prng
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.experimental.pallas.ops.tpu.random import prng_utils
|
||||
|
||||
Shape = Sequence[int]
|
||||
|
||||
BLOCK_SIZE = (256, 256)
|
||||
|
||||
# Philox constants. See original paper at:
|
||||
# "Parallel Random Numbers: As Easy as 1, 2, 3", Salmon et. al. 2011
|
||||
K_HI_32 = 0x9E3779B9
|
||||
K_LO_32 = 0xBB67AE85
|
||||
MUL_A = 0xCD9E8D57
|
||||
MUL_B = 0xD2511F53
|
||||
|
||||
|
||||
def mul32_hi_lo(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
|
||||
"""Multiplies 2 32-bit values and returns the hi+low bits of the result."""
|
||||
xhi = x >> 16
|
||||
yhi = y >> 16
|
||||
xlo = x & 0xffff
|
||||
ylo = y & 0xffff
|
||||
|
||||
xy_hi = xhi * yhi
|
||||
xy_lo = xlo * ylo
|
||||
cross_xy = xhi * ylo
|
||||
cross_yx = xlo * yhi
|
||||
carry = (cross_xy & 0xffff) + (cross_yx & 0xffff) + (xy_lo >> 16)
|
||||
return xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16), xy_lo
|
||||
|
||||
|
||||
def philox_4x32(hi0, lo0, hi1, lo1, k_hi, k_lo, rounds = 10):
|
||||
"""Philox 4x32 keyed hash function."""
|
||||
k_hi_const = jnp.array(K_HI_32, dtype=jnp.uint32)
|
||||
k_lo_const = jnp.array(K_LO_32, dtype=jnp.uint32)
|
||||
mul_a = jnp.array(MUL_A, dtype=jnp.uint32)
|
||||
mul_b = jnp.array(MUL_B, dtype=jnp.uint32)
|
||||
|
||||
for i in range(rounds):
|
||||
# Compute the round.
|
||||
new_hi0, new_lo0 = mul32_hi_lo(mul_a, hi1)
|
||||
new_hi0 = new_hi0 ^ lo0 ^ k_hi
|
||||
new_hi1, new_lo1 = mul32_hi_lo(mul_b, hi0)
|
||||
new_hi1 = new_hi1 ^ lo1 ^ k_lo
|
||||
hi0, lo0, hi1, lo1 = new_hi0, new_lo0, new_hi1, new_lo1
|
||||
|
||||
# Raise the key on all iterations except for the last round.
|
||||
if i != rounds - 1:
|
||||
k_hi = k_hi + k_hi_const
|
||||
k_lo = k_lo + k_lo_const
|
||||
return hi0, lo0, hi1, lo1
|
||||
|
||||
|
||||
def philox_4x32_kernel(key,
|
||||
shape: Shape,
|
||||
unpadded_shape: Shape,
|
||||
block_size: tuple[int, int],
|
||||
offset: typing.ArrayLike = 0,
|
||||
fuse_output: bool = True):
|
||||
"""Generates random bits using the Philox keyed hash function.
|
||||
|
||||
Args:
|
||||
key: A Philox key of shape (2,).
|
||||
shape: The shape of the output. Must be divisible by `block_size`.
|
||||
unpadded_shape: If `shape` is padded, then this is the shape of the
|
||||
output tensor if it were not padded. This is important for indexing
|
||||
calculations within the kernel. If `shape` is not padded, then this
|
||||
should be equal to `shape`.
|
||||
block_size: The block size of the kernel.
|
||||
offset: An optional offset to the counts.
|
||||
fuse_output: Whether to fuse the output bits into a single value.
|
||||
|
||||
Returns:
|
||||
A tensor of random bits of shape `shape` if fuse_output=True. Otherwise,
|
||||
this will return a tensor of shape (2, *shape) with the first channel being
|
||||
the high bits and the second channel being the low bits.
|
||||
"""
|
||||
shape = tuple(shape)
|
||||
if np.prod(shape) > jnp.iinfo(jnp.uint32).max:
|
||||
raise ValueError(
|
||||
f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}")
|
||||
|
||||
if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0):
|
||||
raise ValueError(
|
||||
f"Shape dimension {shape[-2:]} must be divisible by {block_size}")
|
||||
grid_dims = shape[:-2] + (
|
||||
shape[-2] // block_size[-2], shape[-1] // block_size[1],)
|
||||
offset = jnp.array(offset, dtype=jnp.uint32)
|
||||
if offset.ndim != 0:
|
||||
raise ValueError(f"Offset must be scalar, got {offset.shape}")
|
||||
offset = jnp.reshape(offset, (1,))
|
||||
|
||||
def kernel(offset_ref, key_ref, out_ref):
|
||||
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
|
||||
offset = prng_utils.compute_scalar_offset(
|
||||
counts_idx, unpadded_shape, block_shape)
|
||||
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
|
||||
counts_lo = counts_lo + offset + offset_ref[0]
|
||||
counts_lo = counts_lo.astype(jnp.uint32)
|
||||
# TODO(justinfu): Support hi bits on count.
|
||||
_zeros = jnp.zeros_like(counts_lo)
|
||||
k1 = jnp.reshape(key_ref[0, 0], (1, 1))
|
||||
k2 = jnp.reshape(key_ref[0, 1], (1, 1))
|
||||
o1, o2, _, _ = philox_4x32(_zeros, counts_lo, _zeros, _zeros, k1, k2)
|
||||
if fuse_output:
|
||||
out_bits = o1 ^ o2
|
||||
out_ref[...] = out_bits.reshape(out_ref.shape)
|
||||
else:
|
||||
out_ref[0, ...] = o1.reshape(out_ref[0].shape)
|
||||
out_ref[1, ...] = o2.reshape(out_ref[0].shape)
|
||||
|
||||
key = key.reshape((1, 2))
|
||||
block_shape = (1,) * (len(shape)-2) + block_size
|
||||
if fuse_output:
|
||||
out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32)
|
||||
out_spec = pl.BlockSpec(block_shape, lambda *idxs: idxs)
|
||||
else:
|
||||
out = jax.ShapeDtypeStruct((2,) + shape, dtype=jnp.uint32)
|
||||
out_spec = pl.BlockSpec((2,) + block_shape, lambda *idxs: (0, *idxs))
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
|
||||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
|
||||
],
|
||||
out_specs=out_spec,
|
||||
grid=grid_dims,
|
||||
out_shape=out,
|
||||
)(offset, key)
|
||||
|
||||
|
||||
def philox_4x32_count(key,
|
||||
shape: Shape,
|
||||
offset: typing.ArrayLike = 0,
|
||||
fuse_output: bool = True):
|
||||
"""Convenience function to call philox_4x32_kernel with padded shapes."""
|
||||
if len(shape) == 0:
|
||||
return philox_4x32_count(
|
||||
key, (1, 1), offset=offset, fuse_output=fuse_output)[..., 0, 0]
|
||||
elif len(shape) == 1:
|
||||
return philox_4x32_count(
|
||||
key, (1, *shape), offset=offset, fuse_output=fuse_output)[..., 0, :]
|
||||
|
||||
requires_pad = (
|
||||
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
|
||||
if requires_pad:
|
||||
padded_shape = tuple(shape[:-2]) + (
|
||||
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
|
||||
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
|
||||
)
|
||||
padded_result = philox_4x32_kernel(
|
||||
key, padded_shape, shape,
|
||||
block_size=BLOCK_SIZE, offset=offset,
|
||||
fuse_output=fuse_output)
|
||||
return padded_result[..., :shape[-2], :shape[-1]]
|
||||
else:
|
||||
return philox_4x32_kernel(key, shape, shape,
|
||||
block_size=BLOCK_SIZE, offset=offset,
|
||||
fuse_output=fuse_output)
|
||||
|
||||
|
||||
def philox_split(key, shape: Shape):
|
||||
"""Splits the key into two keys of the same shape."""
|
||||
bits1, bits2 = philox_4x32_count(key, shape, fuse_output=False)
|
||||
return jnp.stack([bits1, bits2], axis=bits1.ndim)
|
||||
|
||||
|
||||
def philox_random_bits(key, bit_width: int, shape: Shape):
|
||||
if bit_width != 32:
|
||||
raise ValueError("Only 32-bit PRNG supported.")
|
||||
return philox_4x32_count(key, shape, fuse_output=True)
|
||||
|
||||
|
||||
def philox_fold_in(key, data):
|
||||
assert data.ndim == 0
|
||||
return philox_4x32_count(key, (), offset=data, fuse_output=False)
|
||||
|
||||
|
||||
plphilox_prng_impl = prng.PRNGImpl(
|
||||
key_shape=(2,),
|
||||
seed=prng.threefry_seed,
|
||||
split=philox_split,
|
||||
random_bits=philox_random_bits,
|
||||
fold_in=philox_fold_in,
|
||||
name="pallas_philox4x32",
|
||||
tag="pllox")
|
||||
|
||||
prng.register_prng(plphilox_prng_impl)
|
55
jax/experimental/pallas/ops/tpu/random/prng_utils.py
Normal file
55
jax/experimental/pallas/ops/tpu/random/prng_utils.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper functions for PRNG kernels."""
|
||||
from typing import Sequence
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
Shape = Sequence[int]
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
|
||||
def blocked_iota(block_shape: Shape,
|
||||
total_shape: Shape):
|
||||
"""Computes a sub-block of a larger shaped iota.
|
||||
|
||||
Args:
|
||||
block_shape: The output block shape of the iota.
|
||||
total_shape: The total shape of the input tensor.
|
||||
Returns:
|
||||
Result of the blocked iota.
|
||||
"""
|
||||
iota_data = jnp.zeros(block_shape, dtype=jnp.uint32)
|
||||
multiplier = 1
|
||||
for dim in range(len(block_shape)-1, -1, -1):
|
||||
block_mult = 1
|
||||
counts_lo = lax.broadcasted_iota(
|
||||
dtype=jnp.uint32, shape=block_shape, dimension=dim
|
||||
)
|
||||
iota_data += counts_lo * multiplier * block_mult
|
||||
multiplier *= total_shape[dim]
|
||||
return iota_data
|
||||
|
||||
|
||||
def compute_scalar_offset(iteration_index,
|
||||
total_size: Shape,
|
||||
block_size: Shape):
|
||||
ndims = len(iteration_index)
|
||||
dim_size = 1
|
||||
total_idx = 0
|
||||
for i in range(ndims-1, -1, -1):
|
||||
dim_idx = iteration_index[i] * block_size[i]
|
||||
total_idx += dim_idx * dim_size
|
||||
dim_size *= total_size[i]
|
||||
return total_idx
|
@ -14,54 +14,17 @@
|
||||
"""Implementation of the Threefry PRNG as a Pallas kernel."""
|
||||
from typing import Sequence
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import prng
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.experimental.pallas.ops.tpu.random import prng_utils
|
||||
|
||||
Shape = Sequence[int]
|
||||
|
||||
BLOCK_SIZE = (256, 256)
|
||||
|
||||
_round_up = lambda x, y: (x + y - 1) // y * y
|
||||
|
||||
|
||||
def blocked_iota(block_shape: Shape,
|
||||
total_shape: Shape):
|
||||
"""Computes a sub-block of a larger shaped iota.
|
||||
|
||||
Args:
|
||||
block_shape: The output block shape of the iota.
|
||||
total_shape: The total shape of the input tensor.
|
||||
Returns:
|
||||
Result of the blocked iota.
|
||||
"""
|
||||
iota_data = jnp.zeros(block_shape, dtype=jnp.uint32)
|
||||
multiplier = 1
|
||||
for dim in range(len(block_shape)-1, -1, -1):
|
||||
block_mult = 1
|
||||
counts_lo = lax.broadcasted_iota(
|
||||
dtype=jnp.uint32, shape=block_shape, dimension=dim
|
||||
)
|
||||
iota_data += counts_lo * multiplier * block_mult
|
||||
multiplier *= total_shape[dim]
|
||||
return iota_data
|
||||
|
||||
|
||||
def _compute_scalar_offset(iteration_index,
|
||||
total_size: Shape,
|
||||
block_size: Shape):
|
||||
ndims = len(iteration_index)
|
||||
dim_size = 1
|
||||
total_idx = 0
|
||||
for i in range(ndims-1, -1, -1):
|
||||
dim_idx = iteration_index[i] * block_size[i]
|
||||
total_idx += dim_idx * dim_size
|
||||
dim_size *= total_size[i]
|
||||
return total_idx
|
||||
|
||||
|
||||
def threefry_2x32_count(key,
|
||||
shape: Shape,
|
||||
@ -97,8 +60,9 @@ def threefry_2x32_count(key,
|
||||
|
||||
def kernel(key_ref, out_ref):
|
||||
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
|
||||
offset = _compute_scalar_offset(counts_idx, unpadded_shape, block_shape)
|
||||
counts_lo = blocked_iota(block_size, unpadded_shape)
|
||||
offset = prng_utils.compute_scalar_offset(
|
||||
counts_idx, unpadded_shape, block_shape)
|
||||
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
|
||||
counts_lo = counts_lo + offset
|
||||
counts_lo = counts_lo.astype(jnp.uint32)
|
||||
# TODO(justinfu): Support hi bits on count.
|
||||
@ -134,8 +98,8 @@ def plthreefry_random_bits(key, bit_width: int, shape: Shape):
|
||||
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
|
||||
if requires_pad:
|
||||
padded_shape = tuple(shape[:-2]) + (
|
||||
_round_up(shape[-2], BLOCK_SIZE[-2]),
|
||||
_round_up(shape[-1], BLOCK_SIZE[-1]),
|
||||
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
|
||||
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
|
||||
)
|
||||
padded_result = threefry_2x32_count(
|
||||
key, padded_shape, shape, block_size=BLOCK_SIZE)
|
||||
|
@ -22,6 +22,7 @@ from jax._src.pallas.mosaic import random as plrandom
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
from jax.experimental.pallas.ops.tpu.random import philox # pylint: disable=unused-import # noqa: F401
|
||||
from jax.experimental.pallas.ops.tpu.random import threefry # pylint: disable=unused-import # noqa: F401
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -305,5 +306,56 @@ class ThreefryTest(parameterized.TestCase):
|
||||
np.testing.assert_array_equal(jax_gen, pl_gen)
|
||||
|
||||
|
||||
class PhiloxTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Need TPU devices")
|
||||
super().setUp()
|
||||
|
||||
@parameterized.parameters(
|
||||
((512, 512),),
|
||||
((137, 275),), # Non block-aligned shape
|
||||
((4, 512, 512),), # Greater than 2D shape
|
||||
((34,),), # 1D
|
||||
(tuple(),), # 0D
|
||||
)
|
||||
def test_generate_uniform(self, shape):
|
||||
key = jax_random.key(0, impl="pallas_philox4x32")
|
||||
values = jax_random.uniform(key, shape=shape)
|
||||
minval = jnp.min(values)
|
||||
maxval = jnp.max(values)
|
||||
self.assertGreaterEqual(minval, 0.0)
|
||||
self.assertLessEqual(maxval, 1.0)
|
||||
|
||||
def test_split(self):
|
||||
key = jax_random.key(0, impl="pallas_philox4x32")
|
||||
key1, key2 = jax_random.split(key)
|
||||
key_data = jax.random.key_data(key)
|
||||
key1_data = jax.random.key_data(key1)
|
||||
key2_data = jax.random.key_data(key2)
|
||||
# Assert all keys are different.
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_array_equal(key_data, key1_data)
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_array_equal(key_data, key2_data)
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_array_equal(key1_data, key2_data)
|
||||
|
||||
def test_foldin(self):
|
||||
key = jax_random.key(0, impl="pallas_philox4x32")
|
||||
new_key_1 = jax_random.fold_in(key, 1)
|
||||
new_key_2 = jax_random.fold_in(key, 2)
|
||||
key_data = jax.random.key_data(key)
|
||||
key1_data = jax.random.key_data(new_key_1)
|
||||
key2_data = jax.random.key_data(new_key_2)
|
||||
# Assert all keys are different.
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_array_equal(key_data, key1_data)
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_array_equal(key_data, key2_data)
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_array_equal(key1_data, key2_data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user