Make blocked_fold_in consistent when the block sizes induce padding

Add coverage for padded shapes to unit tests.

PiperOrigin-RevId: 738029476
This commit is contained in:
jax authors 2025-03-18 09:11:28 -07:00
parent 1e36cbe597
commit 13541e9f12
2 changed files with 34 additions and 16 deletions

View File

@ -29,8 +29,8 @@ class SampleFn(Protocol):
def _compute_tile_index(block_index: Sequence[int],
total_size_in_blocks: Shape,
block_size_in_tiles: Shape,
total_size_in_tiles: Shape,
tile_index_in_block: Sequence[int]) -> int:
ndims = len(block_index)
dim_size = 1
@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int],
for i in range(ndims-1, -1, -1):
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
total_idx += dim_idx * dim_size
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
dim_size *= total_size_in_tiles[i]
return total_idx
@ -103,15 +103,17 @@ def blocked_fold_in(
_shape // _element for _shape, _element in zip(block_size, tile_size)
)
total_size_in_blocks = tuple(
_shape // _element for _shape, _element in zip(total_size, block_size)
# Round up to make sure every tile is numbered.
total_size_in_tiles = tuple(
(_shape + _element - 1) // _element
for _shape, _element in zip(total_size, tile_size)
)
def _keygen_loop(axis, prefix):
if axis == len(block_size_in_tiles):
subtile_key = jax.random.fold_in(
global_key, _compute_tile_index(
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
return subtile_key
else:
keys = []

View File

@ -29,16 +29,23 @@ def call_kernel(
kernel,
grid: tuple[int, int],
transpose_grid: bool,
*args
key: jax.Array,
total_size: tuple[int, int],
block_size: tuple[int, int],
tile_size: tuple[int, int],
):
"""Calls a kernel over a grid and concatenates results to a single array."""
if transpose_grid:
grid = (grid[1], grid[0])
m, n = grid
return jnp.concatenate([
jnp.concatenate([
kernel((i, j), *args) for j in range(n)], axis=1)
for i in range(m)], axis=0)
samples = jnp.concatenate([
jnp.concatenate([
kernel((i, j), key, total_size, block_size, tile_size)
for j in range(n)], axis=1)
for i in range(m)], axis=0)
# Slice out the padding.
samples = samples[:total_size[0], :total_size[1]]
return samples
def call_kernel_3d(
@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
block_size=block_size,
tile_size=tile_size)
return blocked_sampler.sample_block(jax.random.uniform,
keys,
block_size=block_size,
tile_size=tile_size,
minval=0.0, maxval=1.0)
keys,
block_size=block_size,
tile_size=tile_size,
minval=0.0, maxval=1.0)
class BlockedSamplerTest(jtu.JaxTestCase):
@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
block_size_a=(16, 256), block_size_b=(32, 128),
tile_size=(8, 128), transpose_grid=False),
dict(testcase_name='128x128_vs_128x256_padding',
total_size=(256, 128), block_size_a=(128, 128),
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
dict(testcase_name='128x128_vs_128x256_padding2',
total_size=(257, 129), block_size_a=(128, 128),
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
)
def test_block_shape_invariance(self, total_size, block_size_a,
block_size_b, tile_size, transpose_grid):
global_key = jax.random.key(0)
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
ceil_div = lambda x, y: (x + y - 1) // y
grid_a = tuple(ceil_div(_tot, _blk)
for _tot, _blk in zip(total_size, block_size_a))
result_a = call_kernel(
uniform_kernel, grid_a, transpose_grid, global_key,
total_size, block_size_a, tile_size)
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
grid_b = tuple(ceil_div(_tot, _blk)
for _tot, _blk in zip(total_size, block_size_b))
result_b = call_kernel(
uniform_kernel, grid_b, transpose_grid, global_key,
total_size, block_size_b, tile_size)