Compute tile index using tile-based coordinates

This reduces the chances of overflowing a 32-bit integer when computing tile indices.
Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`.

PiperOrigin-RevId: 737778853
This commit is contained in:
jax authors 2025-03-17 15:45:39 -07:00
parent b74b16f9b9
commit b4966130a3
2 changed files with 78 additions and 21 deletions

View File

@ -28,17 +28,17 @@ class SampleFn(Protocol):
...
def _compute_scalar_index(iteration_index: Sequence[int],
total_size: Shape,
block_size: Shape,
block_index: Sequence[int]) -> int:
ndims = len(iteration_index)
def _compute_tile_index(block_index: Sequence[int],
total_size_in_blocks: Shape,
block_size_in_tiles: Shape,
tile_index_in_block: Sequence[int]) -> int:
ndims = len(block_index)
dim_size = 1
total_idx = 0
for i in range(ndims-1, -1, -1):
dim_idx = block_index[i] + iteration_index[i] * block_size[i]
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
total_idx += dim_idx * dim_size
dim_size *= total_size[i]
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
return total_idx
@ -99,18 +99,23 @@ def blocked_fold_in(
An N-dimensional nested list of keys required to sample the tiles
corresponding to the block specified by `block_index`.
"""
size_in_blocks = tuple(
_shape // _element for _shape, _element in zip(block_size, tile_size))
block_size_in_tiles = tuple(
_shape // _element for _shape, _element in zip(block_size, tile_size)
)
total_size_in_blocks = tuple(
_shape // _element for _shape, _element in zip(total_size, block_size)
)
def _keygen_loop(axis, prefix):
if axis == len(size_in_blocks):
if axis == len(block_size_in_tiles):
subtile_key = jax.random.fold_in(
global_key, _compute_scalar_index(
block_index, total_size, size_in_blocks, prefix))
global_key, _compute_tile_index(
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
return subtile_key
else:
keys = []
for i in range(size_in_blocks[axis]):
for i in range(block_size_in_tiles[axis]):
keys.append(_keygen_loop(axis+1, prefix+(i,)))
return keys
return _keygen_loop(0, tuple())

View File

@ -37,18 +37,41 @@ def call_kernel(
m, n = grid
return jnp.concatenate([
jnp.concatenate([
kernel(i, j, *args) for j in range(n)], axis=1)
kernel((i, j), *args) for j in range(n)], axis=1)
for i in range(m)], axis=0)
def uniform_kernel(i: int, j: int, total_size, block_size, tile_size):
"""Uniform random sampling kernel function."""
global_key = jax.random.key(0)
keys = blocked_sampler.blocked_fold_in(global_key,
def call_kernel_3d(
kernel,
grid: tuple[int, int],
*args
):
"""Calls a kernel over a 3D grid and concatenates results to a single array."""
depth, rows, cols = grid
return jnp.concatenate([
jnp.concatenate([
jnp.concatenate([
jnp.array(kernel((i, j, k), *args))
for k in range(cols)], axis=2)
for j in range(rows)], axis=1)
for i in range(depth)], axis=0)
def blocked_fold_in(block_index, key, total_size, block_size, tile_size):
"""Folds in block_index into global_key."""
return blocked_sampler.blocked_fold_in(key,
total_size=total_size,
block_size=block_size,
tile_size=tile_size,
block_index=(i, j))
block_index=block_index)
def uniform_kernel(block_index, key, total_size, block_size, tile_size):
"""Uniform random sampling kernel function."""
keys = blocked_fold_in(block_index, key,
total_size=total_size,
block_size=block_size,
tile_size=tile_size)
return blocked_sampler.sample_block(jax.random.uniform,
keys,
block_size=block_size,
@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase):
)
def test_block_shape_invariance(self, total_size, block_size_a,
block_size_b, tile_size, transpose_grid):
global_key = jax.random.key(0)
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
result_a = call_kernel(
uniform_kernel, grid_a, transpose_grid,
uniform_kernel, grid_a, transpose_grid, global_key,
total_size, block_size_a, tile_size)
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
result_b = call_kernel(
uniform_kernel, grid_b, transpose_grid,
uniform_kernel, grid_b, transpose_grid, global_key,
total_size, block_size_b, tile_size)
np.testing.assert_array_equal(result_a, result_b)
class BlockedFoldInTest(jtu.JaxTestCase):
@parameterized.named_parameters(
# Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works
# as expected. Specifically, blocked key folding does not depend on the total
# size of the tensor, but only the total number of tiles.
# Using a 3D grid (with very large inner dimensions) triggers an overflow in a
# previous implementation of blocked_fold_in.
dict(testcase_name='4096x512_vs_1024x2048',
total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512),
block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)),
)
def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a,
block_size_b, tile_size):
global_key = jax.random.key(0)
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
result_a = call_kernel_3d(
blocked_fold_in, grid_a, global_key, total_size,
block_size_a, tile_size)
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
result_b = call_kernel_3d(
blocked_fold_in, grid_b, global_key, total_size,
block_size_b, tile_size)
np.testing.assert_array_equal(jax.random.key_data(result_a),
jax.random.key_data(result_b))
if __name__ == "__main__":
absltest.main()