mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Implements an alternate version of ragged_attention, wherein, the actual attention kernel itself is dense. Meaning, this kernel does not have the compute saving (@when wrapped kernel) or prefetch/index skipping (via index rewriting) as part of the kernel. Rather, the kernel is invoked with a Jumble (A ragged type representation) and pallas takes care of applying the correct work skipping and index rewriting.
Performance wise, we should be at parity, although this has not yet been tested. Authoring wise, the new kernel is significantly smaller and simpler to write. A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs. PiperOrigin-RevId: 689868468
This commit is contained in:
parent
5afdbcbae7
commit
6f371212d9
@ -264,10 +264,17 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
return (BatchTracer(trace, x, spec, source_info_util.current())
|
||||
if spec is not None else x)
|
||||
else:
|
||||
if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis):
|
||||
# TODO(mvoz): A vaguely questionable assumption that it is always
|
||||
# sound to have a 0 axis here. This is true for the current use cases
|
||||
# and comes from how we handle intermediary products of jumbles in
|
||||
# vmap.
|
||||
return BatchTracer(trace, x, 0, source_info_util.current())
|
||||
# TODO(mvoz): This is a terrible place to fall into if you pass
|
||||
# a non jumble type in, make it clearer what went wrong.
|
||||
assert False, f'Unexpected type in ELT? {type(x)}'
|
||||
|
||||
|
||||
to_elt_handlers: dict[type, ToEltHandler] = {}
|
||||
|
||||
def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
|
||||
@ -328,16 +335,16 @@ def flatten_fun_for_vmap(in_tree, *args_flat):
|
||||
yield tree_flatten(ans, is_leaf=is_vmappable)
|
||||
|
||||
# Propagate ragged masking rules from invars to outvars
|
||||
# rule([raggedness_per_invar], outvars) ->
|
||||
# rule([params], [raggedness_per_invar], outvars) ->
|
||||
# [raggedness_per_invar, raggedness_per_outvar]
|
||||
RaggedMaskingRule = Callable[
|
||||
[list[Any], list[Any]], tuple[list[Any], list[Any]]
|
||||
[list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]]
|
||||
]
|
||||
|
||||
ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {}
|
||||
|
||||
|
||||
def ragged_mask_elementwise_rule(invar_raggedness, outvars):
|
||||
def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars):
|
||||
# TODO(mvoz): A util for getting the ragged representations
|
||||
first_invar_raggedness = invar_raggedness[0]
|
||||
for other_invar_raggedness in invar_raggedness[1:]:
|
||||
@ -348,17 +355,19 @@ def ragged_mask_elementwise_rule(invar_raggedness, outvars):
|
||||
return invar_raggedness, outvar_raggedness
|
||||
|
||||
|
||||
def ragged_mask_assert_no_op_rule(invar_raggedness, outvars):
|
||||
def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars):
|
||||
if any(invar_raggedness):
|
||||
raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}')
|
||||
return invar_raggedness, [None] * len(outvars)
|
||||
|
||||
|
||||
def ragged_mask_no_op_rule(invar_raggedness, outvars):
|
||||
def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars):
|
||||
return invar_raggedness, [None] * len(outvars)
|
||||
|
||||
|
||||
def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness):
|
||||
def ragged_mask_transfer_identity(
|
||||
eqn_params, invar_raggedness, outvar_raggedness
|
||||
):
|
||||
assert len(invar_raggedness) == 1, invar_raggedness
|
||||
outvar_raggedness = invar_raggedness
|
||||
return invar_raggedness, outvar_raggedness
|
||||
|
@ -2296,6 +2296,7 @@ mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
|
||||
exp_p = standard_unop(_float | _complex, 'exp')
|
||||
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
||||
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
|
||||
batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
exp2_p = standard_unop(_float | _complex, 'exp2')
|
||||
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
|
||||
@ -2746,6 +2747,7 @@ sub_p = standard_naryop([_num, _num], 'sub')
|
||||
ad.primitive_jvps[sub_p] = _sub_jvp
|
||||
ad.primitive_transposes[sub_p] = _sub_transpose
|
||||
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract))
|
||||
batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
|
||||
def _mul_transpose(ct, x, y):
|
||||
@ -2767,6 +2769,7 @@ ad.defjvp(mul_p,
|
||||
lambda ydot, x, y: mul(x, ydot))
|
||||
ad.primitive_transposes[mul_p] = _mul_transpose
|
||||
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply))
|
||||
batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
def _div_transpose_rule(cotangent, x, y):
|
||||
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
||||
@ -2780,6 +2783,7 @@ ad.defjvp(div_p,
|
||||
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
|
||||
ad.primitive_transposes[div_p] = _div_transpose_rule
|
||||
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide))
|
||||
batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
|
||||
ad.defjvp(
|
||||
@ -2803,12 +2807,14 @@ ad.defjvp2(max_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo))
|
||||
batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
|
||||
ad.defjvp2(min_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
|
||||
batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
||||
ad.defjvp_zero(shift_left_p)
|
||||
@ -2895,6 +2901,7 @@ mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE", False))
|
||||
lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
|
||||
ad.defjvp_zero(lt_p)
|
||||
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False))
|
||||
batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to')
|
||||
ad.defjvp_zero(eq_to_p)
|
||||
@ -3536,12 +3543,37 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
||||
|
||||
|
||||
def _dot_general_ragged_prop_rule(invar_raggedness, outvars):
|
||||
def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
|
||||
assert len(invar_raggedness) == 2
|
||||
assert len(outvars) == 1
|
||||
invar_raggedness_lhs = invar_raggedness[0]
|
||||
invar_raggedness_rhs = invar_raggedness[1]
|
||||
|
||||
dimension_numbers = eqn_params['dimension_numbers']
|
||||
(lhs_contracting, rhs_contracting), (_, _) = dimension_numbers
|
||||
|
||||
if not invar_raggedness_lhs and not invar_raggedness_rhs:
|
||||
# Both are dense - it is valid to reach here, because dense operations
|
||||
# are legal in code running under ragged prop.
|
||||
return invar_raggedness, [None]
|
||||
|
||||
if not invar_raggedness_lhs or not invar_raggedness_rhs:
|
||||
# One ragged, one dense
|
||||
if not invar_raggedness_lhs:
|
||||
# left is dense, right is ragged
|
||||
_, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
|
||||
if rhs_contracting != ragged_axis_dim_rhs:
|
||||
# Contraction is on a dense dimension, this is valid!
|
||||
return invar_raggedness, [None]
|
||||
if not invar_raggedness_rhs:
|
||||
# left is ragged, right is dense
|
||||
_, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
|
||||
if lhs_contracting != ragged_axis_dim_lhs:
|
||||
# Contraction is on a dense dimension, this is valid!
|
||||
return invar_raggedness, [None]
|
||||
|
||||
raise NotImplementedError('NYI - dense and ragged dim contraction')
|
||||
|
||||
stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
|
||||
stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
|
||||
|
||||
@ -3560,9 +3592,8 @@ def _dot_general_ragged_prop_rule(invar_raggedness, outvars):
|
||||
assert len(outvars) == 1
|
||||
|
||||
# TODO(mvoz): A constant on batching.* ?
|
||||
dense_jumble_raggedness = None
|
||||
# Dense (m, n) - no jumble only atm
|
||||
return invar_raggedness, [dense_jumble_raggedness]
|
||||
return invar_raggedness, [None]
|
||||
|
||||
|
||||
dot_general_p = standard_primitive(
|
||||
@ -4205,7 +4236,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
|
||||
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)
|
||||
|
||||
|
||||
def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars):
|
||||
def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
|
||||
assert len(invar_raggedness) == 1
|
||||
assert not isinstance(invar_raggedness[0], core.Var)
|
||||
return invar_raggedness, [None] * len(outvars)
|
||||
@ -5040,6 +5071,7 @@ ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
||||
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
|
||||
_get_sum_identity)
|
||||
batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
||||
reducer = lambda x, y: [mul(x, y)]
|
||||
@ -5074,6 +5106,7 @@ ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_max_p, _get_max_identity)
|
||||
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
|
||||
_get_max_identity)
|
||||
batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
|
||||
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
||||
@ -5854,9 +5887,11 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding):
|
||||
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
||||
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
|
||||
|
||||
|
||||
iota_p = Primitive('iota')
|
||||
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
|
||||
iota_p.def_abstract_eval(_iota_abstract_eval)
|
||||
batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule
|
||||
|
||||
def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding):
|
||||
params = dict(dtype=dtype, shape=shape, dimension=dimension,
|
||||
|
@ -945,7 +945,15 @@ def _pallas_call_batching_rule(
|
||||
)
|
||||
for invar in eqn.invars
|
||||
]
|
||||
invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars)
|
||||
try:
|
||||
invar_raggedness, outvar_raggedness = rule(
|
||||
eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type]
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:"
|
||||
f" {eqn.outvars}. Underlying reason: {e}"
|
||||
) from e
|
||||
|
||||
for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment]
|
||||
if isinstance(invar, jax_core.Var):
|
||||
|
@ -1984,6 +1984,7 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
|
||||
|
||||
batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher
|
||||
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None)
|
||||
batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: sharding.Sharding | UnspecifiedValue,
|
||||
|
@ -127,7 +127,7 @@ swap_p = core.Primitive("swap")
|
||||
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
|
||||
|
||||
|
||||
def swap_ragged_prop_rule(invar_raggedness, outvars):
|
||||
def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
|
||||
assert len(invar_raggedness) == 2
|
||||
invar_raggedness_lhs = invar_raggedness[0]
|
||||
invar_raggedness_rhs = invar_raggedness[1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user