diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 27cde6d31..eb174cc5c 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -89,6 +89,17 @@ def _jumble_flatten(jumble): aval = jumble.aval.replace(elt_ty=elt_ty) return (lengths, jumble.data), aval + +def _ragged_axis_parts(dim: RaggedAxis) -> tuple[int, int, int]: + stacked_axis = dim.stacked_axis + ragged_axes = dim.ragged_axes + if len(ragged_axes) != 1: + raise ValueError('Multiple ragged axes not yet implemented.') + ragged_axis_dim = ragged_axes[0][0] + ragged_axis_length = ragged_axes[0][1] + return stacked_axis, ragged_axis_dim, ragged_axis_length + + def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) @@ -136,6 +147,7 @@ class RaggedAxis: new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes) return RaggedAxis(dst, new_axes) + def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis: new_ragged_axes = [] for idx, old_idx in enumerate(perm): @@ -315,6 +327,43 @@ def flatten_fun_for_vmap(in_tree, *args_flat): ans = yield py_args, py_kwargs yield tree_flatten(ans, is_leaf=is_vmappable) +# Propagate ragged masking rules from invars to outvars +# rule([raggedness_per_invar], outvars) -> +# [raggedness_per_invar, raggedness_per_outvar] +RaggedMaskingRule = Callable[ + [list[Any], list[Any]], tuple[list[Any], list[Any]] +] + +ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {} + + +def ragged_mask_elementwise_rule(invar_raggedness, outvars): + # TODO(mvoz): A util for getting the ragged representations + first_invar_raggedness = invar_raggedness[0] + for other_invar_raggedness in invar_raggedness[1:]: + if other_invar_raggedness != first_invar_raggedness: + raise ValueError(f'{other_invar_raggedness} != {first_invar_raggedness}') + + outvar_raggedness = [first_invar_raggedness] * len(outvars) + return invar_raggedness, outvar_raggedness + + +def ragged_mask_assert_no_op_rule(invar_raggedness, outvars): + if any(invar_raggedness): + raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}') + return invar_raggedness, [None] * len(outvars) + + +def ragged_mask_no_op_rule(invar_raggedness, outvars): + return invar_raggedness, [None] * len(outvars) + + +def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness): + assert len(invar_raggedness) == 1, invar_raggedness + outvar_raggedness = invar_raggedness + return invar_raggedness, outvar_raggedness + + ### tracer # TODO(mattjj): use a special sentinel type rather than None diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index e654ce953..c63414876 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -817,6 +817,7 @@ core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) core.axis_substitution_rules[cond_p] = _cond_axis_substitution pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.dce_rules[cond_p] = _cond_dce_rule +batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule def _cond_lowering(ctx, index, *args, branches): joined_effects = core.join_effects(*(branch.effects for branch in branches)) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7726e3861..fcf766357 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2330,6 +2330,8 @@ def _sin_lowering(ctx, x): sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) mlir.register_lowering(sin_p, _sin_lowering) +batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule + def _cos_complex(x): # cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x))) @@ -2677,6 +2679,7 @@ add_p: Primitive = standard_naryop([_num, _num], 'add') ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add)) +batching.ragged_prop_rules[add_p] = batching.ragged_mask_elementwise_rule def _sub_jvp(primals, tangents): x, y = primals @@ -2834,6 +2837,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True) ad.defjvp_zero(eq_p) mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False)) +batching.ragged_prop_rules[eq_p] = batching.ragged_mask_elementwise_rule ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True) ad.defjvp_zero(ne_p) @@ -2970,6 +2974,9 @@ pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule pe.def_trivial_padding(convert_element_type_p) core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule +batching.ragged_prop_rules[convert_element_type_p] = ( + batching.ragged_mask_elementwise_rule +) def _real_dtype(dtype): return np.finfo(dtype).dtype @@ -3459,9 +3466,42 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: (list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch))) return core._pp_eqn(eqn.replace(params=printed_params), context, settings) + +def _dot_general_ragged_prop_rule(invar_raggedness, outvars): + assert len(invar_raggedness) == 2 + assert len(outvars) == 1 + invar_raggedness_lhs = invar_raggedness[0] + invar_raggedness_rhs = invar_raggedness[1] + + stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs + stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs + + if stacked_axis_rhs != 0 or stacked_axis_lhs != 0: + raise NotImplementedError( + 'Dot general ragged prop for non 0 stacked axis, NYI' + ) + + # We only support ragged k atm, that is, lhs is (m, ragged_k) and rhs is + # (ragged_k, n), meaning the output is dense. + if ragged_axis_dim_lhs != 2 or ragged_axis_dim_rhs != 1: + raise NotImplementedError( + 'Dot general ragged prop for non contraction raggedness, NYI' + ) + + assert len(outvars) == 1 + + # TODO(mvoz): A constant on batching.* ? + dense_jumble_raggedness = None + # Dense (m, n) - no jumble only atm + return invar_raggedness, [dense_jumble_raggedness] + + dot_general_p = standard_primitive( - _dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general', - sharding_rule=_dot_general_sharding_rule) + _dot_general_shape_rule, + _dot_general_dtype_rule, + 'dot_general', + sharding_rule=_dot_general_sharding_rule, +) def _dot_general_batch_unpack_args(batch_args): @@ -3494,6 +3534,7 @@ _dot_general_batch_rule = functools.partial( batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule pe.padding_rules[dot_general_p] = _dot_general_padding_rule core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule +batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule def precision_attr(precision: Precision) -> ir.ArrayAttr: if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): @@ -4055,6 +4096,13 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions): # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type) + +def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars): + assert len(invar_raggedness) == 1 + assert not isinstance(invar_raggedness[0], core.Var) + return invar_raggedness, [None] * len(outvars) + + broadcast_in_dim_p = standard_primitive( _broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim') broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval) @@ -4067,6 +4115,9 @@ pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower) +batching.ragged_prop_rules[broadcast_in_dim_p] = ( + _broadcast_in_dim_ragged_prop_rule +) def _clamp_shape_rule(min, operand, max): @@ -4337,6 +4388,7 @@ squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) +batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. @@ -4958,6 +5010,7 @@ reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', weak_type_rule=_strip_weak_type) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) +batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 6e7aab7a1..4e76d7c30 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1243,6 +1243,9 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice') ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule +# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries +# or supporting nested jumbles. NYI. +batching.ragged_prop_rules[slice_p] = batching.ragged_mask_no_op_rule # Override the standard impl to defer to dynamic_slice whenever possible. # This lets us reuse the same program for many applications of slicing for as diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 44dad819b..800f5336b 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -279,6 +279,10 @@ def _pallas_call_impl_interpret( carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): + for s in scalars: + if isinstance(s.dtype, jax_core.bint): + aval = jax_core.get_aval(s) + s.aval = aval.update(dtype=jnp.int32) start_indices = [ None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) for bm in grid_mapping.block_mappings] @@ -293,10 +297,6 @@ def _pallas_call_impl_interpret( len(blocks), len(scratch_values), ) - for s in scalars: - aval = jax_core.get_aval(s) - if isinstance(aval, jax_core.DShapedArray): - s.aval = aval.update(dtype=jnp.int32) blocks = jax_core.eval_jaxpr( discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch @@ -437,10 +437,11 @@ ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule def _batch_block_mapping( grid_mapping: GridMapping, axis_size: int, + for_ragged: bool, aval: jax_core.ShapedArray, dim: int | batching.NotMapped, block_mapping: BlockMapping, - for_ragged: bool, + ragged_axis_values, ) -> BlockMapping: def _block_map_function(new_idx, *args): if for_ragged: @@ -466,14 +467,7 @@ def _batch_block_mapping( if for_ragged: if isinstance(dim, batching.RaggedAxis): assert for_ragged, "Ragged axis not supported for non-ragged batching." - _, _, ragged_axis_length = _ragged_axis_parts(dim) - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = pallas_core.AbstractMemoryRef( - aval, - pallas_core.MemorySpace.INDEX, - ) + _, _, _, lengths_aval = ragged_axis_values idx_avals = [*idx_avals, lengths_aval] else: i32_aval_memref = pallas_core.AbstractMemoryRef( @@ -538,6 +532,12 @@ def _broadcast_input_output_aliases( for input_index, _ in input_output_aliases: dim = dims_[input_index] dims_[input_index] = 0 + if isinstance(dim, batching.RaggedAxis): + stacked_axis = dim.stacked_axis + if stacked_axis != 0: + raise NotImplementedError("Ragged aliasing on non 0 dim NYI") + return tuple(args_), tuple(dims_) + if dim is batching.not_mapped: args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) elif dim != 0: @@ -643,16 +643,6 @@ def _batch_with_explicit_loop( return result, (0,) * len(result) -def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]: - stacked_axis = dim.stacked_axis - ragged_axes = dim.ragged_axes - if len(ragged_axes) != 1: - raise ValueError("Multiple ragged axes not yet implemented.") - ragged_axis_dim = ragged_axes[0][0] - ragged_axis_length = ragged_axes[0][1] - return stacked_axis, ragged_axis_dim, ragged_axis_length - - def _pallas_call_batching_rule( args, dims, @@ -674,21 +664,10 @@ def _pallas_call_batching_rule( return x return jnp.squeeze(x, axis=bdim) - all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)] - if len(all_ragged_axes) > 1: - raise ValueError("Multiple ragged dimensions not yet implemented.") - - if all_ragged_axes: - stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts( - all_ragged_axes[0] - ) - else: - stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None - def get_size(i, x, d): if not isinstance(d, batching.RaggedAxis): return x.shape[d] - return x.aval.shape[i] + return x.aval.shape[d.stacked_axis] (axis_size,) = { get_size(i=i, x=x, d=d) @@ -793,35 +772,49 @@ def _pallas_call_batching_rule( args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size ) + # Each dim either has data about its ragged axis, or None + ragged_axis_values = [] + for d in dims: + if isinstance(d, batching.RaggedAxis): + stacked_axis, ragged_axis_dim, ragged_axis_length = ( + batching._ragged_axis_parts(d) + ) + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + # TODO(mvoz): Give this its own type + ragged_axis_values.append( + (stacked_axis, ragged_axis_dim, ragged_axis_length, lengths_aval) + ) + else: + ragged_axis_values.append(None) # type: ignore[arg-type] + all_dims = list(dims) + [0] * grid_mapping.num_outputs + ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs # type: ignore[list-item] num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands - lengths_aval = None - if ragged_axis_length is not None: - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = pallas_core.AbstractMemoryRef( - aval, - pallas_core.MemorySpace.INDEX, - ) - # Only add a batch dimension for the avals that actually have a grid mapping. # This excludes scalar prefetch inputs (the first in the list) and scratch # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] + batched_block_mappings = map( partial( _batch_block_mapping, grid_mapping, axis_size, - for_ragged=lengths_aval is not None, + any(ragged_axis_values), ), avals_to_batch, all_dims[num_index_operands:], block_mappings, + ragged_axis_values[num_index_operands:], ) index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten( @@ -829,6 +822,17 @@ def _pallas_call_batching_rule( assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args + lengths_aval = None # type: ignore[assignment] + + # Check all the ragged axis values, ensure their raggedness pattern + # is identical (consider moving this check up!) + for rav in ragged_axis_values: + if rav is not None: + if lengths_aval is None: + lengths_aval = rav[3] + else: + assert lengths_aval == rav[3], "NYI - different lengths in ragged batch" + if lengths_aval: batched_index_map_args = batched_index_map_args + (lengths_aval,) num_index_operands += 1 @@ -854,6 +858,14 @@ def _pallas_call_batching_rule( else: batched_cost_estimate = None + # Start the ragged handling code + # Here, we: + # - Rewrite the indexer to save memory (skip indices outside the ragged bounds) + # - Rewrite the kernel to save compute (skip elements outside the ragged bounds) + # - Update various internal structures/metadata to account for the new + # block spec. + # - Set the hacky flag of ragged_originating on the mapping, to signal to + # the lowering code to treat mapped dimensions as part of the user grid. if lengths_aval: batched_grid_mapping = batched_grid_mapping.replace( get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, @@ -868,53 +880,102 @@ def _pallas_call_batching_rule( # we write 0s. This is useful for debugging, as certain lowering paths # like mosaic will write the last data as passthrough, leading to # potentially confusing results. - debug_zero_fill_counterfactual = debug - - first_block_mapping = batched_grid_mapping.block_mappings[0] + block_mapped_dim_idxs = [] for block_mapping in batched_grid_mapping.block_mappings: - # This invariant may already be checked elsewhere, but lets reaffirm it - assert block_mapping.block_shape == first_block_mapping.block_shape, ( - f"block_mapping.block_shape: {block_mapping.block_shape}, " - f"first_block_mapping.block_shape: {first_block_mapping.block_shape}" - ) - assert ( - block_mapping.array_shape_dtype - == first_block_mapping.array_shape_dtype - ), ( - f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype}," - " first_block_mapping.array_shape_dtype:" - f" {first_block_mapping.array_shape_dtype}" - ) + mapped_dim_idxs = [] + for i, d in enumerate(block_mapping.block_shape): + if d is pallas_core.mapped: + mapped_dim_idxs.append(i) + else: + mapped_dim_idxs.append(None) # type: ignore[arg-type] + block_mapped_dim_idxs.append(mapped_dim_idxs) - mapped_dim_idxs = [ - i - for i, d in enumerate(first_block_mapping.block_shape) - if d is pallas_core.mapped - ] - assert len(mapped_dim_idxs) == 1 - mapped_dim_idx = mapped_dim_idxs[0] - if stacked_axis != mapped_dim_idx: - raise ValueError( - f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}" - ) + mapped_dim_idx = None + for rav, mapped_dim_idxs in zip(ragged_axis_values, block_mapped_dim_idxs): + if rav is not None: + stacked_axis = rav[0] + if mapped_dim_idx is None: + mapped_dim_idx = mapped_dim_idxs[stacked_axis] + if mapped_dim_idxs[stacked_axis] is None: + raise ValueError( + f"Expected mapped dim to be {stacked_axis}, but got" + f" {mapped_dim_idxs[stacked_axis]}" + ) + else: + assert mapped_dim_idx == mapped_dim_idxs[stacked_axis], ( + f"Different mapped dims - expected {mapped_dim_idx}, but got" + f" {mapped_dim_idxs[stacked_axis]}" + ) - assert ragged_axis_dim is not None, "Invariant violation" # This is the blockspec size of the dimension - val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] + block_shapes = [b.block_shape for b in batched_grid_mapping.block_mappings] + + # Parse out the operations from the jaxpr to determine how to mask the output + # NOTE! while this *could* be a default dict of None, and None is sound, as + # it denotes that there is no raggedness for the given var, we explicitly + # do not do this, so as to get better signal on implementation of rules + # A misimplemented rule that does not account for new vars being introduced + # will result in an error on the next op using the new var. The benefit of + # of forcing implementers to account for all outputs and intermediaries is + # a very nice one. + + var_to_raggedness = {} + for invar, rav in zip(jaxpr.invars, ragged_axis_values): + var_to_raggedness[invar] = rav + + for eqn in jaxpr.eqns: + prim = eqn.primitive + if prim not in batching.ragged_prop_rules: + raise NotImplementedError(f"Not implemented - ragged prop for {prim}") + rule = batching.ragged_prop_rules[prim] + + invar_raggedness = [ + ( + var_to_raggedness.get(invar, None) + if isinstance(invar, jax_core.Var) + else None + ) + for invar in eqn.invars + ] + invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars) + + for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment] + if isinstance(invar, jax_core.Var): + var_to_raggedness[invar] = rav + for outvar, rav in zip(eqn.outvars, outvar_raggedness): + if isinstance(outvar, jax_core.Var): + var_to_raggedness[outvar] = rav + + for pos, invar in enumerate(jaxpr.invars): + ragged_axis_values[pos] = var_to_raggedness[invar] + + per_input_ragged_axis_dim = [] + for rav in ragged_axis_values: + if rav is not None: + per_input_ragged_axis_dim.append(rav[1]) + else: + per_input_ragged_axis_dim.append(None) def when_wrapped_kernel(lengths_ref, *args, **kwargs): - b_idx = primitives.program_id(stacked_axis) - i_idx = ( - primitives.program_id(ragged_axis_dim) - * val_at_ragged_dim - ) + b_idx = primitives.program_id(mapped_dim_idx) + b_len = lengths_ref[b_idx] + run_kernel = jnp.array(True) + for i, _ in enumerate(args): + ragged_axis_dim = per_input_ragged_axis_dim[i] + if ragged_axis_dim is None: + continue + arg_i_idx = ( + primitives.program_id(ragged_axis_dim) + * block_shapes[i][ragged_axis_dim] + ) + run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len) # TODO(mvoz): Unimplemented primitive in pallas # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") - @pallas_utils.when(i_idx < b_len) + @pallas_utils.when(run_kernel) def f(): # Important! This allows us to trace the inner kernel with the correct # grid to preserve user program_id semantics. Ex: program_id(0) will @@ -922,20 +983,108 @@ def _pallas_call_batching_rule( with pallas_core.tracing_grid_env(grid_mapping.grid, ()): jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) - if debug_zero_fill_counterfactual: - - @pallas_utils.when(i_idx >= b_len) - def g(): - for arg_ref in args: - arg_ref[...] = jnp.zeros_like(arg_ref) - kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( list(kernel_avals) ) + + def _rewrite_index_jaxpr(enumerate_batched_block_mapping): + arg_pos, batched_block_mapping = enumerate_batched_block_mapping + indexer_avals = [ + v.aval for v in batched_block_mapping.index_map_jaxpr.jaxpr.invars + ] + flat_indexer_avals, indexer_in_tree = tree_util.tree_flatten( + list(indexer_avals) + ) + + def index_rewrite_kernel(*indexer_args): + ragged_axis_dim = per_input_ragged_axis_dim[arg_pos] + + # the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means + # that the indexer for output isnt getting written - before, it always was + + lengths_ref = indexer_args[-1] + rest_indexer_args = indexer_args[:-1] + # Lengths are always the last argument of the indexer. + # lengths_ref = args[-1] + # Invariant: Stacked axis is enforced to be the mapped axis above. + b_idx = indexer_args[mapped_dim_idx] + + nargs = list(rest_indexer_args) + + if ragged_axis_dim is not None: + val_at_ragged_dim = batched_block_mapping.block_shape[ragged_axis_dim] + + # The current index into the ragged dimension. + # Invariant: There is only one ragged dimension, enforced above. + i_idx = indexer_args[ragged_axis_dim] + + # grid space -> element space + i_len = i_idx * val_at_ragged_dim + + # The length of the current batch. + b_len = lengths_ref[b_idx] + + # Have we reached the end of the current batch? + not_done = i_len < b_len + + am_last_batch = b_idx == axis_size - 1 + last_good_block = lax.div(b_len, val_at_ragged_dim) - 1 + + # The logic below can be thought of as: + # if index_oob_ragged: + # if not last_batch: + # batch_idx += 1 + # ragged_idx = 0 + # else: + # ragged_idx = last_good_block + # + # wherein we find the next good block by incrementing the batch index + # and setting the ragged index to 0 if we are not in the last batch. + # Otherwise, we set the ragged index to the last good block. + b_next = jnp.where( + not_done, b_idx, jnp.where(am_last_batch, b_idx, b_idx + 1) + ) + i_next = jnp.where( + not_done, i_idx, jnp.where(am_last_batch, last_good_block, 0) + ) + nargs[ragged_axis_dim] = i_next + nargs[mapped_dim_idx] = b_next + + nargs = nargs + [lengths_ref] + return jax_core.eval_jaxpr( + batched_block_mapping.index_map_jaxpr.jaxpr, + batched_block_mapping.index_map_jaxpr.consts, + *nargs, + ) + + index_jaxpr, _ = _trace_kernel_to_jaxpr( + index_rewrite_kernel, + "index_rewrite_kernel", + batched_grid_mapping, + tuple(flat_indexer_avals), + indexer_in_tree, + tuple(() for _ in flat_indexer_avals), + interpret=interpret, + indexer=True, + ) + + batched_block_mapping = batched_block_mapping.replace( + index_map_jaxpr=pe.close_jaxpr(index_jaxpr) + ) + return batched_block_mapping + # Important! This allows us to trace the outer kernel with the correct grid # to enable accessing the batch program_id. with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): + batched_block_mappings = map( + _rewrite_index_jaxpr, enumerate(batched_block_mappings) + ) + + batched_grid_mapping = batched_grid_mapping.replace( + block_mappings=tuple(batched_block_mappings), + ) + kernel_src_info: pallas_core.SrcInfoStr = "" jaxpr, consts = _trace_kernel_to_jaxpr( @@ -950,7 +1099,15 @@ def _pallas_call_batching_rule( if consts: raise NotImplementedError("consts not supported in pallas_call") - assert ragged_axis_length is not None + # We need to rewrite the input_output_aliases here, the initial call + # to broadcast is done, and we have inseted a new input (lengths), so + # there's an off-by-one here now. + new_input_output_aliases = [] + for k, v in input_output_aliases: + new_input_output_aliases.append((k + 1, v)) + input_output_aliases = tuple(new_input_output_aliases) + + # assert ragged_axis_length is not None args = (ragged_axis_length, *args) assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals) batched_out_avals = tuple( @@ -1225,7 +1382,6 @@ def pallas_call_checkify_rule(error: checkify.Error, return new_error, results checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule - # All of those shenanigans are because we can't make TransformedRef a PyTree, # because they should appear as atomic JAX values to the users. @lu.transformation @@ -1247,6 +1403,7 @@ def _trace_kernel_to_jaxpr( kernel_in_tree: tree_util.PyTreeDef, kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...], interpret: bool, + indexer: bool = False, ) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]: if interpret: kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval, @@ -1268,7 +1425,7 @@ def _trace_kernel_to_jaxpr( "You should pass them as inputs") kernel_out_tree = out_tree_thunk() - if kernel_out_tree != tree_util.tree_structure(None): + if not indexer and kernel_out_tree != tree_util.tree_structure(None): raise ValueError( f"The kernel function in the pallas_call {name_and_src_info} " f"should return None. It returns a PyTree: {kernel_out_tree}") @@ -1673,6 +1830,18 @@ def pallas_call( f"[0, {len(flat_out_avals)})") in_aval = flat_in_avals[i_idx] out_aval = flat_out_avals[o_idx] + if isinstance(in_aval, jax_core.DShapedArray): + new_shape = [] + for d in in_aval.shape: + if isinstance(d, int): + new_shape.append(d) + else: + new_shape.append(d.dtype.bound) + + in_aval = jax_core.ShapedArray( + tuple(new_shape), in_aval.dtype, in_aval.weak_type + ) + if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype: raise ValueError( f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 3bf815cd3..9e446917b 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -51,6 +51,7 @@ map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip program_id_p = jax_core.Primitive("program_id") +batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule def program_id(axis: int) -> jax.Array: """Returns the kernel execution position along the given axis of the grid. diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index b91c2a13c..112399d6a 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -63,6 +63,7 @@ traceback_util.register_exclusion(__file__) # a:f32[3] <- x[] get_p = core.Primitive("get") get_p.def_impl(partial(dispatch.apply_primitive, get_p)) +batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity Indexer = Union[int, slice, Array, types.EllipsisType] @@ -122,6 +123,16 @@ swap_p = core.Primitive("swap") swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) +def swap_ragged_prop_rule(invar_raggedness, outvars): + assert len(invar_raggedness) == 2 + invar_raggedness_lhs = invar_raggedness[0] + invar_raggedness_rhs = invar_raggedness[1] + + return [invar_raggedness_rhs, invar_raggedness_lhs], [None] + + +batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule + def ref_swap( ref_or_view: AbstractRef | TransformedRef, idx: Indexer | tuple[Indexer, ...] | None, diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py index 5ed15fe96..f26352da0 100644 --- a/tests/pallas/pallas_jumble_test.py +++ b/tests/pallas/pallas_jumble_test.py @@ -111,8 +111,81 @@ class PallasCallRaggedVmapTest(PallasBaseTest): ragged_total = 0 for dim in ragged_shape: ragged_total += row_count * dim * 128 - # See note - on zero filling counterfactuals - self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total) + + def correct(v): + return np.count_nonzero(v == jnp.sin(1.0)) + + for b, batch in enumerate(res): + ragged_val = ragged_shape[b] + for r, row in enumerate(batch): + row_total = ragged_val * 128 + self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}") + + self.assertEqual(correct(res), ragged_total) + + def test_vmap_jumble_over_add_kernel(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + y = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + + def invoke_kernel(x, y): + return pl.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec((8, 128), lambda j, k: (j, k)), + pl.BlockSpec((8, 128), lambda j, k: (j, k)), + ], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct( + (8, col_grid_size * 128), dtype=jnp.float32 + ), + grid=(1, col_grid_size), + interpret=False, + )(x, y) + + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) + + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + ragged_total = 0 + for dim in ragged_shape: + ragged_total += row_count * dim * 128 + + def correct(v): + return np.count_nonzero(v == 2.0) + + for r, row in enumerate(res): + ragged_val = ragged_shape[r] + row_total = ragged_val * 128 * row_count + self.assertEqual(correct(row), row_total) + for col in row: + col_total = ragged_val * 128 + self.assertEqual(correct(col), col_total) + + self.assertEqual(np.count_nonzero(res == 2.0), ragged_total) def test_vmap_jumble_over_sin_kernel_grid_remapping(self): if not jtu.test_device_matches(["tpu"]): @@ -150,6 +223,97 @@ class PallasCallRaggedVmapTest(PallasBaseTest): axis_size=3, )(x) + def test_vmap_jumble_over_matmul_kernel(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + m = 128 + k = 640 + n = 640 + + def matmul_kernel(x_ref, y_ref, x_sentinel, z_ref): + # weird little once-only reset + @pl.when(x_sentinel[...][0][0] == 1.0) + def _(): + z_ref[...] = jnp.zeros_like(z_ref) + x_sentinel[...] = jnp.zeros_like(x_sentinel) + + z_ref[...] += x_ref[...] @ y_ref[...] + + def matmul( + x: jax.Array, + y: jax.Array, + x_sentinel: jax.Array, + *, + bm: int = 128, + bk: int = 128, + bn: int = 640, + ): + # m, k = x.shape + # _, n = y.shape + # a (1, 5) grid + # TODO(mvoz): parameterize this grid? + grid = (n // bn, k // bk) + return pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + in_specs=[ + pl.BlockSpec( + (bm, bk), + lambda j, k: (0, k), + ), + pl.BlockSpec( + (bk, bn), + lambda j, k: (k, j), + ), + pl.BlockSpec( + (bm, bn), + lambda j, k: (0, j), + ), + ], + out_specs=pl.BlockSpec( + (bm, bn), + lambda j, k: (0, j), + ), + grid=grid, + input_output_aliases={2: 0}, + interpret=False, + )(x, y, x_sentinel) + + # TODO(mvoz): parameterize this shape? + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(k), + ) + x = jax.vmap(lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis)( + sizes + ) + x_sentinel = jax.vmap( + lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis + )(sizes) + y = jax.vmap(lambda k_: jnp.ones((k_, n)), out_axes=batching.jumble_axis)( + sizes + ) + + res = jax.vmap( + matmul, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y, x_sentinel) + + ref = jax.vmap( + jnp.dot, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) + + ref = ref.data + res = res.data + np.testing.assert_allclose(ref, res) + def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Only tested on TPU")