From 0cfed4efadc6fd6dd9bd7245225ba3c872fa2fc4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 1 Oct 2024 06:18:54 -0700 Subject: [PATCH] [pallas:mosaic_gpu] Shrink `max_concurrent_iteration` based on the total number of steps PiperOrigin-RevId: 680990842 --- jax/_src/pallas/mosaic_gpu/lowering.py | 5 ++++- tests/pallas/mosaic_gpu_test.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 58dbf9c29..b74fd7168 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -279,6 +279,10 @@ def lower_jaxpr_to_module( num_steps = 1 out_sequential_invariant = [True] * len(grid_mapping.out_shapes) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to + # reduce the size of the allocated buffers below. + max_concurrent_steps = min(max_concurrent_steps, num_steps) + in_in_smem, out_in_smem = util.split_list( [ bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM) @@ -291,7 +295,6 @@ def lower_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) in_structs_gmem = [*grid_mapping.in_shapes] - # TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps. # We allocate the fully transformed shapes here. All primitives have seen the # inverse transformation stack and will understand how to handle it. in_structs_smem = [ diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 44131d360..0e20e78c7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -97,7 +97,7 @@ class PallasCallTest(PallasTest): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) - @parameterized.product(max_concurrent_steps=[1, 2, 3, 4]) + @parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16]) def test_add_one_grid_pipelined(self, max_concurrent_steps): @functools.partial(