Add loop-based vmap lowering for pallas calls

Loop-based vmap is used for cases in which a pipeline-based vmap is currently not feasible:
* Dynamic grid dimensions
* Batched scalar prefetch arguments

PiperOrigin-RevId: 640530524
This commit is contained in:
jax authors 2024-06-05 08:14:39 -07:00 committed by jax authors
parent 9e3f290de3
commit 621814bd7d
2 changed files with 219 additions and 50 deletions

View File

@ -39,8 +39,14 @@ from jax._src.pallas import core as pallas_core
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as sp
from jax._src.util import (
split_list, safe_map, safe_zip, weakref_lru_cache,
tuple_insert, partition_list, merge_lists)
merge_lists,
partition_list,
safe_map,
safe_zip,
split_list,
tuple_insert,
weakref_lru_cache,
)
import jax.numpy as jnp
import numpy as np
@ -367,17 +373,148 @@ def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
return block_mapping.replace(block_shape=new_block_shape,
index_map_jaxpr=jaxpr)
def _pallas_call_batching_rule(args, dims, *,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
grid_mapping: GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: tuple[bool, ...],
compiler_params: Any):
def _broadcast_input_output_aliases(
args: Sequence[jax.Array],
dims: Sequence[int | batching.NotMapped],
*,
input_output_aliases: tuple[tuple[int, int], ...],
axis_size: int,
) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]:
"""Broadcast input/output operands.
When we have input/output aliasing, since the output will be mapped, we need
to make sure to broadcast the input across that dimension if it is not
mapped.
"""
args_ = list(args)
dims_ = list(dims)
for input_index, _ in input_output_aliases:
dim = dims_[input_index]
if dim is batching.not_mapped:
dims_[input_index] = 0
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
return tuple(args_), tuple(dims_)
def _batch_with_explicit_loop(
args: Sequence[jax.Array],
dims: Sequence[int | batching.NotMapped],
*,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
grid_mapping: GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: tuple[bool, ...],
compiler_params: Any,
):
"""Batch the pallas_call by calling it in loop over the batch size.
This function provides a fallback implementation of batching a pallas_call
for the cases in which adding a batch dimension to the pallas grid is not
supported. This is currently the case when the batched dimension corresponds
to a dynamic axis or a scalar prefetch argument.
This implementation builds a HLO loop that dynamic_slices the inputs according
to the current iteration index and dynamic_updates an (initially empty) output
allocation.
"""
if not dims:
raise NotImplementedError("vmapping pallas_call with no arguments.")
(axis_size,) = {
arg.shape[dim]
for arg, dim in zip(args, dims)
if dim is not batching.not_mapped
}
args, dims = _broadcast_input_output_aliases(
args,
dims,
input_output_aliases=input_output_aliases,
axis_size=axis_size,
)
# The output arrays are completelly overwritten, so we can just initialize
# empty arrays.
initial_state = [
jnp.empty(
tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype
)
for out_shape in out_shapes
]
def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
batch_args = []
for arg, dim in zip(args, dims):
# If the argument is mapped, extract a slice of size 1 in the mapped
# dimension at the current index.
if dim is batching.not_mapped:
batch_args.append(arg)
else:
batch_args.append(
jnp.squeeze(
jax.lax.dynamic_slice_in_dim(
operand=arg,
start_index=batch_index,
slice_size=1,
axis=dim,
),
axis=dim,
)
)
batch_out = pallas_call_p.bind(
*batch_args,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
)
for i, batch_out_array in enumerate(batch_out):
state[i] = jax.lax.dynamic_update_index_in_dim(
state[i],
batch_out_array,
batch_index,
axis=0,
)
return state
result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False)
return result, (0,) * len(result)
def _pallas_call_batching_rule(
args,
dims,
*,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
grid_mapping: GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: tuple[bool, ...],
compiler_params: Any,
):
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
@ -386,6 +523,8 @@ def _pallas_call_batching_rule(args, dims, *,
return x
return jnp.squeeze(x, axis=bdim)
# The first num_dynamic_grid_bounds arguments are size-1 arrays that store
# the size of the dynamic bounds.
dynamic_grid_args, args = split_list(
args, [grid_mapping.num_dynamic_grid_bounds]
)
@ -397,10 +536,24 @@ def _pallas_call_batching_rule(args, dims, *,
for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims)
):
dynamic_grid_args = safe_map(
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims)
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims
)
elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims):
raise NotImplementedError(
f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}"
# TODO(amagni, sharadmv): Explore possibility of batching dynamic grid
# bounds.
return _batch_with_explicit_loop(
args=dynamic_grid_args + args,
dims=dynamic_grid_dims + dims,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
)
else:
pass # No dynamic grid dimensions
@ -421,8 +574,23 @@ def _pallas_call_batching_rule(args, dims, *,
args = (*scalar_args, *args)
dims = (*scalar_bdims, *bdims)
else:
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
raise NotImplementedError
# TODO(amagni,sharadmv,apaszke): enable efficient batching over
# prefetched scalar args.
return _batch_with_explicit_loop(
args=scalar_args + args,
dims=scalar_bdims + bdims,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
)
if not dims:
raise NotImplementedError("vmapping pallas_call with no arguments.")
axis_size, = {x.shape[d] for x, d in zip(args, dims)
@ -436,18 +604,9 @@ def _pallas_call_batching_rule(args, dims, *,
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
# TODO(sharadmv): explore a long term solution to output dim inference
# When we have input/output aliasing, since the output will be mapped, we need
# to make sure to broadcast the input across that dimension if it is not
# mapped.
dims_ = list(dims)
args_ = list(args)
for input_index, _ in input_output_aliases:
dim = dims_[input_index]
if dim is batching.not_mapped:
dims_[input_index] = 0
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
args = tuple(args_)
dims = tuple(dims_)
args, dims = _broadcast_input_output_aliases(
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
)
all_dims = list(dims) + [0] * len(out_shapes)
@ -493,6 +652,8 @@ def _pallas_call_batching_rule(args, dims, *,
compiler_params=compiler_params,
)
return out, (0,) * len(out)
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:

View File

@ -286,34 +286,39 @@ class PallasCallScalarPrefetchTest(PallasTPUTest):
o_ref[...] = x_ref[...]
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
s = jnp.tile(s[None], [2, 1])
x = jnp.tile(x[None], [2, 1, 1])
with self.assertRaises(NotImplementedError):
jax.vmap(
pl.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])),
],
out_specs=pl.BlockSpec(
lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2])
),
grid=8,
@jax.jit
@jax.vmap
def kernel(s, x):
return pl.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])),
],
out_specs=pl.BlockSpec(
lambda i, _: (i, 0), (x.shape[0] // 8, x.shape[1])
),
interpret=self.interpret,
)
grid=8,
),
interpret=self.interpret,
)(s, x)
first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:])
second = x[1, ...].reshape((1, 8, 8, -1))[:, s[1, ...]].reshape(x.shape[1:])
expected = jnp.stack([first, second])
np.testing.assert_allclose(kernel(s, x), expected)
class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
interpret: bool = True
@ -434,8 +439,11 @@ class PallasCallDynamicGridTest(PallasTPUTest):
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)()
with self.assertRaises(NotImplementedError):
dynamic_kernel(jnp.array([4, 8], jnp.int32))
out = dynamic_kernel(jnp.array([4, 8], jnp.int32))
first = jnp.full(shape, fill_value=8.0, dtype=jnp.float32)
second = jnp.full(shape, fill_value=16.0, dtype=jnp.float32)
expected_out = jnp.stack([first, second], axis=0)
np.testing.assert_array_equal(out, expected_out)
def test_vmap_dynamic_grid(self):
shape = (8, 128)