Implements an alternate version of ragged_attention, wherein, the actual attention kernel itself is dense. Meaning, this kernel does not have the compute saving (@when wrapped kernel) or prefetch/index skipping (via index rewriting) as part of the kernel. Rather, the kernel is invoked with a Jumble (A ragged type representation) and pallas takes care of applying the correct work skipping and index rewriting.

Performance wise, we should be at parity, although this has not yet been tested.

Authoring wise, the new kernel is significantly smaller and simpler to write.

A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs.

PiperOrigin-RevId: 689868468
This commit is contained in:
jax authors 2024-10-25 12:06:59 -07:00
parent 5afdbcbae7
commit 6f371212d9
5 changed files with 65 additions and 12 deletions

View File

@ -264,10 +264,17 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
return (BatchTracer(trace, x, spec, source_info_util.current())
if spec is not None else x)
else:
if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis):
# TODO(mvoz): A vaguely questionable assumption that it is always
# sound to have a 0 axis here. This is true for the current use cases
# and comes from how we handle intermediary products of jumbles in
# vmap.
return BatchTracer(trace, x, 0, source_info_util.current())
# TODO(mvoz): This is a terrible place to fall into if you pass
# a non jumble type in, make it clearer what went wrong.
assert False, f'Unexpected type in ELT? {type(x)}'
to_elt_handlers: dict[type, ToEltHandler] = {}
def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
@ -328,16 +335,16 @@ def flatten_fun_for_vmap(in_tree, *args_flat):
yield tree_flatten(ans, is_leaf=is_vmappable)
# Propagate ragged masking rules from invars to outvars
# rule([raggedness_per_invar], outvars) ->
# rule([params], [raggedness_per_invar], outvars) ->
# [raggedness_per_invar, raggedness_per_outvar]
RaggedMaskingRule = Callable[
[list[Any], list[Any]], tuple[list[Any], list[Any]]
[list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]]
]
ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {}
def ragged_mask_elementwise_rule(invar_raggedness, outvars):
def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars):
# TODO(mvoz): A util for getting the ragged representations
first_invar_raggedness = invar_raggedness[0]
for other_invar_raggedness in invar_raggedness[1:]:
@ -348,17 +355,19 @@ def ragged_mask_elementwise_rule(invar_raggedness, outvars):
return invar_raggedness, outvar_raggedness
def ragged_mask_assert_no_op_rule(invar_raggedness, outvars):
def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars):
if any(invar_raggedness):
raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}')
return invar_raggedness, [None] * len(outvars)
def ragged_mask_no_op_rule(invar_raggedness, outvars):
def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars):
return invar_raggedness, [None] * len(outvars)
def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness):
def ragged_mask_transfer_identity(
eqn_params, invar_raggedness, outvar_raggedness
):
assert len(invar_raggedness) == 1, invar_raggedness
outvar_raggedness = invar_raggedness
return invar_raggedness, outvar_raggedness

View File

@ -2296,6 +2296,7 @@ mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule
exp2_p = standard_unop(_float | _complex, 'exp2')
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
@ -2746,6 +2747,7 @@ sub_p = standard_naryop([_num, _num], 'sub')
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract))
batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule
def _mul_transpose(ct, x, y):
@ -2767,6 +2769,7 @@ ad.defjvp(mul_p,
lambda ydot, x, y: mul(x, ydot))
ad.primitive_transposes[mul_p] = _mul_transpose
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply))
batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule
def _div_transpose_rule(cotangent, x, y):
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
@ -2780,6 +2783,7 @@ ad.defjvp(div_p,
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
ad.primitive_transposes[div_p] = _div_transpose_rule
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide))
batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
ad.defjvp(
@ -2803,12 +2807,14 @@ ad.defjvp2(max_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo))
batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
ad.defjvp2(min_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule
shift_left_p = standard_naryop([_int, _int], 'shift_left')
ad.defjvp_zero(shift_left_p)
@ -2895,6 +2901,7 @@ mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE", False))
lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
ad.defjvp_zero(lt_p)
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False))
batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule
eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to')
ad.defjvp_zero(eq_to_p)
@ -3536,12 +3543,37 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
def _dot_general_ragged_prop_rule(invar_raggedness, outvars):
def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
assert len(invar_raggedness) == 2
assert len(outvars) == 1
invar_raggedness_lhs = invar_raggedness[0]
invar_raggedness_rhs = invar_raggedness[1]
dimension_numbers = eqn_params['dimension_numbers']
(lhs_contracting, rhs_contracting), (_, _) = dimension_numbers
if not invar_raggedness_lhs and not invar_raggedness_rhs:
# Both are dense - it is valid to reach here, because dense operations
# are legal in code running under ragged prop.
return invar_raggedness, [None]
if not invar_raggedness_lhs or not invar_raggedness_rhs:
# One ragged, one dense
if not invar_raggedness_lhs:
# left is dense, right is ragged
_, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
if rhs_contracting != ragged_axis_dim_rhs:
# Contraction is on a dense dimension, this is valid!
return invar_raggedness, [None]
if not invar_raggedness_rhs:
# left is ragged, right is dense
_, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
if lhs_contracting != ragged_axis_dim_lhs:
# Contraction is on a dense dimension, this is valid!
return invar_raggedness, [None]
raise NotImplementedError('NYI - dense and ragged dim contraction')
stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
@ -3560,9 +3592,8 @@ def _dot_general_ragged_prop_rule(invar_raggedness, outvars):
assert len(outvars) == 1
# TODO(mvoz): A constant on batching.* ?
dense_jumble_raggedness = None
# Dense (m, n) - no jumble only atm
return invar_raggedness, [dense_jumble_raggedness]
return invar_raggedness, [None]
dot_general_p = standard_primitive(
@ -4205,7 +4236,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)
def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars):
def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
assert len(invar_raggedness) == 1
assert not isinstance(invar_raggedness[0], core.Var)
return invar_raggedness, [None] * len(outvars)
@ -5040,6 +5071,7 @@ ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
_get_sum_identity)
batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
reducer = lambda x, y: [mul(x, y)]
@ -5074,6 +5106,7 @@ ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_max_p, _get_max_identity)
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
_get_max_identity)
batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
@ -5854,9 +5887,11 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding):
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
iota_p = Primitive('iota')
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval)
batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule
def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding):
params = dict(dtype=dtype, shape=shape, dimension=dimension,

View File

@ -945,7 +945,15 @@ def _pallas_call_batching_rule(
)
for invar in eqn.invars
]
invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars)
try:
invar_raggedness, outvar_raggedness = rule(
eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type]
)
except Exception as e:
raise RuntimeError(
f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:"
f" {eqn.outvars}. Underlying reason: {e}"
) from e
for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment]
if isinstance(invar, jax_core.Var):

View File

@ -1984,6 +1984,7 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None)
batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
def _pjit_batcher_for_sharding(
s: sharding.Sharding | UnspecifiedValue,

View File

@ -127,7 +127,7 @@ swap_p = core.Primitive("swap")
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
def swap_ragged_prop_rule(invar_raggedness, outvars):
def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
assert len(invar_raggedness) == 2
invar_raggedness_lhs = invar_raggedness[0]
invar_raggedness_rhs = invar_raggedness[1]