Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.

The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.

PiperOrigin-RevId: 685827096
This commit is contained in:
jax authors 2024-10-14 14:00:58 -07:00
parent fff3b8747f
commit 1f0b5728a4
8 changed files with 546 additions and 95 deletions

View File

@ -89,6 +89,17 @@ def _jumble_flatten(jumble):
aval = jumble.aval.replace(elt_ty=elt_ty)
return (lengths, jumble.data), aval
def _ragged_axis_parts(dim: RaggedAxis) -> tuple[int, int, int]:
stacked_axis = dim.stacked_axis
ragged_axes = dim.ragged_axes
if len(ragged_axes) != 1:
raise ValueError('Multiple ragged axes not yet implemented.')
ragged_axis_dim = ragged_axes[0][0]
ragged_axis_length = ragged_axes[0][1]
return stacked_axis, ragged_axis_dim, ragged_axis_length
def _jumble_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
@ -136,6 +147,7 @@ class RaggedAxis:
new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes)
return RaggedAxis(dst, new_axes)
def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis:
new_ragged_axes = []
for idx, old_idx in enumerate(perm):
@ -315,6 +327,43 @@ def flatten_fun_for_vmap(in_tree, *args_flat):
ans = yield py_args, py_kwargs
yield tree_flatten(ans, is_leaf=is_vmappable)
# Propagate ragged masking rules from invars to outvars
# rule([raggedness_per_invar], outvars) ->
# [raggedness_per_invar, raggedness_per_outvar]
RaggedMaskingRule = Callable[
[list[Any], list[Any]], tuple[list[Any], list[Any]]
]
ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {}
def ragged_mask_elementwise_rule(invar_raggedness, outvars):
# TODO(mvoz): A util for getting the ragged representations
first_invar_raggedness = invar_raggedness[0]
for other_invar_raggedness in invar_raggedness[1:]:
if other_invar_raggedness != first_invar_raggedness:
raise ValueError(f'{other_invar_raggedness} != {first_invar_raggedness}')
outvar_raggedness = [first_invar_raggedness] * len(outvars)
return invar_raggedness, outvar_raggedness
def ragged_mask_assert_no_op_rule(invar_raggedness, outvars):
if any(invar_raggedness):
raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}')
return invar_raggedness, [None] * len(outvars)
def ragged_mask_no_op_rule(invar_raggedness, outvars):
return invar_raggedness, [None] * len(outvars)
def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness):
assert len(invar_raggedness) == 1, invar_raggedness
outvar_raggedness = invar_raggedness
return invar_raggedness, outvar_raggedness
### tracer
# TODO(mattjj): use a special sentinel type rather than None

View File

@ -817,6 +817,7 @@ core.custom_typechecks[cond_p] = partial(_cond_typecheck, False)
core.axis_substitution_rules[cond_p] = _cond_axis_substitution
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
pe.dce_rules[cond_p] = _cond_dce_rule
batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule
def _cond_lowering(ctx, index, *args, branches):
joined_effects = core.join_effects(*(branch.effects for branch in branches))

View File

@ -2330,6 +2330,8 @@ def _sin_lowering(ctx, x):
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
mlir.register_lowering(sin_p, _sin_lowering)
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
def _cos_complex(x):
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
@ -2677,6 +2679,7 @@ add_p: Primitive = standard_naryop([_num, _num], 'add')
ad.primitive_jvps[add_p] = _add_jvp
ad.primitive_transposes[add_p] = _add_transpose
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add))
batching.ragged_prop_rules[add_p] = batching.ragged_mask_elementwise_rule
def _sub_jvp(primals, tangents):
x, y = primals
@ -2834,6 +2837,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
ad.defjvp_zero(eq_p)
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False))
batching.ragged_prop_rules[eq_p] = batching.ragged_mask_elementwise_rule
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True)
ad.defjvp_zero(ne_p)
@ -2970,6 +2974,9 @@ pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
pe.def_trivial_padding(convert_element_type_p)
core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
batching.ragged_prop_rules[convert_element_type_p] = (
batching.ragged_mask_elementwise_rule
)
def _real_dtype(dtype): return np.finfo(dtype).dtype
@ -3459,9 +3466,42 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
(list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch)))
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
def _dot_general_ragged_prop_rule(invar_raggedness, outvars):
assert len(invar_raggedness) == 2
assert len(outvars) == 1
invar_raggedness_lhs = invar_raggedness[0]
invar_raggedness_rhs = invar_raggedness[1]
stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
if stacked_axis_rhs != 0 or stacked_axis_lhs != 0:
raise NotImplementedError(
'Dot general ragged prop for non 0 stacked axis, NYI'
)
# We only support ragged k atm, that is, lhs is (m, ragged_k) and rhs is
# (ragged_k, n), meaning the output is dense.
if ragged_axis_dim_lhs != 2 or ragged_axis_dim_rhs != 1:
raise NotImplementedError(
'Dot general ragged prop for non contraction raggedness, NYI'
)
assert len(outvars) == 1
# TODO(mvoz): A constant on batching.* ?
dense_jumble_raggedness = None
# Dense (m, n) - no jumble only atm
return invar_raggedness, [dense_jumble_raggedness]
dot_general_p = standard_primitive(
_dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general',
sharding_rule=_dot_general_sharding_rule)
_dot_general_shape_rule,
_dot_general_dtype_rule,
'dot_general',
sharding_rule=_dot_general_sharding_rule,
)
def _dot_general_batch_unpack_args(batch_args):
@ -3494,6 +3534,7 @@ _dot_general_batch_rule = functools.partial(
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule
def precision_attr(precision: Precision) -> ir.ArrayAttr:
if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
@ -4055,6 +4096,13 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)
def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars):
assert len(invar_raggedness) == 1
assert not isinstance(invar_raggedness[0], core.Var)
return invar_raggedness, [None] * len(outvars)
broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
@ -4067,6 +4115,9 @@ pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
batching.ragged_prop_rules[broadcast_in_dim_p] = (
_broadcast_in_dim_ragged_prop_rule
)
def _clamp_shape_rule(min, operand, max):
@ -4337,6 +4388,7 @@ squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
pe.def_trivial_padding(squeeze_p)
batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule
def _squeeze_lower(ctx, operand, *, dimensions):
del dimensions # Implied by the output aval.
@ -4958,6 +5010,7 @@ reduce_and_p = standard_primitive(
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
weak_type_rule=_strip_weak_type)
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule
reduce_xor_p = standard_primitive(

View File

@ -1243,6 +1243,9 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice')
ad.deflinear2(slice_p, _slice_transpose_rule)
batching.primitive_batchers[slice_p] = _slice_batching_rule
# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries
# or supporting nested jumbles. NYI.
batching.ragged_prop_rules[slice_p] = batching.ragged_mask_no_op_rule
# Override the standard impl to defer to dynamic_slice whenever possible.
# This lets us reuse the same program for many applications of slicing for as

View File

@ -279,6 +279,10 @@ def _pallas_call_impl_interpret(
carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks])
with pallas_core.grid_env(local_grid_env):
for s in scalars:
if isinstance(s.dtype, jax_core.bint):
aval = jax_core.get_aval(s)
s.aval = aval.update(dtype=jnp.int32)
start_indices = [
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
@ -293,10 +297,6 @@ def _pallas_call_impl_interpret(
len(blocks),
len(scratch_values),
)
for s in scalars:
aval = jax_core.get_aval(s)
if isinstance(aval, jax_core.DShapedArray):
s.aval = aval.update(dtype=jnp.int32)
blocks = jax_core.eval_jaxpr(
discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch
@ -437,10 +437,11 @@ ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
def _batch_block_mapping(
grid_mapping: GridMapping,
axis_size: int,
for_ragged: bool,
aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping,
for_ragged: bool,
ragged_axis_values,
) -> BlockMapping:
def _block_map_function(new_idx, *args):
if for_ragged:
@ -466,14 +467,7 @@ def _batch_block_mapping(
if for_ragged:
if isinstance(dim, batching.RaggedAxis):
assert for_ragged, "Ragged axis not supported for non-ragged batching."
_, _, ragged_axis_length = _ragged_axis_parts(dim)
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
if isinstance(aval, jax_core.DShapedArray):
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
lengths_aval = pallas_core.AbstractMemoryRef(
aval,
pallas_core.MemorySpace.INDEX,
)
_, _, _, lengths_aval = ragged_axis_values
idx_avals = [*idx_avals, lengths_aval]
else:
i32_aval_memref = pallas_core.AbstractMemoryRef(
@ -538,6 +532,12 @@ def _broadcast_input_output_aliases(
for input_index, _ in input_output_aliases:
dim = dims_[input_index]
dims_[input_index] = 0
if isinstance(dim, batching.RaggedAxis):
stacked_axis = dim.stacked_axis
if stacked_axis != 0:
raise NotImplementedError("Ragged aliasing on non 0 dim NYI")
return tuple(args_), tuple(dims_)
if dim is batching.not_mapped:
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
elif dim != 0:
@ -643,16 +643,6 @@ def _batch_with_explicit_loop(
return result, (0,) * len(result)
def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]:
stacked_axis = dim.stacked_axis
ragged_axes = dim.ragged_axes
if len(ragged_axes) != 1:
raise ValueError("Multiple ragged axes not yet implemented.")
ragged_axis_dim = ragged_axes[0][0]
ragged_axis_length = ragged_axes[0][1]
return stacked_axis, ragged_axis_dim, ragged_axis_length
def _pallas_call_batching_rule(
args,
dims,
@ -674,21 +664,10 @@ def _pallas_call_batching_rule(
return x
return jnp.squeeze(x, axis=bdim)
all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)]
if len(all_ragged_axes) > 1:
raise ValueError("Multiple ragged dimensions not yet implemented.")
if all_ragged_axes:
stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts(
all_ragged_axes[0]
)
else:
stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None
def get_size(i, x, d):
if not isinstance(d, batching.RaggedAxis):
return x.shape[d]
return x.aval.shape[i]
return x.aval.shape[d.stacked_axis]
(axis_size,) = {
get_size(i=i, x=x, d=d)
@ -793,35 +772,49 @@ def _pallas_call_batching_rule(
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
)
# Each dim either has data about its ragged axis, or None
ragged_axis_values = []
for d in dims:
if isinstance(d, batching.RaggedAxis):
stacked_axis, ragged_axis_dim, ragged_axis_length = (
batching._ragged_axis_parts(d)
)
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
if isinstance(aval, jax_core.DShapedArray):
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
lengths_aval = pallas_core.AbstractMemoryRef(
aval,
pallas_core.MemorySpace.INDEX,
)
# TODO(mvoz): Give this its own type
ragged_axis_values.append(
(stacked_axis, ragged_axis_dim, ragged_axis_length, lengths_aval)
)
else:
ragged_axis_values.append(None) # type: ignore[arg-type]
all_dims = list(dims) + [0] * grid_mapping.num_outputs
ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs # type: ignore[list-item]
num_index_operands = grid_mapping.num_index_operands
num_scratch_operands = grid_mapping.num_scratch_operands
lengths_aval = None
if ragged_axis_length is not None:
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
if isinstance(aval, jax_core.DShapedArray):
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
lengths_aval = pallas_core.AbstractMemoryRef(
aval,
pallas_core.MemorySpace.INDEX,
)
# Only add a batch dimension for the avals that actually have a grid mapping.
# This excludes scalar prefetch inputs (the first in the list) and scratch
# operands (the last in the list).
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
batched_block_mappings = map(
partial(
_batch_block_mapping,
grid_mapping,
axis_size,
for_ragged=lengths_aval is not None,
any(ragged_axis_values),
),
avals_to_batch,
all_dims[num_index_operands:],
block_mappings,
ragged_axis_values[num_index_operands:],
)
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(
@ -829,6 +822,17 @@ def _pallas_call_batching_rule(
assert not index_map_tree_kwargs
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
lengths_aval = None # type: ignore[assignment]
# Check all the ragged axis values, ensure their raggedness pattern
# is identical (consider moving this check up!)
for rav in ragged_axis_values:
if rav is not None:
if lengths_aval is None:
lengths_aval = rav[3]
else:
assert lengths_aval == rav[3], "NYI - different lengths in ragged batch"
if lengths_aval:
batched_index_map_args = batched_index_map_args + (lengths_aval,)
num_index_operands += 1
@ -854,6 +858,14 @@ def _pallas_call_batching_rule(
else:
batched_cost_estimate = None
# Start the ragged handling code
# Here, we:
# - Rewrite the indexer to save memory (skip indices outside the ragged bounds)
# - Rewrite the kernel to save compute (skip elements outside the ragged bounds)
# - Update various internal structures/metadata to account for the new
# block spec.
# - Set the hacky flag of ragged_originating on the mapping, to signal to
# the lowering code to treat mapped dimensions as part of the user grid.
if lengths_aval:
batched_grid_mapping = batched_grid_mapping.replace(
get_grid_indices=lambda indices, maybe_include_mapped_dims: indices,
@ -868,53 +880,102 @@ def _pallas_call_batching_rule(
# we write 0s. This is useful for debugging, as certain lowering paths
# like mosaic will write the last data as passthrough, leading to
# potentially confusing results.
debug_zero_fill_counterfactual = debug
first_block_mapping = batched_grid_mapping.block_mappings[0]
block_mapped_dim_idxs = []
for block_mapping in batched_grid_mapping.block_mappings:
# This invariant may already be checked elsewhere, but lets reaffirm it
assert block_mapping.block_shape == first_block_mapping.block_shape, (
f"block_mapping.block_shape: {block_mapping.block_shape}, "
f"first_block_mapping.block_shape: {first_block_mapping.block_shape}"
)
assert (
block_mapping.array_shape_dtype
== first_block_mapping.array_shape_dtype
), (
f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype},"
" first_block_mapping.array_shape_dtype:"
f" {first_block_mapping.array_shape_dtype}"
)
mapped_dim_idxs = []
for i, d in enumerate(block_mapping.block_shape):
if d is pallas_core.mapped:
mapped_dim_idxs.append(i)
else:
mapped_dim_idxs.append(None) # type: ignore[arg-type]
block_mapped_dim_idxs.append(mapped_dim_idxs)
mapped_dim_idxs = [
i
for i, d in enumerate(first_block_mapping.block_shape)
if d is pallas_core.mapped
]
assert len(mapped_dim_idxs) == 1
mapped_dim_idx = mapped_dim_idxs[0]
if stacked_axis != mapped_dim_idx:
raise ValueError(
f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}"
)
mapped_dim_idx = None
for rav, mapped_dim_idxs in zip(ragged_axis_values, block_mapped_dim_idxs):
if rav is not None:
stacked_axis = rav[0]
if mapped_dim_idx is None:
mapped_dim_idx = mapped_dim_idxs[stacked_axis]
if mapped_dim_idxs[stacked_axis] is None:
raise ValueError(
f"Expected mapped dim to be {stacked_axis}, but got"
f" {mapped_dim_idxs[stacked_axis]}"
)
else:
assert mapped_dim_idx == mapped_dim_idxs[stacked_axis], (
f"Different mapped dims - expected {mapped_dim_idx}, but got"
f" {mapped_dim_idxs[stacked_axis]}"
)
assert ragged_axis_dim is not None, "Invariant violation"
# This is the blockspec size of the dimension
val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim]
block_shapes = [b.block_shape for b in batched_grid_mapping.block_mappings]
# Parse out the operations from the jaxpr to determine how to mask the output
# NOTE! while this *could* be a default dict of None, and None is sound, as
# it denotes that there is no raggedness for the given var, we explicitly
# do not do this, so as to get better signal on implementation of rules
# A misimplemented rule that does not account for new vars being introduced
# will result in an error on the next op using the new var. The benefit of
# of forcing implementers to account for all outputs and intermediaries is
# a very nice one.
var_to_raggedness = {}
for invar, rav in zip(jaxpr.invars, ragged_axis_values):
var_to_raggedness[invar] = rav
for eqn in jaxpr.eqns:
prim = eqn.primitive
if prim not in batching.ragged_prop_rules:
raise NotImplementedError(f"Not implemented - ragged prop for {prim}")
rule = batching.ragged_prop_rules[prim]
invar_raggedness = [
(
var_to_raggedness.get(invar, None)
if isinstance(invar, jax_core.Var)
else None
)
for invar in eqn.invars
]
invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars)
for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment]
if isinstance(invar, jax_core.Var):
var_to_raggedness[invar] = rav
for outvar, rav in zip(eqn.outvars, outvar_raggedness):
if isinstance(outvar, jax_core.Var):
var_to_raggedness[outvar] = rav
for pos, invar in enumerate(jaxpr.invars):
ragged_axis_values[pos] = var_to_raggedness[invar]
per_input_ragged_axis_dim = []
for rav in ragged_axis_values:
if rav is not None:
per_input_ragged_axis_dim.append(rav[1])
else:
per_input_ragged_axis_dim.append(None)
def when_wrapped_kernel(lengths_ref, *args, **kwargs):
b_idx = primitives.program_id(stacked_axis)
i_idx = (
primitives.program_id(ragged_axis_dim)
* val_at_ragged_dim
)
b_idx = primitives.program_id(mapped_dim_idx)
b_len = lengths_ref[b_idx]
run_kernel = jnp.array(True)
for i, _ in enumerate(args):
ragged_axis_dim = per_input_ragged_axis_dim[i]
if ragged_axis_dim is None:
continue
arg_i_idx = (
primitives.program_id(ragged_axis_dim)
* block_shapes[i][ragged_axis_dim]
)
run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len)
# TODO(mvoz): Unimplemented primitive in pallas
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
@pallas_utils.when(i_idx < b_len)
@pallas_utils.when(run_kernel)
def f():
# Important! This allows us to trace the inner kernel with the correct
# grid to preserve user program_id semantics. Ex: program_id(0) will
@ -922,20 +983,108 @@ def _pallas_call_batching_rule(
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs)
if debug_zero_fill_counterfactual:
@pallas_utils.when(i_idx >= b_len)
def g():
for arg_ref in args:
arg_ref[...] = jnp.zeros_like(arg_ref)
kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars]
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(
list(kernel_avals)
)
def _rewrite_index_jaxpr(enumerate_batched_block_mapping):
arg_pos, batched_block_mapping = enumerate_batched_block_mapping
indexer_avals = [
v.aval for v in batched_block_mapping.index_map_jaxpr.jaxpr.invars
]
flat_indexer_avals, indexer_in_tree = tree_util.tree_flatten(
list(indexer_avals)
)
def index_rewrite_kernel(*indexer_args):
ragged_axis_dim = per_input_ragged_axis_dim[arg_pos]
# the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means
# that the indexer for output isnt getting written - before, it always was
lengths_ref = indexer_args[-1]
rest_indexer_args = indexer_args[:-1]
# Lengths are always the last argument of the indexer.
# lengths_ref = args[-1]
# Invariant: Stacked axis is enforced to be the mapped axis above.
b_idx = indexer_args[mapped_dim_idx]
nargs = list(rest_indexer_args)
if ragged_axis_dim is not None:
val_at_ragged_dim = batched_block_mapping.block_shape[ragged_axis_dim]
# The current index into the ragged dimension.
# Invariant: There is only one ragged dimension, enforced above.
i_idx = indexer_args[ragged_axis_dim]
# grid space -> element space
i_len = i_idx * val_at_ragged_dim
# The length of the current batch.
b_len = lengths_ref[b_idx]
# Have we reached the end of the current batch?
not_done = i_len < b_len
am_last_batch = b_idx == axis_size - 1
last_good_block = lax.div(b_len, val_at_ragged_dim) - 1
# The logic below can be thought of as:
# if index_oob_ragged:
# if not last_batch:
# batch_idx += 1
# ragged_idx = 0
# else:
# ragged_idx = last_good_block
#
# wherein we find the next good block by incrementing the batch index
# and setting the ragged index to 0 if we are not in the last batch.
# Otherwise, we set the ragged index to the last good block.
b_next = jnp.where(
not_done, b_idx, jnp.where(am_last_batch, b_idx, b_idx + 1)
)
i_next = jnp.where(
not_done, i_idx, jnp.where(am_last_batch, last_good_block, 0)
)
nargs[ragged_axis_dim] = i_next
nargs[mapped_dim_idx] = b_next
nargs = nargs + [lengths_ref]
return jax_core.eval_jaxpr(
batched_block_mapping.index_map_jaxpr.jaxpr,
batched_block_mapping.index_map_jaxpr.consts,
*nargs,
)
index_jaxpr, _ = _trace_kernel_to_jaxpr(
index_rewrite_kernel,
"index_rewrite_kernel",
batched_grid_mapping,
tuple(flat_indexer_avals),
indexer_in_tree,
tuple(() for _ in flat_indexer_avals),
interpret=interpret,
indexer=True,
)
batched_block_mapping = batched_block_mapping.replace(
index_map_jaxpr=pe.close_jaxpr(index_jaxpr)
)
return batched_block_mapping
# Important! This allows us to trace the outer kernel with the correct grid
# to enable accessing the batch program_id.
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
batched_block_mappings = map(
_rewrite_index_jaxpr, enumerate(batched_block_mappings)
)
batched_grid_mapping = batched_grid_mapping.replace(
block_mappings=tuple(batched_block_mappings),
)
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"
jaxpr, consts = _trace_kernel_to_jaxpr(
@ -950,7 +1099,15 @@ def _pallas_call_batching_rule(
if consts:
raise NotImplementedError("consts not supported in pallas_call")
assert ragged_axis_length is not None
# We need to rewrite the input_output_aliases here, the initial call
# to broadcast is done, and we have inseted a new input (lengths), so
# there's an off-by-one here now.
new_input_output_aliases = []
for k, v in input_output_aliases:
new_input_output_aliases.append((k + 1, v))
input_output_aliases = tuple(new_input_output_aliases)
# assert ragged_axis_length is not None
args = (ragged_axis_length, *args)
assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals)
batched_out_avals = tuple(
@ -1225,7 +1382,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
return new_error, results
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
# All of those shenanigans are because we can't make TransformedRef a PyTree,
# because they should appear as atomic JAX values to the users.
@lu.transformation
@ -1247,6 +1403,7 @@ def _trace_kernel_to_jaxpr(
kernel_in_tree: tree_util.PyTreeDef,
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
interpret: bool,
indexer: bool = False,
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
if interpret:
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
@ -1268,7 +1425,7 @@ def _trace_kernel_to_jaxpr(
"You should pass them as inputs")
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
if not indexer and kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
f"should return None. It returns a PyTree: {kernel_out_tree}")
@ -1673,6 +1830,18 @@ def pallas_call(
f"[0, {len(flat_out_avals)})")
in_aval = flat_in_avals[i_idx]
out_aval = flat_out_avals[o_idx]
if isinstance(in_aval, jax_core.DShapedArray):
new_shape = []
for d in in_aval.shape:
if isinstance(d, int):
new_shape.append(d)
else:
new_shape.append(d.dtype.bound)
in_aval = jax_core.ShapedArray(
tuple(new_shape), in_aval.dtype, in_aval.weak_type
)
if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype:
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "

View File

@ -51,6 +51,7 @@ map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
program_id_p = jax_core.Primitive("program_id")
batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule
def program_id(axis: int) -> jax.Array:
"""Returns the kernel execution position along the given axis of the grid.

View File

@ -63,6 +63,7 @@ traceback_util.register_exclusion(__file__)
# a:f32[3] <- x[]
get_p = core.Primitive("get")
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity
Indexer = Union[int, slice, Array, types.EllipsisType]
@ -122,6 +123,16 @@ swap_p = core.Primitive("swap")
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
def swap_ragged_prop_rule(invar_raggedness, outvars):
assert len(invar_raggedness) == 2
invar_raggedness_lhs = invar_raggedness[0]
invar_raggedness_rhs = invar_raggedness[1]
return [invar_raggedness_rhs, invar_raggedness_lhs], [None]
batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule
def ref_swap(
ref_or_view: AbstractRef | TransformedRef,
idx: Indexer | tuple[Indexer, ...] | None,

View File

@ -111,8 +111,81 @@ class PallasCallRaggedVmapTest(PallasBaseTest):
ragged_total = 0
for dim in ragged_shape:
ragged_total += row_count * dim * 128
# See note - on zero filling counterfactuals
self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total)
def correct(v):
return np.count_nonzero(v == jnp.sin(1.0))
for b, batch in enumerate(res):
ragged_val = ragged_shape[b]
for r, row in enumerate(batch):
row_total = ragged_val * 128
self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}")
self.assertEqual(correct(res), ragged_total)
def test_vmap_jumble_over_add_kernel(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Only tested on TPU")
row_count = 8
col_grid_size = 5
ragged_shape = [3, 1, 4]
sizes = lax.convert_element_type(
jnp.array([128 * x for x in ragged_shape]),
core.bint(col_grid_size * 128),
)
x = jax.vmap(
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
)(sizes)
y = jax.vmap(
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
)(sizes)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
def invoke_kernel(x, y):
return pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec((8, 128), lambda j, k: (j, k)),
pl.BlockSpec((8, 128), lambda j, k: (j, k)),
],
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
out_shape=jax.ShapeDtypeStruct(
(8, col_grid_size * 128), dtype=jnp.float32
),
grid=(1, col_grid_size),
interpret=False,
)(x, y)
res = jax.vmap(
invoke_kernel,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x, y)
res = res.data
total = len(ragged_shape) * row_count * col_grid_size * 128
res_total = np.prod(res.shape)
self.assertEqual(res_total, total)
ragged_total = 0
for dim in ragged_shape:
ragged_total += row_count * dim * 128
def correct(v):
return np.count_nonzero(v == 2.0)
for r, row in enumerate(res):
ragged_val = ragged_shape[r]
row_total = ragged_val * 128 * row_count
self.assertEqual(correct(row), row_total)
for col in row:
col_total = ragged_val * 128
self.assertEqual(correct(col), col_total)
self.assertEqual(np.count_nonzero(res == 2.0), ragged_total)
def test_vmap_jumble_over_sin_kernel_grid_remapping(self):
if not jtu.test_device_matches(["tpu"]):
@ -150,6 +223,97 @@ class PallasCallRaggedVmapTest(PallasBaseTest):
axis_size=3,
)(x)
def test_vmap_jumble_over_matmul_kernel(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Only tested on TPU")
m = 128
k = 640
n = 640
def matmul_kernel(x_ref, y_ref, x_sentinel, z_ref):
# weird little once-only reset
@pl.when(x_sentinel[...][0][0] == 1.0)
def _():
z_ref[...] = jnp.zeros_like(z_ref)
x_sentinel[...] = jnp.zeros_like(x_sentinel)
z_ref[...] += x_ref[...] @ y_ref[...]
def matmul(
x: jax.Array,
y: jax.Array,
x_sentinel: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 640,
):
# m, k = x.shape
# _, n = y.shape
# a (1, 5) grid
# TODO(mvoz): parameterize this grid?
grid = (n // bn, k // bk)
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
in_specs=[
pl.BlockSpec(
(bm, bk),
lambda j, k: (0, k),
),
pl.BlockSpec(
(bk, bn),
lambda j, k: (k, j),
),
pl.BlockSpec(
(bm, bn),
lambda j, k: (0, j),
),
],
out_specs=pl.BlockSpec(
(bm, bn),
lambda j, k: (0, j),
),
grid=grid,
input_output_aliases={2: 0},
interpret=False,
)(x, y, x_sentinel)
# TODO(mvoz): parameterize this shape?
ragged_shape = [3, 1, 4]
sizes = lax.convert_element_type(
jnp.array([128 * x for x in ragged_shape]),
core.bint(k),
)
x = jax.vmap(lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis)(
sizes
)
x_sentinel = jax.vmap(
lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis
)(sizes)
y = jax.vmap(lambda k_: jnp.ones((k_, n)), out_axes=batching.jumble_axis)(
sizes
)
res = jax.vmap(
matmul,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x, y, x_sentinel)
ref = jax.vmap(
jnp.dot,
out_axes=batching.jumble_axis,
in_axes=batching.jumble_axis,
axis_size=3,
)(x, y)
ref = ref.data
res = res.data
np.testing.assert_allclose(ref, res)
def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Only tested on TPU")