mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add loop-based vmap lowering for pallas calls
Loop-based vmap is used for cases in which a pipeline-based vmap is currently not feasible: * Dynamic grid dimensions * Batched scalar prefetch arguments PiperOrigin-RevId: 640530524
This commit is contained in:
parent
9e3f290de3
commit
621814bd7d
@ -39,8 +39,14 @@ from jax._src.pallas import core as pallas_core
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src.util import (
|
||||
split_list, safe_map, safe_zip, weakref_lru_cache,
|
||||
tuple_insert, partition_list, merge_lists)
|
||||
merge_lists,
|
||||
partition_list,
|
||||
safe_map,
|
||||
safe_zip,
|
||||
split_list,
|
||||
tuple_insert,
|
||||
weakref_lru_cache,
|
||||
)
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -367,17 +373,148 @@ def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
|
||||
return block_mapping.replace(block_shape=new_block_shape,
|
||||
index_map_jaxpr=jaxpr)
|
||||
|
||||
def _pallas_call_batching_rule(args, dims, *,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
compiler_params: Any):
|
||||
|
||||
def _broadcast_input_output_aliases(
|
||||
args: Sequence[jax.Array],
|
||||
dims: Sequence[int | batching.NotMapped],
|
||||
*,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
axis_size: int,
|
||||
) -> tuple[tuple[jax.Array, ...], tuple[int | batching.NotMapped, ...]]:
|
||||
"""Broadcast input/output operands.
|
||||
|
||||
When we have input/output aliasing, since the output will be mapped, we need
|
||||
to make sure to broadcast the input across that dimension if it is not
|
||||
mapped.
|
||||
"""
|
||||
|
||||
args_ = list(args)
|
||||
dims_ = list(dims)
|
||||
for input_index, _ in input_output_aliases:
|
||||
dim = dims_[input_index]
|
||||
if dim is batching.not_mapped:
|
||||
dims_[input_index] = 0
|
||||
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
||||
|
||||
return tuple(args_), tuple(dims_)
|
||||
|
||||
|
||||
def _batch_with_explicit_loop(
|
||||
args: Sequence[jax.Array],
|
||||
dims: Sequence[int | batching.NotMapped],
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
compiler_params: Any,
|
||||
):
|
||||
"""Batch the pallas_call by calling it in loop over the batch size.
|
||||
|
||||
This function provides a fallback implementation of batching a pallas_call
|
||||
for the cases in which adding a batch dimension to the pallas grid is not
|
||||
supported. This is currently the case when the batched dimension corresponds
|
||||
to a dynamic axis or a scalar prefetch argument.
|
||||
|
||||
This implementation builds a HLO loop that dynamic_slices the inputs according
|
||||
to the current iteration index and dynamic_updates an (initially empty) output
|
||||
allocation.
|
||||
"""
|
||||
|
||||
if not dims:
|
||||
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
||||
|
||||
(axis_size,) = {
|
||||
arg.shape[dim]
|
||||
for arg, dim in zip(args, dims)
|
||||
if dim is not batching.not_mapped
|
||||
}
|
||||
|
||||
args, dims = _broadcast_input_output_aliases(
|
||||
args,
|
||||
dims,
|
||||
input_output_aliases=input_output_aliases,
|
||||
axis_size=axis_size,
|
||||
)
|
||||
|
||||
# The output arrays are completelly overwritten, so we can just initialize
|
||||
# empty arrays.
|
||||
initial_state = [
|
||||
jnp.empty(
|
||||
tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype
|
||||
)
|
||||
for out_shape in out_shapes
|
||||
]
|
||||
|
||||
def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
|
||||
batch_args = []
|
||||
|
||||
for arg, dim in zip(args, dims):
|
||||
# If the argument is mapped, extract a slice of size 1 in the mapped
|
||||
# dimension at the current index.
|
||||
if dim is batching.not_mapped:
|
||||
batch_args.append(arg)
|
||||
else:
|
||||
batch_args.append(
|
||||
jnp.squeeze(
|
||||
jax.lax.dynamic_slice_in_dim(
|
||||
operand=arg,
|
||||
start_index=batch_index,
|
||||
slice_size=1,
|
||||
axis=dim,
|
||||
),
|
||||
axis=dim,
|
||||
)
|
||||
)
|
||||
|
||||
batch_out = pallas_call_p.bind(
|
||||
*batch_args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
for i, batch_out_array in enumerate(batch_out):
|
||||
state[i] = jax.lax.dynamic_update_index_in_dim(
|
||||
state[i],
|
||||
batch_out_array,
|
||||
batch_index,
|
||||
axis=0,
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
result = jax.lax.fori_loop(0, axis_size, body, initial_state, unroll=False)
|
||||
|
||||
return result, (0,) * len(result)
|
||||
|
||||
|
||||
def _pallas_call_batching_rule(
|
||||
args,
|
||||
dims,
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
compiler_params: Any,
|
||||
):
|
||||
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
@ -386,6 +523,8 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
|
||||
# The first num_dynamic_grid_bounds arguments are size-1 arrays that store
|
||||
# the size of the dynamic bounds.
|
||||
dynamic_grid_args, args = split_list(
|
||||
args, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
@ -397,10 +536,24 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
for arg, bdim in zip(dynamic_grid_args, dynamic_grid_dims)
|
||||
):
|
||||
dynamic_grid_args = safe_map(
|
||||
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims)
|
||||
_maybe_squeeze_out_bdim, dynamic_grid_args, dynamic_grid_dims
|
||||
)
|
||||
elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims):
|
||||
raise NotImplementedError(
|
||||
f"Batched dynamic grid bounds unsupported: {dynamic_grid_dims}"
|
||||
# TODO(amagni, sharadmv): Explore possibility of batching dynamic grid
|
||||
# bounds.
|
||||
return _batch_with_explicit_loop(
|
||||
args=dynamic_grid_args + args,
|
||||
dims=dynamic_grid_dims + dims,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
else:
|
||||
pass # No dynamic grid dimensions
|
||||
@ -421,8 +574,23 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
args = (*scalar_args, *args)
|
||||
dims = (*scalar_bdims, *bdims)
|
||||
else:
|
||||
# TODO(sharadmv,apaszke): enable batching over prefetched scalar args
|
||||
raise NotImplementedError
|
||||
# TODO(amagni,sharadmv,apaszke): enable efficient batching over
|
||||
# prefetched scalar args.
|
||||
return _batch_with_explicit_loop(
|
||||
args=scalar_args + args,
|
||||
dims=scalar_bdims + bdims,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
|
||||
if not dims:
|
||||
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
||||
@ -436,18 +604,9 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
|
||||
# TODO(sharadmv): explore a long term solution to output dim inference
|
||||
|
||||
# When we have input/output aliasing, since the output will be mapped, we need
|
||||
# to make sure to broadcast the input across that dimension if it is not
|
||||
# mapped.
|
||||
dims_ = list(dims)
|
||||
args_ = list(args)
|
||||
for input_index, _ in input_output_aliases:
|
||||
dim = dims_[input_index]
|
||||
if dim is batching.not_mapped:
|
||||
dims_[input_index] = 0
|
||||
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
||||
args = tuple(args_)
|
||||
dims = tuple(dims_)
|
||||
args, dims = _broadcast_input_output_aliases(
|
||||
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
|
||||
)
|
||||
|
||||
all_dims = list(dims) + [0] * len(out_shapes)
|
||||
|
||||
@ -493,6 +652,8 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
return out, (0,) * len(out)
|
||||
|
||||
|
||||
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
|
||||
|
||||
def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
|
||||
|
@ -286,34 +286,39 @@ class PallasCallScalarPrefetchTest(PallasTPUTest):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
s = jnp.array([4, 3, 2, 5, 3, 5, 2, 7], jnp.int32)
|
||||
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
|
||||
x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128))
|
||||
|
||||
def _x_transform(i, s_ref):
|
||||
s = pl.load(s_ref, (i,))
|
||||
return (s, 0)
|
||||
|
||||
s = jnp.tile(s[None], [2, 1])
|
||||
x = jnp.tile(x[None], [2, 1, 1])
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
jax.vmap(
|
||||
pl.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
in_specs=[
|
||||
pl.BlockSpec(_x_transform, (x.shape[1] // 8, x.shape[2])),
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
lambda i, _: (i, 0), (x.shape[1] // 8, x.shape[2])
|
||||
),
|
||||
grid=8,
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def kernel(s, x):
|
||||
return pl.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
in_specs=[
|
||||
pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])),
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
lambda i, _: (i, 0), (x.shape[0] // 8, x.shape[1])
|
||||
),
|
||||
interpret=self.interpret,
|
||||
)
|
||||
grid=8,
|
||||
),
|
||||
interpret=self.interpret,
|
||||
)(s, x)
|
||||
|
||||
first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:])
|
||||
second = x[1, ...].reshape((1, 8, 8, -1))[:, s[1, ...]].reshape(x.shape[1:])
|
||||
|
||||
expected = jnp.stack([first, second])
|
||||
np.testing.assert_allclose(kernel(s, x), expected)
|
||||
|
||||
|
||||
class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
|
||||
interpret: bool = True
|
||||
@ -434,8 +439,11 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dynamic_kernel(jnp.array([4, 8], jnp.int32))
|
||||
out = dynamic_kernel(jnp.array([4, 8], jnp.int32))
|
||||
first = jnp.full(shape, fill_value=8.0, dtype=jnp.float32)
|
||||
second = jnp.full(shape, fill_value=16.0, dtype=jnp.float32)
|
||||
expected_out = jnp.stack([first, second], axis=0)
|
||||
np.testing.assert_array_equal(out, expected_out)
|
||||
|
||||
def test_vmap_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
|
Loading…
x
Reference in New Issue
Block a user