[Pallas] Add interpret mode support for dynamic grid

PiperOrigin-RevId: 603818776
This commit is contained in:
Sharad Vikram 2024-02-02 16:37:16 -08:00 committed by jax authors
parent c4b6266049
commit a7a6b40b55
3 changed files with 139 additions and 111 deletions

View File

@ -45,10 +45,6 @@ def pallas_call_tpu_lowering_rule(
**compiler_params: Any):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"Dynamic grid bounds not supported in interpret mode."
)
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,

View File

@ -16,7 +16,7 @@
from __future__ import annotations
from functools import partial
import itertools as it
from functools import reduce
from typing import Any, Callable
from collections.abc import Sequence
@ -89,24 +89,37 @@ def _uninitialized_value(shape, dtype):
return jnp.full(shape, jnp.iinfo(dtype).min, dtype)
raise NotImplementedError(dtype)
def _get_next_indices(grid, indices):
next_indices = []
carry = True
for dim_size, index in reversed(list(zip(grid, indices))):
i = jnp.where(carry, index + 1, index)
carry = dim_size == i
next_indices.append(jnp.where(carry, 0, i))
return tuple(reversed(next_indices))
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
interpret, debug: bool,
in_shapes,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
**compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
dynamic_grid_args, args = split_list( # type: ignore
args, [grid_mapping.num_dynamic_grid_bounds]
)
if interpret:
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr. This should reproduce exactly what compiling to Triton
# will do.
grid = grid_mapping.static_grid
dynamic_grid_args_iter = iter(dynamic_grid_args)
grid = tuple(
a if a is not None else next(dynamic_grid_args_iter)
for a in grid_mapping.grid
)
assert next(dynamic_grid_args_iter, None) is None
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
loop_indices = jnp.array(list(it.product(*(range(g) for g in grid))),
dtype=jnp.int32)
oi_map = {v: k for k, v in input_output_aliases}
out = []
for i, out_shape in enumerate(out_shapes):
@ -135,12 +148,18 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
for a in scratch_avals]
carry = [*args, *out, *scratch_values]
num_carry = len(args) + len(out)
grid_start_indices = (jnp.int32(0),) * len(grid)
if grid:
num_iterations = reduce(jnp.multiply, grid)
else:
# Base case is always one iteration when grid is ()
num_iterations = 1
def cond(carry):
return carry[0] < loop_indices.shape[0]
i, *_ = carry
return i < num_iterations
def body(carry):
i, *carry = carry
i, loop_idx, *carry = carry
carry, scratch = split_list(carry, [num_carry])
loop_idx = loop_indices[i]
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
@ -177,8 +196,10 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
blocks, out_scratch = split_list(blocks, [num_carry])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, *carry, *out_scratch)
(_, *carry) = lax.while_loop(cond, body, (jnp.int32(0), *carry))
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
_, out, _ = split_list(carry, [len(args), len(out)])
return out
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,

View File

@ -44,15 +44,17 @@ partial = functools.partial
class PallasTPUTest(jtu.JaxTestCase):
interpret: bool = False
class PallasCallScalarPrefetchTest(PallasTPUTest):
interpret: bool = False
def setUp(self):
super().setUp()
if not self.interpret and jtu.device_under_test() != 'tpu':
self.skipTest('Only interpret mode supported on non-TPU')
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
class PallasCallScalarPrefetchTest(PallasTPUTest):
def test_trivial_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
@ -288,6 +290,107 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
interpret: bool = True
class PallasCallDynamicGridTest(PallasTPUTest):
def test_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)()
np.testing.assert_array_equal(
dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32)
)
def test_vmap_trivial_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps, x):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)(x)
x = jnp.arange(8 * 128., dtype=jnp.float32).reshape((1, *shape))
np.testing.assert_array_equal(
dynamic_kernel(jnp.array([4], jnp.int32), x), x + 8.0
)
def test_vmap_nontrivial_dynamic_grid(self):
# Dynamic grid doesn't support vmapping over multiple distinct grid values
# at the moment.
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)()
with self.assertRaises(NotImplementedError):
dynamic_kernel(jnp.array([4, 8], jnp.int32))
def test_vmap_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += jnp.float32(1.)
@jax.jit
def dynamic_kernel(x, steps):
return self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)(x)
x = jnp.arange(4 * 8 * 128., dtype=jnp.float32).reshape((4, *shape))
np.testing.assert_array_equal(
jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, jnp.int32(4)),
x + 8,
)
class PallasCallInterpretDynamicGridTest(PallasCallDynamicGridTest):
interpret: bool = True
class PallasCallDMATest(parameterized.TestCase):
def setUp(self):
@ -1091,98 +1194,6 @@ class PallasCallTest(PallasTPUTest):
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18))
)(x)
def test_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
def dynamic_kernel(steps):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)()
np.testing.assert_array_equal(
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
)
def test_vmap_trivial_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps, x):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)(x)
x = jnp.arange(8 * 128.).reshape((1, *shape))
np.testing.assert_array_equal(dynamic_kernel(jnp.array([4]), x), x + 8.0)
def test_vmap_nontrivial_dynamic_grid(self):
# Dynamic grid doesn't support vmapping over multiple distinct grid values
# at the moment.
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
@jax.vmap
def dynamic_kernel(steps):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)()
with self.assertRaises(NotImplementedError):
dynamic_kernel(jnp.array([4, 8]))
def test_vmap_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(x_ref, y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = x_ref[...]
y_ref[...] += 1.
@jax.jit
def dynamic_kernel(x, steps):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)(x)
x = jnp.arange(4 * 8 * 128.).reshape((4, *shape))
np.testing.assert_array_equal(
jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, 4),
x + 8,
)
class PallasUXTest(PallasTPUTest):