From 6f371212d972a2017fb58e621268e446d33e3235 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 25 Oct 2024 12:06:59 -0700 Subject: [PATCH] Implements an alternate version of ragged_attention, wherein, the actual attention kernel itself is dense. Meaning, this kernel does not have the compute saving (@when wrapped kernel) or prefetch/index skipping (via index rewriting) as part of the kernel. Rather, the kernel is invoked with a Jumble (A ragged type representation) and pallas takes care of applying the correct work skipping and index rewriting. Performance wise, we should be at parity, although this has not yet been tested. Authoring wise, the new kernel is significantly smaller and simpler to write. A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs. PiperOrigin-RevId: 689868468 --- jax/_src/interpreters/batching.py | 21 ++++++++++----- jax/_src/lax/lax.py | 43 ++++++++++++++++++++++++++++--- jax/_src/pallas/pallas_call.py | 10 ++++++- jax/_src/pjit.py | 1 + jax/_src/state/primitives.py | 2 +- 5 files changed, 65 insertions(+), 12 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index eb174cc5c..b40a3807d 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -264,10 +264,17 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: + if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis): + # TODO(mvoz): A vaguely questionable assumption that it is always + # sound to have a 0 axis here. This is true for the current use cases + # and comes from how we handle intermediary products of jumbles in + # vmap. + return BatchTracer(trace, x, 0, source_info_util.current()) # TODO(mvoz): This is a terrible place to fall into if you pass # a non jumble type in, make it clearer what went wrong. assert False, f'Unexpected type in ELT? {type(x)}' + to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, @@ -328,16 +335,16 @@ def flatten_fun_for_vmap(in_tree, *args_flat): yield tree_flatten(ans, is_leaf=is_vmappable) # Propagate ragged masking rules from invars to outvars -# rule([raggedness_per_invar], outvars) -> +# rule([params], [raggedness_per_invar], outvars) -> # [raggedness_per_invar, raggedness_per_outvar] RaggedMaskingRule = Callable[ - [list[Any], list[Any]], tuple[list[Any], list[Any]] + [list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]] ] ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {} -def ragged_mask_elementwise_rule(invar_raggedness, outvars): +def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars): # TODO(mvoz): A util for getting the ragged representations first_invar_raggedness = invar_raggedness[0] for other_invar_raggedness in invar_raggedness[1:]: @@ -348,17 +355,19 @@ def ragged_mask_elementwise_rule(invar_raggedness, outvars): return invar_raggedness, outvar_raggedness -def ragged_mask_assert_no_op_rule(invar_raggedness, outvars): +def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars): if any(invar_raggedness): raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}') return invar_raggedness, [None] * len(outvars) -def ragged_mask_no_op_rule(invar_raggedness, outvars): +def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars): return invar_raggedness, [None] * len(outvars) -def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness): +def ragged_mask_transfer_identity( + eqn_params, invar_raggedness, outvar_raggedness +): assert len(invar_raggedness) == 1, invar_raggedness outvar_raggedness = invar_raggedness return invar_raggedness, outvar_raggedness diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ebcc5aac4..9ed1d55cd 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2296,6 +2296,7 @@ mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) +batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule exp2_p = standard_unop(_float | _complex, 'exp2') ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) @@ -2746,6 +2747,7 @@ sub_p = standard_naryop([_num, _num], 'sub') ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract)) +batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule def _mul_transpose(ct, x, y): @@ -2767,6 +2769,7 @@ ad.defjvp(mul_p, lambda ydot, x, y: mul(x, ydot)) ad.primitive_transposes[mul_p] = _mul_transpose mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply)) +batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -2780,6 +2783,7 @@ ad.defjvp(div_p, lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2))) ad.primitive_transposes[div_p] = _div_transpose_rule mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide)) +batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( @@ -2803,12 +2807,14 @@ ad.defjvp2(max_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo)) +batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule min_p: core.Primitive = standard_naryop([_any, _any], 'min') ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo)) +batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) @@ -2895,6 +2901,7 @@ mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE", False)) lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt') ad.defjvp_zero(lt_p) mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False)) +batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to') ad.defjvp_zero(eq_to_p) @@ -3536,12 +3543,37 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -def _dot_general_ragged_prop_rule(invar_raggedness, outvars): +def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 2 assert len(outvars) == 1 invar_raggedness_lhs = invar_raggedness[0] invar_raggedness_rhs = invar_raggedness[1] + dimension_numbers = eqn_params['dimension_numbers'] + (lhs_contracting, rhs_contracting), (_, _) = dimension_numbers + + if not invar_raggedness_lhs and not invar_raggedness_rhs: + # Both are dense - it is valid to reach here, because dense operations + # are legal in code running under ragged prop. + return invar_raggedness, [None] + + if not invar_raggedness_lhs or not invar_raggedness_rhs: + # One ragged, one dense + if not invar_raggedness_lhs: + # left is dense, right is ragged + _, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs + if rhs_contracting != ragged_axis_dim_rhs: + # Contraction is on a dense dimension, this is valid! + return invar_raggedness, [None] + if not invar_raggedness_rhs: + # left is ragged, right is dense + _, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs + if lhs_contracting != ragged_axis_dim_lhs: + # Contraction is on a dense dimension, this is valid! + return invar_raggedness, [None] + + raise NotImplementedError('NYI - dense and ragged dim contraction') + stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs @@ -3560,9 +3592,8 @@ def _dot_general_ragged_prop_rule(invar_raggedness, outvars): assert len(outvars) == 1 # TODO(mvoz): A constant on batching.* ? - dense_jumble_raggedness = None # Dense (m, n) - no jumble only atm - return invar_raggedness, [dense_jumble_raggedness] + return invar_raggedness, [None] dot_general_p = standard_primitive( @@ -4205,7 +4236,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type) -def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars): +def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 1 assert not isinstance(invar_raggedness[0], core.Var) return invar_raggedness, [None] * len(outvars) @@ -5040,6 +5071,7 @@ ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum, _get_sum_identity) +batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule def _reduce_prod_jvp_rule(primals, tangents, *, axes): reducer = lambda x, y: [mul(x, y)] @@ -5074,6 +5106,7 @@ ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max, _get_max_identity) +batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype, @@ -5854,9 +5887,11 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) + iota_p = Primitive('iota') iota_p.def_impl(partial(dispatch.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) +batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding): params = dict(dtype=dtype, shape=shape, dimension=dimension, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2bed4a083..e20c77834 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -945,7 +945,15 @@ def _pallas_call_batching_rule( ) for invar in eqn.invars ] - invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars) + try: + invar_raggedness, outvar_raggedness = rule( + eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type] + ) + except Exception as e: + raise RuntimeError( + f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:" + f" {eqn.outvars}. Underlying reason: {e}" + ) from e for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment] if isinstance(invar, jax_core.Var): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 2abf81f26..5b8856aa6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1984,6 +1984,7 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None) +batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule def _pjit_batcher_for_sharding( s: sharding.Sharding | UnspecifiedValue, diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index a0f70a126..7724466d3 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -127,7 +127,7 @@ swap_p = core.Primitive("swap") swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) -def swap_ragged_prop_rule(invar_raggedness, outvars): +def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 2 invar_raggedness_lhs = invar_raggedness[0] invar_raggedness_rhs = invar_raggedness[1]