mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Compute tile index using tile-based coordinates
This reduces the chances of overflowing a 32-bit integer when computing tile indices. Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`. PiperOrigin-RevId: 737778853
This commit is contained in:
parent
b74b16f9b9
commit
b4966130a3
@ -28,17 +28,17 @@ class SampleFn(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def _compute_scalar_index(iteration_index: Sequence[int],
|
||||
total_size: Shape,
|
||||
block_size: Shape,
|
||||
block_index: Sequence[int]) -> int:
|
||||
ndims = len(iteration_index)
|
||||
def _compute_tile_index(block_index: Sequence[int],
|
||||
total_size_in_blocks: Shape,
|
||||
block_size_in_tiles: Shape,
|
||||
tile_index_in_block: Sequence[int]) -> int:
|
||||
ndims = len(block_index)
|
||||
dim_size = 1
|
||||
total_idx = 0
|
||||
for i in range(ndims-1, -1, -1):
|
||||
dim_idx = block_index[i] + iteration_index[i] * block_size[i]
|
||||
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
||||
total_idx += dim_idx * dim_size
|
||||
dim_size *= total_size[i]
|
||||
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
|
||||
return total_idx
|
||||
|
||||
|
||||
@ -99,18 +99,23 @@ def blocked_fold_in(
|
||||
An N-dimensional nested list of keys required to sample the tiles
|
||||
corresponding to the block specified by `block_index`.
|
||||
"""
|
||||
size_in_blocks = tuple(
|
||||
_shape // _element for _shape, _element in zip(block_size, tile_size))
|
||||
block_size_in_tiles = tuple(
|
||||
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
||||
)
|
||||
|
||||
total_size_in_blocks = tuple(
|
||||
_shape // _element for _shape, _element in zip(total_size, block_size)
|
||||
)
|
||||
|
||||
def _keygen_loop(axis, prefix):
|
||||
if axis == len(size_in_blocks):
|
||||
if axis == len(block_size_in_tiles):
|
||||
subtile_key = jax.random.fold_in(
|
||||
global_key, _compute_scalar_index(
|
||||
block_index, total_size, size_in_blocks, prefix))
|
||||
global_key, _compute_tile_index(
|
||||
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
|
||||
return subtile_key
|
||||
else:
|
||||
keys = []
|
||||
for i in range(size_in_blocks[axis]):
|
||||
for i in range(block_size_in_tiles[axis]):
|
||||
keys.append(_keygen_loop(axis+1, prefix+(i,)))
|
||||
return keys
|
||||
return _keygen_loop(0, tuple())
|
||||
|
@ -37,18 +37,41 @@ def call_kernel(
|
||||
m, n = grid
|
||||
return jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
kernel(i, j, *args) for j in range(n)], axis=1)
|
||||
kernel((i, j), *args) for j in range(n)], axis=1)
|
||||
for i in range(m)], axis=0)
|
||||
|
||||
|
||||
def uniform_kernel(i: int, j: int, total_size, block_size, tile_size):
|
||||
"""Uniform random sampling kernel function."""
|
||||
global_key = jax.random.key(0)
|
||||
keys = blocked_sampler.blocked_fold_in(global_key,
|
||||
def call_kernel_3d(
|
||||
kernel,
|
||||
grid: tuple[int, int],
|
||||
*args
|
||||
):
|
||||
"""Calls a kernel over a 3D grid and concatenates results to a single array."""
|
||||
depth, rows, cols = grid
|
||||
return jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
jnp.concatenate([
|
||||
jnp.array(kernel((i, j, k), *args))
|
||||
for k in range(cols)], axis=2)
|
||||
for j in range(rows)], axis=1)
|
||||
for i in range(depth)], axis=0)
|
||||
|
||||
|
||||
def blocked_fold_in(block_index, key, total_size, block_size, tile_size):
|
||||
"""Folds in block_index into global_key."""
|
||||
return blocked_sampler.blocked_fold_in(key,
|
||||
total_size=total_size,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size,
|
||||
block_index=(i, j))
|
||||
block_index=block_index)
|
||||
|
||||
|
||||
def uniform_kernel(block_index, key, total_size, block_size, tile_size):
|
||||
"""Uniform random sampling kernel function."""
|
||||
keys = blocked_fold_in(block_index, key,
|
||||
total_size=total_size,
|
||||
block_size=block_size,
|
||||
tile_size=tile_size)
|
||||
return blocked_sampler.sample_block(jax.random.uniform,
|
||||
keys,
|
||||
block_size=block_size,
|
||||
@ -74,17 +97,46 @@ class BlockedSamplerTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_block_shape_invariance(self, total_size, block_size_a,
|
||||
block_size_b, tile_size, transpose_grid):
|
||||
global_key = jax.random.key(0)
|
||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||
result_a = call_kernel(
|
||||
uniform_kernel, grid_a, transpose_grid,
|
||||
uniform_kernel, grid_a, transpose_grid, global_key,
|
||||
total_size, block_size_a, tile_size)
|
||||
|
||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||
result_b = call_kernel(
|
||||
uniform_kernel, grid_b, transpose_grid,
|
||||
uniform_kernel, grid_b, transpose_grid, global_key,
|
||||
total_size, block_size_b, tile_size)
|
||||
np.testing.assert_array_equal(result_a, result_b)
|
||||
|
||||
|
||||
class BlockedFoldInTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(
|
||||
# Check that sampling a tensor of total size > jnp.iinfo(jnp.uint32).max works
|
||||
# as expected. Specifically, blocked key folding does not depend on the total
|
||||
# size of the tensor, but only the total number of tiles.
|
||||
# Using a 3D grid (with very large inner dimensions) triggers an overflow in a
|
||||
# previous implementation of blocked_fold_in.
|
||||
dict(testcase_name='4096x512_vs_1024x2048',
|
||||
total_size=(2, 64 * 1024, 64 * 1024), block_size_a=(1, 4096, 512),
|
||||
block_size_b=(1, 1024, 2048), tile_size=(1, 1024, 512)),
|
||||
)
|
||||
def test_blocked_fold_in_shape_invariance(self, total_size, block_size_a,
|
||||
block_size_b, tile_size):
|
||||
global_key = jax.random.key(0)
|
||||
grid_a = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_a))
|
||||
result_a = call_kernel_3d(
|
||||
blocked_fold_in, grid_a, global_key, total_size,
|
||||
block_size_a, tile_size)
|
||||
|
||||
grid_b = tuple(_tot // _blk for _tot, _blk in zip(total_size, block_size_b))
|
||||
result_b = call_kernel_3d(
|
||||
blocked_fold_in, grid_b, global_key, total_size,
|
||||
block_size_b, tile_size)
|
||||
np.testing.assert_array_equal(jax.random.key_data(result_a),
|
||||
jax.random.key_data(result_b))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user