mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make blocked_fold_in consistent when the block sizes induce padding
Add coverage for padded shapes to unit tests. PiperOrigin-RevId: 738029476
This commit is contained in:
parent
1e36cbe597
commit
13541e9f12
@ -29,8 +29,8 @@ class SampleFn(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
def _compute_tile_index(block_index: Sequence[int],
|
def _compute_tile_index(block_index: Sequence[int],
|
||||||
total_size_in_blocks: Shape,
|
|
||||||
block_size_in_tiles: Shape,
|
block_size_in_tiles: Shape,
|
||||||
|
total_size_in_tiles: Shape,
|
||||||
tile_index_in_block: Sequence[int]) -> int:
|
tile_index_in_block: Sequence[int]) -> int:
|
||||||
ndims = len(block_index)
|
ndims = len(block_index)
|
||||||
dim_size = 1
|
dim_size = 1
|
||||||
@ -38,7 +38,7 @@ def _compute_tile_index(block_index: Sequence[int],
|
|||||||
for i in range(ndims-1, -1, -1):
|
for i in range(ndims-1, -1, -1):
|
||||||
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
||||||
total_idx += dim_idx * dim_size
|
total_idx += dim_idx * dim_size
|
||||||
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
|
dim_size *= total_size_in_tiles[i]
|
||||||
return total_idx
|
return total_idx
|
||||||
|
|
||||||
|
|
||||||
@ -103,15 +103,17 @@ def blocked_fold_in(
|
|||||||
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
total_size_in_blocks = tuple(
|
# Round up to make sure every tile is numbered.
|
||||||
_shape // _element for _shape, _element in zip(total_size, block_size)
|
total_size_in_tiles = tuple(
|
||||||
|
(_shape + _element - 1) // _element
|
||||||
|
for _shape, _element in zip(total_size, tile_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _keygen_loop(axis, prefix):
|
def _keygen_loop(axis, prefix):
|
||||||
if axis == len(block_size_in_tiles):
|
if axis == len(block_size_in_tiles):
|
||||||
subtile_key = jax.random.fold_in(
|
subtile_key = jax.random.fold_in(
|
||||||
global_key, _compute_tile_index(
|
global_key, _compute_tile_index(
|
||||||
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
|
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
|
||||||
return subtile_key
|
return subtile_key
|
||||||
else:
|
else:
|
||||||
keys = []
|
keys = []
|
||||||
|
@ -29,16 +29,23 @@ def call_kernel(
|
|||||||
kernel,
|
kernel,
|
||||||
grid: tuple[int, int],
|
grid: tuple[int, int],
|
||||||
transpose_grid: bool,
|
transpose_grid: bool,
|
||||||
*args
|
key: jax.Array,
|
||||||
|
total_size: tuple[int, int],
|
||||||
|
block_size: tuple[int, int],
|
||||||
|
tile_size: tuple[int, int],
|
||||||
):
|
):
|
||||||
"""Calls a kernel over a grid and concatenates results to a single array."""
|
"""Calls a kernel over a grid and concatenates results to a single array."""
|
||||||
if transpose_grid:
|
if transpose_grid:
|
||||||
grid = (grid[1], grid[0])
|
grid = (grid[1], grid[0])
|
||||||
m, n = grid
|
m, n = grid
|
||||||
return jnp.concatenate([
|
samples = jnp.concatenate([
|
||||||
jnp.concatenate([
|
jnp.concatenate([
|
||||||
kernel((i, j), *args) for j in range(n)], axis=1)
|
kernel((i, j), key, total_size, block_size, tile_size)
|
||||||
|
for j in range(n)], axis=1)
|
||||||
for i in range(m)], axis=0)
|
for i in range(m)], axis=0)
|
||||||
|
# Slice out the padding.
|
||||||
|
samples = samples[:total_size[0], :total_size[1]]
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def call_kernel_3d(
|
def call_kernel_3d(
|
||||||
@ -94,16 +101,25 @@ class BlockedSamplerTest(jtu.JaxTestCase):
|
|||||||
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
|
dict(testcase_name='16x256_vs_32x128', total_size=(32, 256),
|
||||||
block_size_a=(16, 256), block_size_b=(32, 128),
|
block_size_a=(16, 256), block_size_b=(32, 128),
|
||||||
tile_size=(8, 128), transpose_grid=False),
|
tile_size=(8, 128), transpose_grid=False),
|
||||||
|
dict(testcase_name='128x128_vs_128x256_padding',
|
||||||
|
total_size=(256, 128), block_size_a=(128, 128),
|
||||||
|
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
|
||||||
|
dict(testcase_name='128x128_vs_128x256_padding2',
|
||||||
|
total_size=(257, 129), block_size_a=(128, 128),
|
||||||
|
block_size_b=(128, 256), tile_size=(128, 128), transpose_grid=False),
|
||||||
)
|
)
|
||||||
def test_block_shape_invariance(self, total_size, block_size_a,
|
def test_block_shape_invariance(self, total_size, block_size_a,
|
||||||
block_size_b, tile_size, transpose_grid):
|
block_size_b, tile_size, transpose_grid):
|
||||||
global_key = jax.random.key(0)
|
global_key = jax.random.key(0)
|
||||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
ceil_div = lambda x, y: (x + y - 1) // y
|
||||||
|
grid_a = tuple(ceil_div(_tot, _blk)
|
||||||
|
for _tot, _blk in zip(total_size, block_size_a))
|
||||||
result_a = call_kernel(
|
result_a = call_kernel(
|
||||||
uniform_kernel, grid_a, transpose_grid, global_key,
|
uniform_kernel, grid_a, transpose_grid, global_key,
|
||||||
total_size, block_size_a, tile_size)
|
total_size, block_size_a, tile_size)
|
||||||
|
|
||||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
grid_b = tuple(ceil_div(_tot, _blk)
|
||||||
|
for _tot, _blk in zip(total_size, block_size_b))
|
||||||
result_b = call_kernel(
|
result_b = call_kernel(
|
||||||
uniform_kernel, grid_b, transpose_grid, global_key,
|
uniform_kernel, grid_b, transpose_grid, global_key,
|
||||||
total_size, block_size_b, tile_size)
|
total_size, block_size_b, tile_size)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user