[Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.

The existing lowering path supports only while_loops which can be converted to fori_loop.
That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations.

This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas.
Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while".

PiperOrigin-RevId: 626089180
This commit is contained in:
jax authors 2024-04-18 11:03:01 -07:00
parent 6ca69f3824
commit 9c9e805e82
2 changed files with 274 additions and 8 deletions

View File

@ -1833,20 +1833,15 @@ lowering_rules[lax.scan_p] = _scan_lowering_rule
skip_mlir_conversions.add(lax.scan_p)
def _while_lowering_rule(
def _lower_while_via_fori(
ctx: LoweringRuleContext,
*args,
fori_jaxpr,
cond_nconsts,
cond_jaxpr,
body_nconsts,
body_jaxpr,
):
jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
)
if jaxpr is None:
raise NotImplementedError(err)
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
(lb, ub), args = carry[:2], carry[2:]
for_out = _lower_jaxpr_to_for_loop(
@ -1854,7 +1849,7 @@ def _while_lowering_rule(
block_shapes=ctx.block_shapes[: body_nconsts + 1]
+ ctx.block_shapes[body_nconsts + 2 :],
),
jaxpr,
fori_jaxpr,
lb,
arith.subi(ub, lb),
body_consts,
@ -1865,6 +1860,84 @@ def _while_lowering_rule(
return [ub, ub, *for_out]
def _while_lowering_rule(
ctx: LoweringRuleContext,
*args,
cond_nconsts,
cond_jaxpr,
body_nconsts,
body_jaxpr,
):
# First try to lower via a simpler fori loop, which may optimize better.
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
)
if fori_jaxpr is not None:
return _lower_while_via_fori(
ctx,
*args,
fori_jaxpr=fori_jaxpr,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr,
)
# If we fail conversion to fori, fallback to an ordinary while loop.
cond_consts, body_consts, carry = split_list(
args, [cond_nconsts, body_nconsts]
)
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
)
cond_const_types = [a.type for a in cond_consts]
body_const_types = [a.type for a in body_consts]
carry_types = [a.type for a in carry]
all_types = [*cond_const_types, *body_const_types, *carry_types]
while_op = scf.WhileOp(all_types, args)
before_block = while_op.before.blocks.append(*all_types)
cond_consts_, _, carry_ = split_list(
before_block.arguments,
[cond_nconsts, body_nconsts],
)
cond_args = [*cond_consts_, *carry_]
with ir.InsertionPoint.at_block_begin(before_block):
[cond] = jaxpr_subcomp(
ctx.lowering_context.replace(
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
),
cond_jaxpr.jaxpr,
*cond_args,
)
scf.condition(cond, before_block.arguments)
after_block = while_op.after.blocks.append(*all_types)
cond_consts_, body_consts_, carry_ = split_list(
after_block.arguments,
[cond_nconsts, body_nconsts],
)
all_args = [*cond_consts_, *body_consts_, *carry_]
cond_const_args, body_const_args, carry_args = split_list(
all_args, [cond_nconsts, body_nconsts]
)
with ir.InsertionPoint.at_block_begin(after_block):
loop_out = jaxpr_subcomp(
ctx.lowering_context.replace(
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
),
body_jaxpr.jaxpr,
*body_const_args,
*carry_args,
)
all_handles = [*cond_const_args, *body_const_args, *loop_out]
if all_handles:
scf.yield_(all_handles)
all_out = list(while_op.results_)
return all_out[cond_nconsts + body_nconsts :]
lowering_rules[lax.while_p] = _while_lowering_rule
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):

View File

@ -1751,6 +1751,199 @@ class PallasCallWhileLoopTest(PallasTPUTest):
)(*(jnp.array([[x]]) for x in (2, 6)))
np.testing.assert_array_equal(r, 4)
def test_non_range_while_loop(self):
"""Tests lowering of a while_loop which cannot reduce to a fori_loop."""
def kernel(x_ref, r_ref):
@pl.when(pl.program_id(0) == 0)
def _():
pl.store(r_ref, (0, 0), 0)
def cond(state):
i, s = state
return jnp.logical_and(i < 1024, s < 1024)
def body(state):
i, s = state
sl = sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
v = pl.load(x_ref, (0, sl, l))
return i + 1, s + v
i = jnp.int32(0)
s = pl.load(r_ref, (0, 0))
i, s = jax.lax.while_loop(cond, body, (i, s))
pl.store(r_ref, (0, 0), s)
x = jnp.arange(4096)
x = jnp.reshape(x, [4, 8, 128])
r = pl.pallas_call(
kernel,
grid=(4,),
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
in_specs=[
pl.BlockSpec(
lambda i: (i, 0, 0),
block_shape=(1, 8, 128),
memory_space=pltpu.SMEM,
)
],
)(x)
np.testing.assert_array_equal(r, [[1035]])
def test_vector_carry_while_loop(self):
"""Tests lowering of a while_loop which carries a vector quantity."""
def kernel(x_ref, r_ref):
def cond(v):
return v[0, 0] < 16
def body(v):
return v * 2
r_ref[:] = jax.lax.while_loop(cond, body, x_ref[:])
x = jnp.full((8, 128), 3, dtype=jnp.int32)
fn = pl.pallas_call(
kernel,
grid=(1,),
in_specs=[pl.BlockSpec(lambda i: (0, 0), (8, 128))],
out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32),
)
r = fn(x)
reduced = jnp.sum(r)
# 3 -> 6 -> 12 -> 24
np.testing.assert_array_equal(reduced, 1024 * 24)
@parameterized.named_parameters(
('1x128', (1, 128)),
('2x128', (2, 128)),
('4x128', (4, 128)),
('8x128', (8, 128)),
('8x256', (8, 256)),
)
def test_while_loop_carry_memref(self, shape):
"""Tests a while loop carrying a memref."""
# TODO(hmckenzie): Investigate further why this occurs.
if shape == (1, 128):
self.skipTest('memref<1x128> inexplicably doubles to 2x128.')
def kernel(out_ref, bound):
def cond(i):
return i < bound
def body(i):
out_ref[0, i] = 2
return i + 1
jax.lax.while_loop(cond, body, 0)
x = jnp.asarray([1, 1, 1, 1])
x = jnp.asarray(x)
x = jnp.pad(x, (0, np.prod(shape) - 4), constant_values=0)
x = jnp.reshape(x, shape)
kernel = partial(kernel, bound=x.shape[1])
fn = pl.pallas_call(
kernel,
grid=(1,),
out_specs=[
pl.BlockSpec(
lambda i: (0, 0), block_shape=shape, memory_space=pltpu.SMEM
),
],
out_shape=[
jax.ShapeDtypeStruct(shape, jnp.int32),
],
)
y = fn()[0]
np.testing.assert_array_equal(y[0, 0], 2)
np.testing.assert_array_equal(y[0, 1], 2)
np.testing.assert_array_equal(y[0, 2], 2)
np.testing.assert_array_equal(y[0, 3], 2)
def test_nested_while_loop(self):
"""Tests lowering a nested while_loop."""
def kernel(in_key_ref, out_segment_count, out_size_ref, key_count):
# Compute the length of contiguous segments of keys.
def inner_cond(carry):
i, prev_key = carry
sl = sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
key = jax.lax.cond(
i < key_count, lambda i: in_key_ref[sl, l], lambda i: -1, i
)
return jnp.logical_and(i < key_count, key == prev_key)
def inner_body(carry):
i, key = carry
return i + 1, key
def outer_cond(carry):
i, _ = carry
return i < key_count
def outer_body(carry):
i, next_out_idx = carry
sl = sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
key = in_key_ref[sl, l]
end, _ = jax.lax.while_loop(inner_cond, inner_body, (i + 1, key))
sl = sl = jax.lax.div(next_out_idx, 128)
l = jax.lax.rem(next_out_idx, 128)
out_size_ref[sl, l] = end - i
return end, next_out_idx + 1
_, count = jax.lax.while_loop(outer_cond, outer_body, (0, 0))
out_segment_count[0, 0] = count
keys = [4, 4, 4, 3, 2, 2, 7, 7, 7, 7]
keys = jnp.asarray(keys)
real_keys = keys.shape[0]
key_count = 1024
keys = jnp.pad(keys, (0, key_count - real_keys), constant_values=32768)
keys = jnp.reshape(keys, (8, 128))
kernel_fn = partial(kernel, key_count=key_count)
fn = pl.pallas_call(
kernel_fn,
grid=(1,),
in_specs=[
# keys.
pl.BlockSpec(
lambda i: (0, 0),
block_shape=(8, 128),
memory_space=pltpu.SMEM,
),
],
out_specs=[
# Segments found.
pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
# Segment sizes.
pl.BlockSpec(block_shape=(8, 128), memory_space=pltpu.SMEM),
],
out_shape=[
jax.ShapeDtypeStruct((1, 1), jnp.int32),
jax.ShapeDtypeStruct((8, 128), jnp.int32),
],
)
count, sizes = fn(keys)
np.testing.assert_equal(count[0, 0], jnp.asarray(5))
np.testing.assert_equal(sizes[0, 0], jnp.asarray(3))
np.testing.assert_equal(sizes[0, 1], jnp.asarray(1))
np.testing.assert_equal(sizes[0, 2], jnp.asarray(2))
np.testing.assert_equal(sizes[0, 3], jnp.asarray(4))
np.testing.assert_equal(sizes[0, 4], jnp.asarray(key_count - real_keys))
class PallasCallPipelineTest(parameterized.TestCase):