mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
The existing lowering path supports only while_loops which can be converted to fori_loop. That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations. This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas. Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while". PiperOrigin-RevId: 626089180
This commit is contained in:
parent
6ca69f3824
commit
9c9e805e82
@ -1833,20 +1833,15 @@ lowering_rules[lax.scan_p] = _scan_lowering_rule
|
||||
skip_mlir_conversions.add(lax.scan_p)
|
||||
|
||||
|
||||
def _while_lowering_rule(
|
||||
def _lower_while_via_fori(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
fori_jaxpr,
|
||||
cond_nconsts,
|
||||
cond_jaxpr,
|
||||
body_nconsts,
|
||||
body_jaxpr,
|
||||
):
|
||||
jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
|
||||
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
|
||||
)
|
||||
if jaxpr is None:
|
||||
raise NotImplementedError(err)
|
||||
|
||||
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
|
||||
(lb, ub), args = carry[:2], carry[2:]
|
||||
for_out = _lower_jaxpr_to_for_loop(
|
||||
@ -1854,7 +1849,7 @@ def _while_lowering_rule(
|
||||
block_shapes=ctx.block_shapes[: body_nconsts + 1]
|
||||
+ ctx.block_shapes[body_nconsts + 2 :],
|
||||
),
|
||||
jaxpr,
|
||||
fori_jaxpr,
|
||||
lb,
|
||||
arith.subi(ub, lb),
|
||||
body_consts,
|
||||
@ -1865,6 +1860,84 @@ def _while_lowering_rule(
|
||||
return [ub, ub, *for_out]
|
||||
|
||||
|
||||
def _while_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
cond_nconsts,
|
||||
cond_jaxpr,
|
||||
body_nconsts,
|
||||
body_jaxpr,
|
||||
):
|
||||
# First try to lower via a simpler fori loop, which may optimize better.
|
||||
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
|
||||
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
|
||||
)
|
||||
if fori_jaxpr is not None:
|
||||
return _lower_while_via_fori(
|
||||
ctx,
|
||||
*args,
|
||||
fori_jaxpr=fori_jaxpr,
|
||||
cond_nconsts=cond_nconsts,
|
||||
cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=body_nconsts,
|
||||
body_jaxpr=body_jaxpr,
|
||||
)
|
||||
|
||||
# If we fail conversion to fori, fallback to an ordinary while loop.
|
||||
cond_consts, body_consts, carry = split_list(
|
||||
args, [cond_nconsts, body_nconsts]
|
||||
)
|
||||
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
|
||||
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
|
||||
)
|
||||
cond_const_types = [a.type for a in cond_consts]
|
||||
body_const_types = [a.type for a in body_consts]
|
||||
carry_types = [a.type for a in carry]
|
||||
all_types = [*cond_const_types, *body_const_types, *carry_types]
|
||||
while_op = scf.WhileOp(all_types, args)
|
||||
|
||||
before_block = while_op.before.blocks.append(*all_types)
|
||||
cond_consts_, _, carry_ = split_list(
|
||||
before_block.arguments,
|
||||
[cond_nconsts, body_nconsts],
|
||||
)
|
||||
cond_args = [*cond_consts_, *carry_]
|
||||
with ir.InsertionPoint.at_block_begin(before_block):
|
||||
[cond] = jaxpr_subcomp(
|
||||
ctx.lowering_context.replace(
|
||||
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
|
||||
),
|
||||
cond_jaxpr.jaxpr,
|
||||
*cond_args,
|
||||
)
|
||||
scf.condition(cond, before_block.arguments)
|
||||
|
||||
after_block = while_op.after.blocks.append(*all_types)
|
||||
cond_consts_, body_consts_, carry_ = split_list(
|
||||
after_block.arguments,
|
||||
[cond_nconsts, body_nconsts],
|
||||
)
|
||||
all_args = [*cond_consts_, *body_consts_, *carry_]
|
||||
cond_const_args, body_const_args, carry_args = split_list(
|
||||
all_args, [cond_nconsts, body_nconsts]
|
||||
)
|
||||
with ir.InsertionPoint.at_block_begin(after_block):
|
||||
loop_out = jaxpr_subcomp(
|
||||
ctx.lowering_context.replace(
|
||||
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
|
||||
),
|
||||
body_jaxpr.jaxpr,
|
||||
*body_const_args,
|
||||
*carry_args,
|
||||
)
|
||||
all_handles = [*cond_const_args, *body_const_args, *loop_out]
|
||||
if all_handles:
|
||||
scf.yield_(all_handles)
|
||||
|
||||
all_out = list(while_op.results_)
|
||||
return all_out[cond_nconsts + body_nconsts :]
|
||||
|
||||
|
||||
lowering_rules[lax.while_p] = _while_lowering_rule
|
||||
|
||||
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):
|
||||
|
@ -1751,6 +1751,199 @@ class PallasCallWhileLoopTest(PallasTPUTest):
|
||||
)(*(jnp.array([[x]]) for x in (2, 6)))
|
||||
np.testing.assert_array_equal(r, 4)
|
||||
|
||||
def test_non_range_while_loop(self):
|
||||
"""Tests lowering of a while_loop which cannot reduce to a fori_loop."""
|
||||
|
||||
def kernel(x_ref, r_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _():
|
||||
pl.store(r_ref, (0, 0), 0)
|
||||
|
||||
def cond(state):
|
||||
i, s = state
|
||||
return jnp.logical_and(i < 1024, s < 1024)
|
||||
|
||||
def body(state):
|
||||
i, s = state
|
||||
sl = sl = jax.lax.div(i, 128)
|
||||
l = jax.lax.rem(i, 128)
|
||||
v = pl.load(x_ref, (0, sl, l))
|
||||
return i + 1, s + v
|
||||
|
||||
i = jnp.int32(0)
|
||||
s = pl.load(r_ref, (0, 0))
|
||||
|
||||
i, s = jax.lax.while_loop(cond, body, (i, s))
|
||||
pl.store(r_ref, (0, 0), s)
|
||||
|
||||
x = jnp.arange(4096)
|
||||
x = jnp.reshape(x, [4, 8, 128])
|
||||
|
||||
r = pl.pallas_call(
|
||||
kernel,
|
||||
grid=(4,),
|
||||
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
|
||||
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
|
||||
in_specs=[
|
||||
pl.BlockSpec(
|
||||
lambda i: (i, 0, 0),
|
||||
block_shape=(1, 8, 128),
|
||||
memory_space=pltpu.SMEM,
|
||||
)
|
||||
],
|
||||
)(x)
|
||||
np.testing.assert_array_equal(r, [[1035]])
|
||||
|
||||
def test_vector_carry_while_loop(self):
|
||||
"""Tests lowering of a while_loop which carries a vector quantity."""
|
||||
|
||||
def kernel(x_ref, r_ref):
|
||||
|
||||
def cond(v):
|
||||
return v[0, 0] < 16
|
||||
|
||||
def body(v):
|
||||
return v * 2
|
||||
|
||||
r_ref[:] = jax.lax.while_loop(cond, body, x_ref[:])
|
||||
|
||||
x = jnp.full((8, 128), 3, dtype=jnp.int32)
|
||||
fn = pl.pallas_call(
|
||||
kernel,
|
||||
grid=(1,),
|
||||
in_specs=[pl.BlockSpec(lambda i: (0, 0), (8, 128))],
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)),
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32),
|
||||
)
|
||||
r = fn(x)
|
||||
reduced = jnp.sum(r)
|
||||
# 3 -> 6 -> 12 -> 24
|
||||
np.testing.assert_array_equal(reduced, 1024 * 24)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('1x128', (1, 128)),
|
||||
('2x128', (2, 128)),
|
||||
('4x128', (4, 128)),
|
||||
('8x128', (8, 128)),
|
||||
('8x256', (8, 256)),
|
||||
)
|
||||
def test_while_loop_carry_memref(self, shape):
|
||||
"""Tests a while loop carrying a memref."""
|
||||
|
||||
# TODO(hmckenzie): Investigate further why this occurs.
|
||||
if shape == (1, 128):
|
||||
self.skipTest('memref<1x128> inexplicably doubles to 2x128.')
|
||||
|
||||
def kernel(out_ref, bound):
|
||||
def cond(i):
|
||||
return i < bound
|
||||
|
||||
def body(i):
|
||||
out_ref[0, i] = 2
|
||||
return i + 1
|
||||
|
||||
jax.lax.while_loop(cond, body, 0)
|
||||
|
||||
x = jnp.asarray([1, 1, 1, 1])
|
||||
x = jnp.asarray(x)
|
||||
x = jnp.pad(x, (0, np.prod(shape) - 4), constant_values=0)
|
||||
x = jnp.reshape(x, shape)
|
||||
kernel = partial(kernel, bound=x.shape[1])
|
||||
|
||||
fn = pl.pallas_call(
|
||||
kernel,
|
||||
grid=(1,),
|
||||
out_specs=[
|
||||
pl.BlockSpec(
|
||||
lambda i: (0, 0), block_shape=shape, memory_space=pltpu.SMEM
|
||||
),
|
||||
],
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(shape, jnp.int32),
|
||||
],
|
||||
)
|
||||
y = fn()[0]
|
||||
np.testing.assert_array_equal(y[0, 0], 2)
|
||||
np.testing.assert_array_equal(y[0, 1], 2)
|
||||
np.testing.assert_array_equal(y[0, 2], 2)
|
||||
np.testing.assert_array_equal(y[0, 3], 2)
|
||||
|
||||
def test_nested_while_loop(self):
|
||||
"""Tests lowering a nested while_loop."""
|
||||
|
||||
def kernel(in_key_ref, out_segment_count, out_size_ref, key_count):
|
||||
# Compute the length of contiguous segments of keys.
|
||||
|
||||
def inner_cond(carry):
|
||||
i, prev_key = carry
|
||||
sl = sl = jax.lax.div(i, 128)
|
||||
l = jax.lax.rem(i, 128)
|
||||
key = jax.lax.cond(
|
||||
i < key_count, lambda i: in_key_ref[sl, l], lambda i: -1, i
|
||||
)
|
||||
return jnp.logical_and(i < key_count, key == prev_key)
|
||||
|
||||
def inner_body(carry):
|
||||
i, key = carry
|
||||
return i + 1, key
|
||||
|
||||
def outer_cond(carry):
|
||||
i, _ = carry
|
||||
return i < key_count
|
||||
|
||||
def outer_body(carry):
|
||||
i, next_out_idx = carry
|
||||
sl = sl = jax.lax.div(i, 128)
|
||||
l = jax.lax.rem(i, 128)
|
||||
key = in_key_ref[sl, l]
|
||||
end, _ = jax.lax.while_loop(inner_cond, inner_body, (i + 1, key))
|
||||
|
||||
sl = sl = jax.lax.div(next_out_idx, 128)
|
||||
l = jax.lax.rem(next_out_idx, 128)
|
||||
out_size_ref[sl, l] = end - i
|
||||
return end, next_out_idx + 1
|
||||
|
||||
_, count = jax.lax.while_loop(outer_cond, outer_body, (0, 0))
|
||||
out_segment_count[0, 0] = count
|
||||
|
||||
keys = [4, 4, 4, 3, 2, 2, 7, 7, 7, 7]
|
||||
keys = jnp.asarray(keys)
|
||||
real_keys = keys.shape[0]
|
||||
key_count = 1024
|
||||
keys = jnp.pad(keys, (0, key_count - real_keys), constant_values=32768)
|
||||
keys = jnp.reshape(keys, (8, 128))
|
||||
kernel_fn = partial(kernel, key_count=key_count)
|
||||
|
||||
fn = pl.pallas_call(
|
||||
kernel_fn,
|
||||
grid=(1,),
|
||||
in_specs=[
|
||||
# keys.
|
||||
pl.BlockSpec(
|
||||
lambda i: (0, 0),
|
||||
block_shape=(8, 128),
|
||||
memory_space=pltpu.SMEM,
|
||||
),
|
||||
],
|
||||
out_specs=[
|
||||
# Segments found.
|
||||
pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
|
||||
# Segment sizes.
|
||||
pl.BlockSpec(block_shape=(8, 128), memory_space=pltpu.SMEM),
|
||||
],
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct((1, 1), jnp.int32),
|
||||
jax.ShapeDtypeStruct((8, 128), jnp.int32),
|
||||
],
|
||||
)
|
||||
count, sizes = fn(keys)
|
||||
np.testing.assert_equal(count[0, 0], jnp.asarray(5))
|
||||
np.testing.assert_equal(sizes[0, 0], jnp.asarray(3))
|
||||
np.testing.assert_equal(sizes[0, 1], jnp.asarray(1))
|
||||
np.testing.assert_equal(sizes[0, 2], jnp.asarray(2))
|
||||
np.testing.assert_equal(sizes[0, 3], jnp.asarray(4))
|
||||
np.testing.assert_equal(sizes[0, 4], jnp.asarray(key_count - real_keys))
|
||||
|
||||
|
||||
class PallasCallPipelineTest(parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user