[pallas:mosaic_gpu] Shrink max_concurrent_iteration based on the total number of steps

PiperOrigin-RevId: 680990842
This commit is contained in:
Sergei Lebedev 2024-10-01 06:18:54 -07:00 committed by jax authors
parent a644e23a4b
commit 0cfed4efad
2 changed files with 5 additions and 2 deletions

View File

@ -279,6 +279,10 @@ def lower_jaxpr_to_module(
num_steps = 1
out_sequential_invariant = [True] * len(grid_mapping.out_shapes)
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
# reduce the size of the allocated buffers below.
max_concurrent_steps = min(max_concurrent_steps, num_steps)
in_in_smem, out_in_smem = util.split_list(
[
bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM)
@ -291,7 +295,6 @@ def lower_jaxpr_to_module(
block_mappings, [grid_mapping.num_inputs]
)
in_structs_gmem = [*grid_mapping.in_shapes]
# TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps.
# We allocate the fully transformed shapes here. All primitives have seen the
# inverse transformation stack and will understand how to handle it.
in_structs_smem = [

View File

@ -97,7 +97,7 @@ class PallasCallTest(PallasTest):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4])
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16])
def test_add_one_grid_pipelined(self, max_concurrent_steps):
@functools.partial(