From 621814bd7dd0fbde9afc426c36354df9e208fe7c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 5 Jun 2024 08:14:39 -0700 Subject: [PATCH] Add loop-based vmap lowering for pallas calls Loop-based vmap is used for cases in which a pipeline-based vmap is currently not feasible: * Dynamic grid dimensions * Batched scalar prefetch arguments PiperOrigin-RevId: 640530524 --- jax/_src/pallas/pallas_call.py | 221 +++++++++++++++++++++++---- tests/pallas/pallas_call_tpu_test.py | 48 +++--- 2 files changed, 219 insertions(+), 50 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 49015be5d..603f400e9 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -39,8 +39,14 @@ from jax._src.pallas import core as pallas_core from jax._src.state import discharge as state_discharge from jax._src.state import primitives as sp from jax._src.util import ( - split_list, safe_map, safe_zip, weakref_lru_cache, - tuple_insert, partition_list, merge_lists) + merge_lists, + partition_list, + safe_map, + safe_zip, + split_list, + tuple_insert, + weakref_lru_cache, +) import jax.numpy as jnp import numpy as np @@ -367,17 +373,148 @@ def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray, return block_mapping.replace(block_shape=new_block_shape, index_map_jaxpr=jaxpr) -def _pallas_call_batching_rule(args, dims, *, - jaxpr: jax_core.Jaxpr, - name: str, - in_shapes: tuple[jax.ShapeDtypeStruct, ...], - out_shapes: tuple[jax.ShapeDtypeStruct, ...], - grid_mapping: GridMapping, - input_output_aliases: tuple[tuple[int, int], ...], - debug: bool, - interpret: bool, - which_linear: tuple[bool, ...], - compiler_params: Any): + +def _broadcast_input_output_aliases( + args: Sequence[jax.Array], + dims: Sequence[int | batching.NotMapped], + *, + input_output_aliases: tuple[tuple[int, int], ...], + axis_size: int, +) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]: + """Broadcast input/output operands. + + When we have input/output aliasing, since the output will be mapped, we need + to make sure to broadcast the input across that dimension if it is not + mapped. + """ + + args_ = list(args) + dims_ = list(dims) + for input_index, _ in input_output_aliases: + dim = dims_[input_index] + if dim is batching.not_mapped: + dims_[input_index] = 0 + args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + + return tuple(args_), tuple(dims_) + + +def _batch_with_explicit_loop( + args: Sequence[jax.Array], + dims: Sequence[int | batching.NotMapped], + *, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + grid_mapping: GridMapping, + input_output_aliases: tuple[tuple[int, int], ...], + debug: bool, + interpret: bool, + which_linear: tuple[bool, ...], + compiler_params: Any, +): + """Batch the pallas_call by calling it in loop over the batch size. + + This function provides a fallback implementation of batching a pallas_call + for the cases in which adding a batch dimension to the pallas grid is not + supported. This is currently the case when the batched dimension corresponds + to a dynamic axis or a scalar prefetch argument. + + This implementation builds a HLO loop that dynamic_slices the inputs according + to the current iteration index and dynamic_updates an (initially empty) output + allocation. + """ + + if not dims: + raise NotImplementedError("vmapping pallas_call with no arguments.") + + (axis_size,) = { + arg.shape[dim] + for arg, dim in zip(args, dims) + if dim is not batching.not_mapped + } + + args, dims = _broadcast_input_output_aliases( + args, + dims, + input_output_aliases=input_output_aliases, + axis_size=axis_size, + ) + + # The output arrays are completelly overwritten, so we can just initialize + # empty arrays. + initial_state = [ + jnp.empty( + tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype + ) + for out_shape in out_shapes + ] + + def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: + batch_args = [] + + for arg, dim in zip(args, dims): + # If the argument is mapped, extract a slice of size 1 in the mapped + # dimension at the current index. + if dim is batching.not_mapped: + batch_args.append(arg) + else: + batch_args.append( + jnp.squeeze( + jax.lax.dynamic_slice_in_dim( + operand=arg, + start_index=batch_index, + slice_size=1, + axis=dim, + ), + axis=dim, + ) + ) + + batch_out = pallas_call_p.bind( + *batch_args, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + which_linear=which_linear, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + for i, batch_out_array in enumerate(batch_out): + state[i] = jax.lax.dynamic_update_index_in_dim( + state[i], + batch_out_array, + batch_index, + axis=0, + ) + + return state + + result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False) + + return result, (0,) * len(result) + + +def _pallas_call_batching_rule( + args, + dims, + *, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + grid_mapping: GridMapping, + input_output_aliases: tuple[tuple[int, int], ...], + debug: bool, + interpret: bool, + which_linear: tuple[bool, ...], + compiler_params: Any, +): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -386,6 +523,8 @@ def _pallas_call_batching_rule(args, dims, *, return x return jnp.squeeze(x, axis=bdim) + # The first num_dynamic_grid_bounds arguments are size-1 arrays that store + # the size of the dynamic bounds. dynamic_grid_args, args = split_list( args, [grid_mapping.num_dynamic_grid_bounds] ) @@ -397,10 +536,24 @@ def _pallas_call_batching_rule(args, dims, *, for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims) ): dynamic_grid_args = safe_map( - _maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims) + _maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims + ) elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims): - raise NotImplementedError( - f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}" + # TODO(amagni, sharadmv): Explore possibility of batching dynamic grid + # bounds. + return _batch_with_explicit_loop( + args=dynamic_grid_args + args, + dims=dynamic_grid_dims + dims, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + which_linear=which_linear, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, ) else: pass # No dynamic grid dimensions @@ -421,8 +574,23 @@ def _pallas_call_batching_rule(args, dims, *, args = (*scalar_args, *args) dims = (*scalar_bdims, *bdims) else: - # TODO(sharadmv,apaszke): enable batching over prefetched scalar args - raise NotImplementedError + # TODO(amagni,sharadmv,apaszke): enable efficient batching over + # prefetched scalar args. + return _batch_with_explicit_loop( + args=scalar_args + args, + dims=scalar_bdims + bdims, + jaxpr=jaxpr, + name=name, + in_shapes=in_shapes, + out_shapes=out_shapes, + which_linear=which_linear, + grid_mapping=grid_mapping, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + compiler_params=compiler_params, + ) + if not dims: raise NotImplementedError("vmapping pallas_call with no arguments.") axis_size, = {x.shape[d] for x, d in zip(args, dims) @@ -436,18 +604,9 @@ def _pallas_call_batching_rule(args, dims, *, # TODO(sharadmv): explore inferring better output dimensions via a heuristic # TODO(sharadmv): explore a long term solution to output dim inference - # When we have input/output aliasing, since the output will be mapped, we need - # to make sure to broadcast the input across that dimension if it is not - # mapped. - dims_ = list(dims) - args_ = list(args) - for input_index, _ in input_output_aliases: - dim = dims_[input_index] - if dim is batching.not_mapped: - dims_[input_index] = 0 - args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) - args = tuple(args_) - dims = tuple(dims_) + args, dims = _broadcast_input_output_aliases( + args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size + ) all_dims = list(dims) + [0] * len(out_shapes) @@ -493,6 +652,8 @@ def _pallas_call_batching_rule(args, dims, *, compiler_params=compiler_params, ) return out, (0,) * len(out) + + batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr: diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index d093f991a..ab753599b 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -286,34 +286,39 @@ class PallasCallScalarPrefetchTest(PallasTPUTest): o_ref[...] = x_ref[...] s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32) - x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): s = pl.load(s_ref, (i,)) return (s, 0) s = jnp.tile(s[None], [2, 1]) - x = jnp.tile(x[None], [2, 1, 1]) - with self.assertRaises(NotImplementedError): - jax.vmap( - pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - in_specs=[ - pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])), - ], - out_specs=pl.BlockSpec( - lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2]) - ), - grid=8, + @jax.jit + @jax.vmap + def kernel(s, x): + return pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])), + ], + out_specs=pl.BlockSpec( + lambda i, _: (i, 0), (x.shape[0] // 8, x.shape[1]) ), - interpret=self.interpret, - ) + grid=8, + ), + interpret=self.interpret, )(s, x) + first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:]) + second = x[1, ...].reshape((1, 8, 8, -1))[:, s[1, ...]].reshape(x.shape[1:]) + + expected = jnp.stack([first, second]) + np.testing.assert_allclose(kernel(s, x), expected) + class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): interpret: bool = True @@ -434,8 +439,11 @@ class PallasCallDynamicGridTest(PallasTPUTest): out_specs=pl.BlockSpec(lambda i: (0, 0), shape), out_shape=result_ty, )() - with self.assertRaises(NotImplementedError): - dynamic_kernel(jnp.array([4, 8], jnp.int32)) + out = dynamic_kernel(jnp.array([4, 8], jnp.int32)) + first = jnp.full(shape, fill_value=8.0, dtype=jnp.float32) + second = jnp.full(shape, fill_value=16.0, dtype=jnp.float32) + expected_out = jnp.stack([first, second], axis=0) + np.testing.assert_array_equal(out, expected_out) def test_vmap_dynamic_grid(self): shape = (8, 128)