diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 5edeef5db..96d4f32b0 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -45,10 +45,6 @@ def pallas_call_tpu_lowering_rule( **compiler_params: Any): """Lowers a pallas_call to a Mosaic TPU custom call.""" if interpret: - if grid_mapping.num_dynamic_grid_bounds: - raise NotImplementedError( - "Dynamic grid bounds not supported in interpret mode." - ) return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, in_shapes=in_shapes, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 99d748407..33a25597d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -16,7 +16,7 @@ from __future__ import annotations from functools import partial -import itertools as it +from functools import reduce from typing import Any, Callable from collections.abc import Sequence @@ -89,24 +89,37 @@ def _uninitialized_value(shape, dtype): return jnp.full(shape, jnp.iinfo(dtype).min, dtype) raise NotImplementedError(dtype) +def _get_next_indices(grid, indices): + next_indices = [] + carry = True + for dim_size, index in reversed(list(zip(grid, indices))): + i = jnp.where(carry, index + 1, index) + carry = dim_size == i + next_indices.append(jnp.where(carry, 0, i)) + return tuple(reversed(next_indices)) + def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, interpret, debug: bool, in_shapes, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, **compiler_params: Any): - if grid_mapping.num_dynamic_grid_bounds: - raise NotImplementedError("interpret with dynamic grid bounds unsupported") + dynamic_grid_args, args = split_list( # type: ignore + args, [grid_mapping.num_dynamic_grid_bounds] + ) if interpret: # If we're in interpreter mode, we *scan* over the grid and eval the # discharged jaxpr. This should reproduce exactly what compiling to Triton # will do. - grid = grid_mapping.static_grid + dynamic_grid_args_iter = iter(dynamic_grid_args) + grid = tuple( + a if a is not None else next(dynamic_grid_args_iter) + for a in grid_mapping.grid + ) + assert next(dynamic_grid_args_iter, None) is None discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) - loop_indices = jnp.array(list(it.product(*(range(g) for g in grid))), - dtype=jnp.int32) oi_map = {v: k for k, v in input_output_aliases} out = [] for i, out_shape in enumerate(out_shapes): @@ -135,12 +148,18 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, for a in scratch_avals] carry = [*args, *out, *scratch_values] num_carry = len(args) + len(out) + grid_start_indices = (jnp.int32(0),) * len(grid) + if grid: + num_iterations = reduce(jnp.multiply, grid) + else: + # Base case is always one iteration when grid is () + num_iterations = 1 def cond(carry): - return carry[0] < loop_indices.shape[0] + i, *_ = carry + return i < num_iterations def body(carry): - i, *carry = carry + i, loop_idx, *carry = carry carry, scratch = split_list(carry, [num_carry]) - loop_idx = loop_indices[i] start_indices = [ None if bm is None else bm.compute_start_indices(loop_idx, *scalars) for bm in grid_mapping.block_mappings] @@ -177,8 +196,10 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, blocks, out_scratch = split_list(blocks, [num_carry]) carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes, carry, blocks, is_indexing_dim) - return (i + 1, *carry, *out_scratch) - (_, *carry) = lax.while_loop(cond, body, (jnp.int32(0), *carry)) + return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch) + (_, _, *carry) = lax.while_loop( + cond, body, (jnp.int32(0), grid_start_indices, *carry) + ) _, out, _ = split_list(carry, [len(args), len(out)]) return out return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name, diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 55707ff4d..47b1da0b0 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -44,15 +44,17 @@ partial = functools.partial class PallasTPUTest(jtu.JaxTestCase): interpret: bool = False - -class PallasCallScalarPrefetchTest(PallasTPUTest): - interpret: bool = False - def setUp(self): super().setUp() if not self.interpret and jtu.device_under_test() != 'tpu': self.skipTest('Only interpret mode supported on non-TPU') + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.interpret) + + +class PallasCallScalarPrefetchTest(PallasTPUTest): + def test_trivial_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] @@ -288,6 +290,107 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): interpret: bool = True +class PallasCallDynamicGridTest(PallasTPUTest): + + def test_dynamic_grid(self): + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = jnp.zeros_like(y_ref) + y_ref[...] += 1 + + @jax.jit + def dynamic_kernel(steps): + return self.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + )() + np.testing.assert_array_equal( + dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32) + ) + + def test_vmap_trivial_dynamic_grid(self): + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(x_ref, y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = x_ref[...] + y_ref[...] += 1 + + @jax.jit + @jax.vmap + def dynamic_kernel(steps, x): + return self.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + )(x) + x = jnp.arange(8 * 128., dtype=jnp.float32).reshape((1, *shape)) + np.testing.assert_array_equal( + dynamic_kernel(jnp.array([4], jnp.int32), x), x + 8.0 + ) + + def test_vmap_nontrivial_dynamic_grid(self): + # Dynamic grid doesn't support vmapping over multiple distinct grid values + # at the moment. + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = jnp.zeros_like(y_ref) + y_ref[...] += 1 + + @jax.jit + @jax.vmap + def dynamic_kernel(steps): + return self.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + )() + with self.assertRaises(NotImplementedError): + dynamic_kernel(jnp.array([4, 8], jnp.int32)) + + def test_vmap_dynamic_grid(self): + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(x_ref, y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = x_ref[...] + y_ref[...] += jnp.float32(1.) + + @jax.jit + def dynamic_kernel(x, steps): + return self.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + )(x) + x = jnp.arange(4 * 8 * 128., dtype=jnp.float32).reshape((4, *shape)) + np.testing.assert_array_equal( + jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, jnp.int32(4)), + x + 8, + ) + + +class PallasCallInterpretDynamicGridTest(PallasCallDynamicGridTest): + interpret: bool = True + + class PallasCallDMATest(parameterized.TestCase): def setUp(self): @@ -1091,98 +1194,6 @@ class PallasCallTest(PallasTPUTest): kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18)) )(x) - def test_dynamic_grid(self): - shape = (8, 128) - result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) - - def kernel(y_ref): - @pl.when(pl.program_id(0) == 0) - def _init(): - y_ref[...] = jnp.zeros_like(y_ref) - y_ref[...] += 1 - - @jax.jit - def dynamic_kernel(steps): - return pl.pallas_call( - kernel, - grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), - out_shape=result_ty, - )() - np.testing.assert_array_equal( - dynamic_kernel(4), np.full(shape, 8.0, np.float32) - ) - - def test_vmap_trivial_dynamic_grid(self): - shape = (8, 128) - result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) - - def kernel(x_ref, y_ref): - @pl.when(pl.program_id(0) == 0) - def _init(): - y_ref[...] = x_ref[...] - y_ref[...] += 1 - - @jax.jit - @jax.vmap - def dynamic_kernel(steps, x): - return pl.pallas_call( - kernel, - grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), - out_shape=result_ty, - )(x) - x = jnp.arange(8 * 128.).reshape((1, *shape)) - np.testing.assert_array_equal(dynamic_kernel(jnp.array([4]), x), x + 8.0) - - def test_vmap_nontrivial_dynamic_grid(self): - # Dynamic grid doesn't support vmapping over multiple distinct grid values - # at the moment. - shape = (8, 128) - result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) - - def kernel(y_ref): - @pl.when(pl.program_id(0) == 0) - def _init(): - y_ref[...] = jnp.zeros_like(y_ref) - y_ref[...] += 1 - - @jax.jit - @jax.vmap - def dynamic_kernel(steps): - return pl.pallas_call( - kernel, - grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), - out_shape=result_ty, - )() - with self.assertRaises(NotImplementedError): - dynamic_kernel(jnp.array([4, 8])) - - def test_vmap_dynamic_grid(self): - shape = (8, 128) - result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) - - def kernel(x_ref, y_ref): - @pl.when(pl.program_id(0) == 0) - def _init(): - y_ref[...] = x_ref[...] - y_ref[...] += 1. - - @jax.jit - def dynamic_kernel(x, steps): - return pl.pallas_call( - kernel, - grid=(steps * 2,), - out_specs=pl.BlockSpec(lambda i: (0, 0), shape), - out_shape=result_ty, - )(x) - x = jnp.arange(4 * 8 * 128.).reshape((4, *shape)) - np.testing.assert_array_equal( - jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, 4), - x + 8, - ) - class PallasUXTest(PallasTPUTest):