[pallas] Fix the handling of captured consts

There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
This commit is contained in:
George Necula 2024-07-19 20:22:21 +03:00
parent b3469a61d1
commit b7105ccd19
11 changed files with 183 additions and 32 deletions

View File

@ -245,7 +245,7 @@ class BlockMapping:
index_map_jaxpr: jax_core.ClosedJaxpr
indexing_mode: IndexingMode
def compute_start_indices(self, loop_idx, *args):
def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
)
@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping(
f"{len(aval.shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
if consts:
raise NotImplementedError(
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
f"{consts}")
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
)

View File

@ -274,7 +274,7 @@ class MosaicGridMapping:
self.mapped_dims = grid_mapping.mapped_dims
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch]
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
num_operands = (
len(self.jaxpr.invars)
- num_scalar_prefetch
@ -411,8 +411,12 @@ def lower_jaxpr_to_module(
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
else:
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
block_operand_shapes = (
*[jax.ShapeDtypeStruct(v.aval.shape,
v.aval.dtype)
for v in invars[:grid_mapping.num_constant_operands]],
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
@ -425,8 +429,6 @@ def lower_jaxpr_to_module(
raise NotImplementedError(
"BlockSpecs are required on TPU when grid is specified"
)
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
# We may not require windowing if our block_shape matches the original

View File

@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule(
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,

View File

@ -42,6 +42,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*args,
@ -68,7 +69,10 @@ def pallas_call_lowering(
if debug:
print(jaxpr)
print(grid_mapping)
if grid_mapping.num_constant_operands:
raise NotImplementedError(
"captured consts not supported in the Mosaic GPU backend"
)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
in_shapes,

View File

@ -187,12 +187,13 @@ def _pallas_call_impl_interpret(
)
assert next(dynamic_grid_args_iter, None) is None
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
# args now contains: *consts, *inputs, *outputs
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
@ -243,6 +244,10 @@ def _pallas_call_impl_interpret(
else:
# Base case is always one iteration when grid is ()
num_iterations = 1
# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
# i:int32 is the interation index
# loop_idx: tuple[int32] are the program ids for each grid axis
def cond(carry):
i, *_ = carry
return i < num_iterations
@ -256,7 +261,7 @@ def _pallas_call_impl_interpret(
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
@ -269,13 +274,14 @@ def _pallas_call_impl_interpret(
len(blocks),
len(scratch_values),
)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
@ -604,13 +610,13 @@ def _pallas_call_batching_rule(
# Ordinarily, adding support for scalar prefetch in vmap would involve
# modifying the block specs in a nontrivial way. However, if we are only
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
# and pretend we were never vmapping over them at all.
# and pretend we were never vmapped over them at all.
if all(
bdim is batching.not_mapped or arg.shape[bdim] == 1
for arg, bdim in zip(scalar_args, scalar_bdims)
):
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
scalar_bdims = [None] * len(scalar_args)
scalar_bdims = [batching.not_mapped] * len(scalar_args)
args = (*scalar_args, *args)
dims = (*scalar_bdims, *bdims)
else:
@ -648,6 +654,7 @@ def _pallas_call_batching_rule(
all_dims = list(dims) + [0] * len(out_shapes)
num_index_operands = grid_mapping.num_index_operands
num_constant_operands = grid_mapping.num_constant_operands
num_scratch_operands = grid_mapping.num_scratch_operands
# Only add a batch dimension for the avals that actually have a grid mapping.
@ -661,11 +668,16 @@ def _pallas_call_batching_rule(
block_mappings,
)
# TODO(necula): should fix in_shapes to include the consts
dims_no_consts = (
dims[:num_index_operands] +
dims[num_index_operands + num_constant_operands:]
)
batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
tuple_insert(x.shape, dim, axis_size),
x.dtype)
for x, dim in zip(in_shapes, dims))
for x, dim in zip(in_shapes, dims_no_consts))
batched_out_shapes = tuple(
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
for x in out_shapes)
@ -900,11 +912,37 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
# the real inputs and outputs.
jaxpr = state_utils.hoist_consts_to_refs(
jaxpr,
index=grid_mapping.num_index_operands,
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
num_constant_operands = len(consts)
# TODO(necula): refactor grid_mapping to remove this code duplication
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
if grid_mapping.num_index_operands:
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
grid_avals,
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
aval=jax_core.ShapedArray(c.shape, c.dtype),
in_tree=grid_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
)
const_block_mappings.append(const_block_mapping)
grid_mapping = grid_mapping.replace(
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
num_constant_operands=len(consts),
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
return grid_mapping, jaxpr, consts, out_tree_thunk()
@ -1105,8 +1143,9 @@ def pallas_call(
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
f"a different abstract value {out_aval}.")
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
jaxpr=jaxpr, name=name,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),

View File

@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context:
def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
in_shapes,
in_out_shapes,
grid_mapping: GridMapping,
name: str,
platform: str
@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
consts_shapes = [
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
]
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module(
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
(*in_shapes, *[()] * grid_mapping.num_constant_operands),
(*consts_shapes, *in_out_shapes),
grid_mapping.block_mappings,
start_indices,
)

View File

@ -54,6 +54,7 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*in_nodes,
@ -72,6 +73,10 @@ def pallas_call_lowering(
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
[lowering_platform] = ctx.platforms or ctx.module_context.platforms

View File

@ -13,6 +13,8 @@
# limitations under the License.
"""Utilities for tracing stateful functions."""
from typing import Callable
from jax._src.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
@ -24,13 +26,20 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
def hoist_consts_to_refs(
jaxpr: core.Jaxpr,
*,
index: int = 0,
make_abstract_ref: Callable[[core.AbstractValue], AbstractRef] = lambda aval: AbstractRef(aval)
) -> core.Jaxpr:
"""Hoists the constants in the given jaxpr into invars.
Args:
jaxpr: The jaxpr.
index: The index where the invars for the constants should be inserted.
By default, the new invars are inserted *before* any existing invars.
make_abstract_ref: a callable to construct an AbstractRef, or subtype
thereof, from a constant AbstractValue.
Returns:
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
@ -42,7 +51,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
]
const_avals = [
var.aval if is_ref else AbstractRef(var.aval)
var.aval if is_ref else make_abstract_ref(var.aval)
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
]
in_avals = [var.aval for var in jaxpr.invars]

View File

@ -150,7 +150,6 @@ class PallasBaseTest(jtu.JaxTestCase):
class PallasCallTest(PallasBaseTest):
def test_add_one(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
@ -468,19 +467,22 @@ class PallasCallTest(PallasBaseTest):
def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557.
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
x = jnp.zeros(32)
indices = jnp.arange(4).reshape((2, 2))
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid=(2,),
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
)
def kernel(src, dst):
dst[indices] = src[indices]
dst[0:1] = to_store
jax.block_until_ready(kernel(x))
res = kernel(x)
self.assertAllClose(res[0:1], to_store)
def test_vector_slicing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
@ -744,6 +746,17 @@ class ApiErrorTest(PallasBaseTest):
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
f(a)
def test_pallas_call_index_map_captures_consts(self):
a = np.arange(256, dtype=np.int32)
index_map_result = np.array([0], dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
with self.assertRaisesRegex(
NotImplementedError,
"Index map for input\\[0\\] captures constants"):
f(a)
def test_pallas_call_out_specs_mismatch_shape(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,

View File

@ -37,7 +37,7 @@ config.parse_flags_with_absl()
@jtu.with_config(jax_traceback_filtering="off")
class PallasTest(jtu.JaxTestCase):
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False
def setUp(self):
@ -58,7 +58,7 @@ class PallasTest(jtu.JaxTestCase):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
class PallasCallVmapTest(PallasTest):
class PallasCallVmapTest(PallasBaseTest):
def setUp(self):
super().setUp()
@ -130,6 +130,26 @@ class PallasCallVmapTest(PallasTest):
out_ref = jnp.arange(1, 9).reshape((4, 2))
np.testing.assert_allclose(out, out_ref)
def test_vmap_with_hoisted_consts(self):
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128))
@jax.vmap
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid=(2,),
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
)
def kernel(src, dst):
dst[0:1] = to_store
res = kernel(x)
for i in range(x.shape[0]):
self.assertAllClose(res[i, 0:1], to_store)
def test_vmap_of_kernel_with_input_output_aliases(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),

View File

@ -17,6 +17,7 @@
import contextlib
import functools
import io
import math
import re
import sys
from absl.testing import absltest
@ -71,7 +72,6 @@ class PallasBaseTest(jtu.JaxTestCase):
class PallasCallScalarPrefetchTest(PallasBaseTest):
def test_trivial_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
@ -115,6 +115,56 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
)(s, x)
np.testing.assert_array_equal(out, x)
@jtu.parameterized_filterable(
kwargs=[
dict(scratch=scratch, vmap=vmap)
for scratch in [True, False]
for vmap in [True, False]
]
)
def test_scalar_prefetch_hoisted_const(self, *, scratch: bool, vmap: bool):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64")
# to_store will be hoisted as constants. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
if vmap:
x_shape = (4, 16, 128)
else:
x_shape = (16, 128)
x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape)
def f(x):
s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(2,),
in_specs=[pl.BlockSpec((8, 128),
lambda i, s_ref: (pl.load(s_ref, (i,)), 0))],
out_specs=pl.BlockSpec((32, 128),
lambda i, s_ref: (pl.load(s_ref, i), 0)),
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch
else []),
),
)
def kernel(s_ref, src, dst, *scratch_refs):
store_idx = s_ref[pl.program_id(0)]
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store)
return kernel(s, x)
if vmap:
f = jax.vmap(f)
res = f(x)
if vmap:
for i in range(x.shape[0]):
self.assertAllClose(res[i, 0:1], to_store)
self.assertAllClose(res[i, 33:34], to_store)
else:
self.assertAllClose(res[0:1], to_store)
self.assertAllClose(res[33:34], to_store)
def test_block_spec_with_wrong_block_shape_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]