From b4966130a355d64759bcfae66a81153f32b68c89 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 17 Mar 2025 15:45:39 -0700 Subject: [PATCH] Compute tile index using tile-based coordinates This reduces the chances of overflowing a 32-bit integer when computing tile indices. Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`. PiperOrigin-RevId: 737778853 --- jax/_src/blocked_sampler.py | 31 +++++++++------- tests/blocked_sampler_test.py | 68 ++++++++++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 21 deletions(-) diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index 3bc592d88..e4d2e2855 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -28,17 +28,17 @@ class SampleFn(Protocol): ... -def _compute_scalar_index(iteration_index: Sequence[int], - total_size: Shape, - block_size: Shape, - block_index: Sequence[int]) -> int: - ndims = len(iteration_index) +def _compute_tile_index(block_index: Sequence[int], + total_size_in_blocks: Shape, + block_size_in_tiles: Shape, + tile_index_in_block: Sequence[int]) -> int: + ndims = len(block_index) dim_size = 1 total_idx = 0 for i in range(ndims-1, -1, -1): - dim_idx = block_index[i] + iteration_index[i] * block_size[i] + dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i] total_idx += dim_idx * dim_size - dim_size *= total_size[i] + dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i] return total_idx @@ -99,18 +99,23 @@ def blocked_fold_in( An N-dimensional nested list of keys required to sample the tiles corresponding to the block specified by `block_index`. """ - size_in_blocks = tuple( - _shape // _element for _shape, _element in zip(block_size, tile_size)) + block_size_in_tiles = tuple( + _shape // _element for _shape, _element in zip(block_size, tile_size) + ) + + total_size_in_blocks = tuple( + _shape // _element for _shape, _element in zip(total_size, block_size) + ) def _keygen_loop(axis, prefix): - if axis == len(size_in_blocks): + if axis == len(block_size_in_tiles): subtile_key = jax.random.fold_in( - global_key, _compute_scalar_index( - block_index, total_size, size_in_blocks, prefix)) + global_key, _compute_tile_index( + block_index, total_size_in_blocks, block_size_in_tiles, prefix)) return subtile_key else: keys = [] - for i in range(size_in_blocks[axis]): + for i in range(block_size_in_tiles[axis]): keys.append(_keygen_loop(axis+1, prefix+(i,))) return keys return _keygen_loop(0, tuple()) diff --git a/tests/blocked_sampler_test.py b/tests/blocked_sampler_test.py index 1f8f2b645..4c27e850c 100644 --- a/tests/blocked_sampler_test.py +++ b/tests/blocked_sampler_test.py @@ -37,18 +37,41 @@ def call_kernel( m, n = grid return jnp.concatenate([ jnp.concatenate([ - kernel(i, j, *args) for j in range(n)], axis=1) + kernel((i, j), *args) for j in range(n)], axis=1) for i in range(m)], axis=0) -def uniform_kernel(i: int, j: int, total_size, block_size, tile_size): - """Uniform random sampling kernel function.""" - global_key = jax.random.key(0) - keys = blocked_sampler.blocked_fold_in(global_key, +def call_kernel_3d( + kernel, + grid: tuple[int, int], + *args + ): + """Calls a kernel over a 3D grid and concatenates results to a single array.""" + depth, rows, cols = grid + return jnp.concatenate([ + jnp.concatenate([ + jnp.concatenate([ + jnp.array(kernel((i, j, k), *args)) + for k in range(cols)], axis=2) + for j in range(rows)], axis=1) + for i in range(depth)], axis=0) + + +def blocked_fold_in(block_index, key, total_size, block_size, tile_size): + """Folds in block_index into global_key.""" + return blocked_sampler.blocked_fold_in(key, total_size=total_size, block_size=block_size, tile_size=tile_size, - block_index=(i, j)) + block_index=block_index) + + +def uniform_kernel(block_index, key, total_size, block_size, tile_size): + """Uniform random sampling kernel function.""" + keys = blocked_fold_in(block_index, key, + total_size=total_size, + block_size=block_size, + tile_size=tile_size) return blocked_sampler.sample_block(jax.random.uniform, keys, block_size=block_size, @@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase): ) def test_block_shape_invariance(self, total_size, block_size_a, block_size_b, tile_size, transpose_grid): + global_key = jax.random.key(0) grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) result_a = call_kernel( - uniform_kernel, grid_a, transpose_grid, + uniform_kernel, grid_a, transpose_grid, global_key, total_size, block_size_a, tile_size) grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) result_b = call_kernel( - uniform_kernel, grid_b, transpose_grid, + uniform_kernel, grid_b, transpose_grid, global_key, total_size, block_size_b, tile_size) np.testing.assert_array_equal(result_a, result_b) +class BlockedFoldInTest(jtu.JaxTestCase): + @parameterized.named_parameters( + # Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works + # as expected. Specifically, blocked key folding does not depend on the total + # size of the tensor, but only the total number of tiles. + # Using a 3D grid (with very large inner dimensions) triggers an overflow in a + # previous implementation of blocked_fold_in. + dict(testcase_name='4096x512_vs_1024x2048', + total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512), + block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)), + ) + def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a, + block_size_b, tile_size): + global_key = jax.random.key(0) + grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) + result_a = call_kernel_3d( + blocked_fold_in, grid_a, global_key, total_size, + block_size_a, tile_size) + + grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) + result_b = call_kernel_3d( + blocked_fold_in, grid_b, global_key, total_size, + block_size_b, tile_size) + np.testing.assert_array_equal(jax.random.key_data(result_a), + jax.random.key_data(result_b)) + + + if __name__ == "__main__": absltest.main()