From 13541e9f12d1589890a9384f35e26b51e2111cc8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 18 Mar 2025 09:11:28 -0700 Subject: [PATCH] Make blocked_fold_in consistent when the block sizes induce padding Add coverage for padded shapes to unit tests. PiperOrigin-RevId: 738029476 --- jax/_src/blocked_sampler.py | 12 ++++++----- tests/blocked_sampler_test.py | 38 +++++++++++++++++++++++++---------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index e4d2e2855..3021b6a16 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -29,8 +29,8 @@ class SampleFn(Protocol): def _compute_tile_index(block_index: Sequence[int], - total_size_in_blocks: Shape, block_size_in_tiles: Shape, + total_size_in_tiles: Shape, tile_index_in_block: Sequence[int]) -> int: ndims = len(block_index) dim_size = 1 @@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int], for i in range(ndims-1, -1, -1): dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i] total_idx += dim_idx * dim_size - dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i] + dim_size *= total_size_in_tiles[i] return total_idx @@ -103,15 +103,17 @@ def blocked_fold_in( _shape // _element for _shape, _element in zip(block_size, tile_size) ) - total_size_in_blocks = tuple( - _shape // _element for _shape, _element in zip(total_size, block_size) + # Round up to make sure every tile is numbered. + total_size_in_tiles = tuple( + (_shape + _element - 1) // _element + for _shape, _element in zip(total_size, tile_size) ) def _keygen_loop(axis, prefix): if axis == len(block_size_in_tiles): subtile_key = jax.random.fold_in( global_key, _compute_tile_index( - block_index, total_size_in_blocks, block_size_in_tiles, prefix)) + block_index, block_size_in_tiles, total_size_in_tiles, prefix)) return subtile_key else: keys = [] diff --git a/tests/blocked_sampler_test.py b/tests/blocked_sampler_test.py index 4c27e850c..b5f87fe05 100644 --- a/tests/blocked_sampler_test.py +++ b/tests/blocked_sampler_test.py @@ -29,16 +29,23 @@ def call_kernel( kernel, grid: tuple[int, int], transpose_grid: bool, - *args + key: jax.Array, + total_size: tuple[int, int], + block_size: tuple[int, int], + tile_size: tuple[int, int], ): """Calls a kernel over a grid and concatenates results to a single array.""" if transpose_grid: grid = (grid[1], grid[0]) m, n = grid - return jnp.concatenate([ - jnp.concatenate([ - kernel((i, j), *args) for j in range(n)], axis=1) - for i in range(m)], axis=0) + samples = jnp.concatenate([ + jnp.concatenate([ + kernel((i, j), key, total_size, block_size, tile_size) + for j in range(n)], axis=1) + for i in range(m)], axis=0) + # Slice out the padding. + samples = samples[:total_size[0], :total_size[1]] + return samples def call_kernel_3d( @@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size): block_size=block_size, tile_size=tile_size) return blocked_sampler.sample_block(jax.random.uniform, - keys, - block_size=block_size, - tile_size=tile_size, - minval=0.0, maxval=1.0) + keys, + block_size=block_size, + tile_size=tile_size, + minval=0.0, maxval=1.0) class BlockedSamplerTest(jtu.JaxTestCase): @@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase): dict(testcase_name='16x256_vs_32x128', total_size=(32, 256), block_size_a=(16, 256), block_size_b=(32, 128), tile_size=(8, 128), transpose_grid=False), + dict(testcase_name='128x128_vs_128x256_padding', + total_size=(256, 128), block_size_a=(128, 128), + block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False), + dict(testcase_name='128x128_vs_128x256_padding2', + total_size=(257, 129), block_size_a=(128, 128), + block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False), ) def test_block_shape_invariance(self, total_size, block_size_a, block_size_b, tile_size, transpose_grid): global_key = jax.random.key(0) - grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a)) + ceil_div = lambda x, y: (x + y - 1) // y + grid_a = tuple(ceil_div(_tot, _blk) + for _tot, _blk in zip(total_size, block_size_a)) result_a = call_kernel( uniform_kernel, grid_a, transpose_grid, global_key, total_size, block_size_a, tile_size) - grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b)) + grid_b = tuple(ceil_div(_tot, _blk) + for _tot, _blk in zip(total_size, block_size_b)) result_b = call_kernel( uniform_kernel, grid_b, transpose_grid, global_key, total_size, block_size_b, tile_size)