mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas:mosaic_gpu] Shrink max_concurrent_iteration
based on the total number of steps
PiperOrigin-RevId: 680990842
This commit is contained in:
parent
a644e23a4b
commit
0cfed4efad
@ -279,6 +279,10 @@ def lower_jaxpr_to_module(
|
||||
num_steps = 1
|
||||
out_sequential_invariant = [True] * len(grid_mapping.out_shapes)
|
||||
|
||||
# Shrink ``max_concurrent_steps`` if the total number of steps is lower to
|
||||
# reduce the size of the allocated buffers below.
|
||||
max_concurrent_steps = min(max_concurrent_steps, num_steps)
|
||||
|
||||
in_in_smem, out_in_smem = util.split_list(
|
||||
[
|
||||
bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM)
|
||||
@ -291,7 +295,6 @@ def lower_jaxpr_to_module(
|
||||
block_mappings, [grid_mapping.num_inputs]
|
||||
)
|
||||
in_structs_gmem = [*grid_mapping.in_shapes]
|
||||
# TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps.
|
||||
# We allocate the fully transformed shapes here. All primitives have seen the
|
||||
# inverse transformation stack and will understand how to handle it.
|
||||
in_structs_smem = [
|
||||
|
@ -97,7 +97,7 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
||||
|
||||
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4])
|
||||
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16])
|
||||
def test_add_one_grid_pipelined(self, max_concurrent_steps):
|
||||
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user