mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Make blocked_fold_in consistent when the block sizes induce padding
Add coverage for padded shapes to unit tests. PiperOrigin-RevId: 738029476
This commit is contained in:
parent
1e36cbe597
commit
13541e9f12
@ -29,8 +29,8 @@ class SampleFn(Protocol):
|
||||
|
||||
|
||||
def _compute_tile_index(block_index: Sequence[int],
|
||||
total_size_in_blocks: Shape,
|
||||
block_size_in_tiles: Shape,
|
||||
total_size_in_tiles: Shape,
|
||||
tile_index_in_block: Sequence[int]) -> int:
|
||||
ndims = len(block_index)
|
||||
dim_size = 1
|
||||
@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int],
|
||||
for i in range(ndims-1, -1, -1):
|
||||
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
||||
total_idx += dim_idx * dim_size
|
||||
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
|
||||
dim_size *= total_size_in_tiles[i]
|
||||
return total_idx
|
||||
|
||||
|
||||
@ -103,15 +103,17 @@ def blocked_fold_in(
|
||||
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
||||
)
|
||||
|
||||
total_size_in_blocks = tuple(
|
||||
_shape // _element for _shape, _element in zip(total_size, block_size)
|
||||
# Round up to make sure every tile is numbered.
|
||||
total_size_in_tiles = tuple(
|
||||
(_shape + _element - 1) // _element
|
||||
for _shape, _element in zip(total_size, tile_size)
|
||||
)
|
||||
|
||||
def _keygen_loop(axis, prefix):
|
||||
if axis == len(block_size_in_tiles):
|
||||
subtile_key = jax.random.fold_in(
|
||||
global_key, _compute_tile_index(
|
||||
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
|
||||
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
|
||||
return subtile_key
|
||||
else:
|
||||
keys = []
|
||||
|
@ -29,16 +29,23 @@ def call_kernel(
|
||||
kernel,
|
||||
grid: tuple[int, int],
|
||||
transpose_grid: bool,
|
||||
*args
|
||||
key: jax.Array,
|
||||
total_size: tuple[int, int],
|
||||
block_size: tuple[int, int],
|
||||
tile_size: tuple[int, int],
|
||||
):
|
||||
"""Calls a kernel over a grid and concatenates results to a single array."""
|
||||
if transpose_grid:
|
||||
grid = (grid[1], grid[0])
|
||||
m, n = grid
|
||||
return jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
kernel((i, j), *args) for j in range(n)], axis=1)
|
||||
for i in range(m)], axis=0)
|
||||
samples = jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
kernel((i, j), key, total_size, block_size, tile_size)
|
||||
for j in range(n)], axis=1)
|
||||
for i in range(m)], axis=0)
|
||||
# Slice out the padding.
|
||||
samples = samples[:total_size[0], :total_size[1]]
|
||||
return samples
|
||||
|
||||
|
||||
def call_kernel_3d(
|
||||
@ -73,10 +80,10 @@ def uniform_kernel(block_index, key, total_size, block_size, tile_size):
|
||||
block_size=block_size,
|
||||
tile_size=tile_size)
|
||||
return blocked_sampler.sample_block(jax.random.uniform,
|
||||
keys,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size,
|
||||
minval=0.0, maxval=1.0)
|
||||
keys,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size,
|
||||
minval=0.0, maxval=1.0)
|
||||
|
||||
|
||||
class BlockedSamplerTest(jtu.JaxTestCase):
|
||||
@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
|
||||
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
|
||||
block_size_a=(16, 256), block_size_b=(32, 128),
|
||||
tile_size=(8, 128), transpose_grid=False),
|
||||
dict(testcase_name='128x128_vs_128x256_padding',
|
||||
total_size=(256, 128), block_size_a=(128, 128),
|
||||
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
|
||||
dict(testcase_name='128x128_vs_128x256_padding2',
|
||||
total_size=(257, 129), block_size_a=(128, 128),
|
||||
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
|
||||
)
|
||||
def test_block_shape_invariance(self, total_size, block_size_a,
|
||||
block_size_b, tile_size, transpose_grid):
|
||||
global_key = jax.random.key(0)
|
||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||
ceil_div = lambda x, y: (x + y - 1) // y
|
||||
grid_a = tuple(ceil_div(_tot, _blk)
|
||||
for _tot, _blk in zip(total_size, block_size_a))
|
||||
result_a = call_kernel(
|
||||
uniform_kernel, grid_a, transpose_grid, global_key,
|
||||
total_size, block_size_a, tile_size)
|
||||
|
||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||
grid_b = tuple(ceil_div(_tot, _blk)
|
||||
for _tot, _blk in zip(total_size, block_size_b))
|
||||
result_b = call_kernel(
|
||||
uniform_kernel, grid_b, transpose_grid, global_key,
|
||||
total_size, block_size_b, tile_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user