mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op. PiperOrigin-RevId: 685827096
This commit is contained in:
parent
fff3b8747f
commit
1f0b5728a4
@ -89,6 +89,17 @@ def _jumble_flatten(jumble):
|
||||
aval = jumble.aval.replace(elt_ty=elt_ty)
|
||||
return (lengths, jumble.data), aval
|
||||
|
||||
|
||||
def _ragged_axis_parts(dim: RaggedAxis) -> tuple[int, int, int]:
|
||||
stacked_axis = dim.stacked_axis
|
||||
ragged_axes = dim.ragged_axes
|
||||
if len(ragged_axes) != 1:
|
||||
raise ValueError('Multiple ragged axes not yet implemented.')
|
||||
ragged_axis_dim = ragged_axes[0][0]
|
||||
ragged_axis_length = ragged_axes[0][1]
|
||||
return stacked_axis, ragged_axis_dim, ragged_axis_length
|
||||
|
||||
|
||||
def _jumble_unflatten(aval, x):
|
||||
lengths, data = x
|
||||
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
|
||||
@ -136,6 +147,7 @@ class RaggedAxis:
|
||||
new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes)
|
||||
return RaggedAxis(dst, new_axes)
|
||||
|
||||
|
||||
def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis:
|
||||
new_ragged_axes = []
|
||||
for idx, old_idx in enumerate(perm):
|
||||
@ -315,6 +327,43 @@ def flatten_fun_for_vmap(in_tree, *args_flat):
|
||||
ans = yield py_args, py_kwargs
|
||||
yield tree_flatten(ans, is_leaf=is_vmappable)
|
||||
|
||||
# Propagate ragged masking rules from invars to outvars
|
||||
# rule([raggedness_per_invar], outvars) ->
|
||||
# [raggedness_per_invar, raggedness_per_outvar]
|
||||
RaggedMaskingRule = Callable[
|
||||
[list[Any], list[Any]], tuple[list[Any], list[Any]]
|
||||
]
|
||||
|
||||
ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {}
|
||||
|
||||
|
||||
def ragged_mask_elementwise_rule(invar_raggedness, outvars):
|
||||
# TODO(mvoz): A util for getting the ragged representations
|
||||
first_invar_raggedness = invar_raggedness[0]
|
||||
for other_invar_raggedness in invar_raggedness[1:]:
|
||||
if other_invar_raggedness != first_invar_raggedness:
|
||||
raise ValueError(f'{other_invar_raggedness} != {first_invar_raggedness}')
|
||||
|
||||
outvar_raggedness = [first_invar_raggedness] * len(outvars)
|
||||
return invar_raggedness, outvar_raggedness
|
||||
|
||||
|
||||
def ragged_mask_assert_no_op_rule(invar_raggedness, outvars):
|
||||
if any(invar_raggedness):
|
||||
raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}')
|
||||
return invar_raggedness, [None] * len(outvars)
|
||||
|
||||
|
||||
def ragged_mask_no_op_rule(invar_raggedness, outvars):
|
||||
return invar_raggedness, [None] * len(outvars)
|
||||
|
||||
|
||||
def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness):
|
||||
assert len(invar_raggedness) == 1, invar_raggedness
|
||||
outvar_raggedness = invar_raggedness
|
||||
return invar_raggedness, outvar_raggedness
|
||||
|
||||
|
||||
### tracer
|
||||
|
||||
# TODO(mattjj): use a special sentinel type rather than None
|
||||
|
@ -817,6 +817,7 @@ core.custom_typechecks[cond_p] = partial(_cond_typecheck, False)
|
||||
core.axis_substitution_rules[cond_p] = _cond_axis_substitution
|
||||
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
|
||||
pe.dce_rules[cond_p] = _cond_dce_rule
|
||||
batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule
|
||||
|
||||
def _cond_lowering(ctx, index, *args, branches):
|
||||
joined_effects = core.join_effects(*(branch.effects for branch in branches))
|
||||
|
@ -2330,6 +2330,8 @@ def _sin_lowering(ctx, x):
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
mlir.register_lowering(sin_p, _sin_lowering)
|
||||
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
|
||||
def _cos_complex(x):
|
||||
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
|
||||
@ -2677,6 +2679,7 @@ add_p: Primitive = standard_naryop([_num, _num], 'add')
|
||||
ad.primitive_jvps[add_p] = _add_jvp
|
||||
ad.primitive_transposes[add_p] = _add_transpose
|
||||
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add))
|
||||
batching.ragged_prop_rules[add_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
def _sub_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
@ -2834,6 +2837,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
|
||||
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
|
||||
ad.defjvp_zero(eq_p)
|
||||
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False))
|
||||
batching.ragged_prop_rules[eq_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True)
|
||||
ad.defjvp_zero(ne_p)
|
||||
@ -2970,6 +2974,9 @@ pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
|
||||
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
|
||||
pe.def_trivial_padding(convert_element_type_p)
|
||||
core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
|
||||
batching.ragged_prop_rules[convert_element_type_p] = (
|
||||
batching.ragged_mask_elementwise_rule
|
||||
)
|
||||
|
||||
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
||||
|
||||
@ -3459,9 +3466,42 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
|
||||
(list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch)))
|
||||
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
||||
|
||||
|
||||
def _dot_general_ragged_prop_rule(invar_raggedness, outvars):
|
||||
assert len(invar_raggedness) == 2
|
||||
assert len(outvars) == 1
|
||||
invar_raggedness_lhs = invar_raggedness[0]
|
||||
invar_raggedness_rhs = invar_raggedness[1]
|
||||
|
||||
stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
|
||||
stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
|
||||
|
||||
if stacked_axis_rhs != 0 or stacked_axis_lhs != 0:
|
||||
raise NotImplementedError(
|
||||
'Dot general ragged prop for non 0 stacked axis, NYI'
|
||||
)
|
||||
|
||||
# We only support ragged k atm, that is, lhs is (m, ragged_k) and rhs is
|
||||
# (ragged_k, n), meaning the output is dense.
|
||||
if ragged_axis_dim_lhs != 2 or ragged_axis_dim_rhs != 1:
|
||||
raise NotImplementedError(
|
||||
'Dot general ragged prop for non contraction raggedness, NYI'
|
||||
)
|
||||
|
||||
assert len(outvars) == 1
|
||||
|
||||
# TODO(mvoz): A constant on batching.* ?
|
||||
dense_jumble_raggedness = None
|
||||
# Dense (m, n) - no jumble only atm
|
||||
return invar_raggedness, [dense_jumble_raggedness]
|
||||
|
||||
|
||||
dot_general_p = standard_primitive(
|
||||
_dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general',
|
||||
sharding_rule=_dot_general_sharding_rule)
|
||||
_dot_general_shape_rule,
|
||||
_dot_general_dtype_rule,
|
||||
'dot_general',
|
||||
sharding_rule=_dot_general_sharding_rule,
|
||||
)
|
||||
|
||||
|
||||
def _dot_general_batch_unpack_args(batch_args):
|
||||
@ -3494,6 +3534,7 @@ _dot_general_batch_rule = functools.partial(
|
||||
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
||||
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
|
||||
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
|
||||
batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule
|
||||
|
||||
def precision_attr(precision: Precision) -> ir.ArrayAttr:
|
||||
if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)):
|
||||
@ -4055,6 +4096,13 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
|
||||
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
||||
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)
|
||||
|
||||
|
||||
def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars):
|
||||
assert len(invar_raggedness) == 1
|
||||
assert not isinstance(invar_raggedness[0], core.Var)
|
||||
return invar_raggedness, [None] * len(outvars)
|
||||
|
||||
|
||||
broadcast_in_dim_p = standard_primitive(
|
||||
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
||||
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
|
||||
@ -4067,6 +4115,9 @@ pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
|
||||
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
|
||||
core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule
|
||||
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
|
||||
batching.ragged_prop_rules[broadcast_in_dim_p] = (
|
||||
_broadcast_in_dim_ragged_prop_rule
|
||||
)
|
||||
|
||||
|
||||
def _clamp_shape_rule(min, operand, max):
|
||||
@ -4337,6 +4388,7 @@ squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
|
||||
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
|
||||
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
||||
pe.def_trivial_padding(squeeze_p)
|
||||
batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule
|
||||
|
||||
def _squeeze_lower(ctx, operand, *, dimensions):
|
||||
del dimensions # Implied by the output aval.
|
||||
@ -4958,6 +5010,7 @@ reduce_and_p = standard_primitive(
|
||||
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
|
||||
weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
|
||||
batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
|
||||
reduce_xor_p = standard_primitive(
|
||||
|
@ -1243,6 +1243,9 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
|
||||
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice')
|
||||
ad.deflinear2(slice_p, _slice_transpose_rule)
|
||||
batching.primitive_batchers[slice_p] = _slice_batching_rule
|
||||
# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries
|
||||
# or supporting nested jumbles. NYI.
|
||||
batching.ragged_prop_rules[slice_p] = batching.ragged_mask_no_op_rule
|
||||
|
||||
# Override the standard impl to defer to dynamic_slice whenever possible.
|
||||
# This lets us reuse the same program for many applications of slicing for as
|
||||
|
@ -279,6 +279,10 @@ def _pallas_call_impl_interpret(
|
||||
|
||||
carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks])
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
for s in scalars:
|
||||
if isinstance(s.dtype, jax_core.bint):
|
||||
aval = jax_core.get_aval(s)
|
||||
s.aval = aval.update(dtype=jnp.int32)
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
@ -293,10 +297,6 @@ def _pallas_call_impl_interpret(
|
||||
len(blocks),
|
||||
len(scratch_values),
|
||||
)
|
||||
for s in scalars:
|
||||
aval = jax_core.get_aval(s)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
s.aval = aval.update(dtype=jnp.int32)
|
||||
|
||||
blocks = jax_core.eval_jaxpr(
|
||||
discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch
|
||||
@ -437,10 +437,11 @@ ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
||||
def _batch_block_mapping(
|
||||
grid_mapping: GridMapping,
|
||||
axis_size: int,
|
||||
for_ragged: bool,
|
||||
aval: jax_core.ShapedArray,
|
||||
dim: int | batching.NotMapped,
|
||||
block_mapping: BlockMapping,
|
||||
for_ragged: bool,
|
||||
ragged_axis_values,
|
||||
) -> BlockMapping:
|
||||
def _block_map_function(new_idx, *args):
|
||||
if for_ragged:
|
||||
@ -466,14 +467,7 @@ def _batch_block_mapping(
|
||||
if for_ragged:
|
||||
if isinstance(dim, batching.RaggedAxis):
|
||||
assert for_ragged, "Ragged axis not supported for non-ragged batching."
|
||||
_, _, ragged_axis_length = _ragged_axis_parts(dim)
|
||||
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
|
||||
lengths_aval = pallas_core.AbstractMemoryRef(
|
||||
aval,
|
||||
pallas_core.MemorySpace.INDEX,
|
||||
)
|
||||
_, _, _, lengths_aval = ragged_axis_values
|
||||
idx_avals = [*idx_avals, lengths_aval]
|
||||
else:
|
||||
i32_aval_memref = pallas_core.AbstractMemoryRef(
|
||||
@ -538,6 +532,12 @@ def _broadcast_input_output_aliases(
|
||||
for input_index, _ in input_output_aliases:
|
||||
dim = dims_[input_index]
|
||||
dims_[input_index] = 0
|
||||
if isinstance(dim, batching.RaggedAxis):
|
||||
stacked_axis = dim.stacked_axis
|
||||
if stacked_axis != 0:
|
||||
raise NotImplementedError("Ragged aliasing on non 0 dim NYI")
|
||||
return tuple(args_), tuple(dims_)
|
||||
|
||||
if dim is batching.not_mapped:
|
||||
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
|
||||
elif dim != 0:
|
||||
@ -643,16 +643,6 @@ def _batch_with_explicit_loop(
|
||||
return result, (0,) * len(result)
|
||||
|
||||
|
||||
def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]:
|
||||
stacked_axis = dim.stacked_axis
|
||||
ragged_axes = dim.ragged_axes
|
||||
if len(ragged_axes) != 1:
|
||||
raise ValueError("Multiple ragged axes not yet implemented.")
|
||||
ragged_axis_dim = ragged_axes[0][0]
|
||||
ragged_axis_length = ragged_axes[0][1]
|
||||
return stacked_axis, ragged_axis_dim, ragged_axis_length
|
||||
|
||||
|
||||
def _pallas_call_batching_rule(
|
||||
args,
|
||||
dims,
|
||||
@ -674,21 +664,10 @@ def _pallas_call_batching_rule(
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
|
||||
all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)]
|
||||
if len(all_ragged_axes) > 1:
|
||||
raise ValueError("Multiple ragged dimensions not yet implemented.")
|
||||
|
||||
if all_ragged_axes:
|
||||
stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts(
|
||||
all_ragged_axes[0]
|
||||
)
|
||||
else:
|
||||
stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None
|
||||
|
||||
def get_size(i, x, d):
|
||||
if not isinstance(d, batching.RaggedAxis):
|
||||
return x.shape[d]
|
||||
return x.aval.shape[i]
|
||||
return x.aval.shape[d.stacked_axis]
|
||||
|
||||
(axis_size,) = {
|
||||
get_size(i=i, x=x, d=d)
|
||||
@ -793,35 +772,49 @@ def _pallas_call_batching_rule(
|
||||
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
|
||||
)
|
||||
|
||||
# Each dim either has data about its ragged axis, or None
|
||||
ragged_axis_values = []
|
||||
for d in dims:
|
||||
if isinstance(d, batching.RaggedAxis):
|
||||
stacked_axis, ragged_axis_dim, ragged_axis_length = (
|
||||
batching._ragged_axis_parts(d)
|
||||
)
|
||||
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
|
||||
lengths_aval = pallas_core.AbstractMemoryRef(
|
||||
aval,
|
||||
pallas_core.MemorySpace.INDEX,
|
||||
)
|
||||
# TODO(mvoz): Give this its own type
|
||||
ragged_axis_values.append(
|
||||
(stacked_axis, ragged_axis_dim, ragged_axis_length, lengths_aval)
|
||||
)
|
||||
else:
|
||||
ragged_axis_values.append(None) # type: ignore[arg-type]
|
||||
|
||||
all_dims = list(dims) + [0] * grid_mapping.num_outputs
|
||||
ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs # type: ignore[list-item]
|
||||
|
||||
num_index_operands = grid_mapping.num_index_operands
|
||||
num_scratch_operands = grid_mapping.num_scratch_operands
|
||||
|
||||
lengths_aval = None
|
||||
if ragged_axis_length is not None:
|
||||
aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
|
||||
if isinstance(aval, jax_core.DShapedArray):
|
||||
aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
|
||||
lengths_aval = pallas_core.AbstractMemoryRef(
|
||||
aval,
|
||||
pallas_core.MemorySpace.INDEX,
|
||||
)
|
||||
|
||||
# Only add a batch dimension for the avals that actually have a grid mapping.
|
||||
# This excludes scalar prefetch inputs (the first in the list) and scratch
|
||||
# operands (the last in the list).
|
||||
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
|
||||
|
||||
batched_block_mappings = map(
|
||||
partial(
|
||||
_batch_block_mapping,
|
||||
grid_mapping,
|
||||
axis_size,
|
||||
for_ragged=lengths_aval is not None,
|
||||
any(ragged_axis_values),
|
||||
),
|
||||
avals_to_batch,
|
||||
all_dims[num_index_operands:],
|
||||
block_mappings,
|
||||
ragged_axis_values[num_index_operands:],
|
||||
)
|
||||
|
||||
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(
|
||||
@ -829,6 +822,17 @@ def _pallas_call_batching_rule(
|
||||
assert not index_map_tree_kwargs
|
||||
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
|
||||
|
||||
lengths_aval = None # type: ignore[assignment]
|
||||
|
||||
# Check all the ragged axis values, ensure their raggedness pattern
|
||||
# is identical (consider moving this check up!)
|
||||
for rav in ragged_axis_values:
|
||||
if rav is not None:
|
||||
if lengths_aval is None:
|
||||
lengths_aval = rav[3]
|
||||
else:
|
||||
assert lengths_aval == rav[3], "NYI - different lengths in ragged batch"
|
||||
|
||||
if lengths_aval:
|
||||
batched_index_map_args = batched_index_map_args + (lengths_aval,)
|
||||
num_index_operands += 1
|
||||
@ -854,6 +858,14 @@ def _pallas_call_batching_rule(
|
||||
else:
|
||||
batched_cost_estimate = None
|
||||
|
||||
# Start the ragged handling code
|
||||
# Here, we:
|
||||
# - Rewrite the indexer to save memory (skip indices outside the ragged bounds)
|
||||
# - Rewrite the kernel to save compute (skip elements outside the ragged bounds)
|
||||
# - Update various internal structures/metadata to account for the new
|
||||
# block spec.
|
||||
# - Set the hacky flag of ragged_originating on the mapping, to signal to
|
||||
# the lowering code to treat mapped dimensions as part of the user grid.
|
||||
if lengths_aval:
|
||||
batched_grid_mapping = batched_grid_mapping.replace(
|
||||
get_grid_indices=lambda indices, maybe_include_mapped_dims: indices,
|
||||
@ -868,53 +880,102 @@ def _pallas_call_batching_rule(
|
||||
# we write 0s. This is useful for debugging, as certain lowering paths
|
||||
# like mosaic will write the last data as passthrough, leading to
|
||||
# potentially confusing results.
|
||||
debug_zero_fill_counterfactual = debug
|
||||
|
||||
first_block_mapping = batched_grid_mapping.block_mappings[0]
|
||||
block_mapped_dim_idxs = []
|
||||
for block_mapping in batched_grid_mapping.block_mappings:
|
||||
# This invariant may already be checked elsewhere, but lets reaffirm it
|
||||
assert block_mapping.block_shape == first_block_mapping.block_shape, (
|
||||
f"block_mapping.block_shape: {block_mapping.block_shape}, "
|
||||
f"first_block_mapping.block_shape: {first_block_mapping.block_shape}"
|
||||
)
|
||||
assert (
|
||||
block_mapping.array_shape_dtype
|
||||
== first_block_mapping.array_shape_dtype
|
||||
), (
|
||||
f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype},"
|
||||
" first_block_mapping.array_shape_dtype:"
|
||||
f" {first_block_mapping.array_shape_dtype}"
|
||||
)
|
||||
mapped_dim_idxs = []
|
||||
for i, d in enumerate(block_mapping.block_shape):
|
||||
if d is pallas_core.mapped:
|
||||
mapped_dim_idxs.append(i)
|
||||
else:
|
||||
mapped_dim_idxs.append(None) # type: ignore[arg-type]
|
||||
block_mapped_dim_idxs.append(mapped_dim_idxs)
|
||||
|
||||
mapped_dim_idxs = [
|
||||
i
|
||||
for i, d in enumerate(first_block_mapping.block_shape)
|
||||
if d is pallas_core.mapped
|
||||
]
|
||||
assert len(mapped_dim_idxs) == 1
|
||||
mapped_dim_idx = mapped_dim_idxs[0]
|
||||
if stacked_axis != mapped_dim_idx:
|
||||
raise ValueError(
|
||||
f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}"
|
||||
)
|
||||
mapped_dim_idx = None
|
||||
for rav, mapped_dim_idxs in zip(ragged_axis_values, block_mapped_dim_idxs):
|
||||
if rav is not None:
|
||||
stacked_axis = rav[0]
|
||||
if mapped_dim_idx is None:
|
||||
mapped_dim_idx = mapped_dim_idxs[stacked_axis]
|
||||
if mapped_dim_idxs[stacked_axis] is None:
|
||||
raise ValueError(
|
||||
f"Expected mapped dim to be {stacked_axis}, but got"
|
||||
f" {mapped_dim_idxs[stacked_axis]}"
|
||||
)
|
||||
else:
|
||||
assert mapped_dim_idx == mapped_dim_idxs[stacked_axis], (
|
||||
f"Different mapped dims - expected {mapped_dim_idx}, but got"
|
||||
f" {mapped_dim_idxs[stacked_axis]}"
|
||||
)
|
||||
|
||||
assert ragged_axis_dim is not None, "Invariant violation"
|
||||
# This is the blockspec size of the dimension
|
||||
val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim]
|
||||
block_shapes = [b.block_shape for b in batched_grid_mapping.block_mappings]
|
||||
|
||||
# Parse out the operations from the jaxpr to determine how to mask the output
|
||||
# NOTE! while this *could* be a default dict of None, and None is sound, as
|
||||
# it denotes that there is no raggedness for the given var, we explicitly
|
||||
# do not do this, so as to get better signal on implementation of rules
|
||||
# A misimplemented rule that does not account for new vars being introduced
|
||||
# will result in an error on the next op using the new var. The benefit of
|
||||
# of forcing implementers to account for all outputs and intermediaries is
|
||||
# a very nice one.
|
||||
|
||||
var_to_raggedness = {}
|
||||
for invar, rav in zip(jaxpr.invars, ragged_axis_values):
|
||||
var_to_raggedness[invar] = rav
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
prim = eqn.primitive
|
||||
if prim not in batching.ragged_prop_rules:
|
||||
raise NotImplementedError(f"Not implemented - ragged prop for {prim}")
|
||||
rule = batching.ragged_prop_rules[prim]
|
||||
|
||||
invar_raggedness = [
|
||||
(
|
||||
var_to_raggedness.get(invar, None)
|
||||
if isinstance(invar, jax_core.Var)
|
||||
else None
|
||||
)
|
||||
for invar in eqn.invars
|
||||
]
|
||||
invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars)
|
||||
|
||||
for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment]
|
||||
if isinstance(invar, jax_core.Var):
|
||||
var_to_raggedness[invar] = rav
|
||||
for outvar, rav in zip(eqn.outvars, outvar_raggedness):
|
||||
if isinstance(outvar, jax_core.Var):
|
||||
var_to_raggedness[outvar] = rav
|
||||
|
||||
for pos, invar in enumerate(jaxpr.invars):
|
||||
ragged_axis_values[pos] = var_to_raggedness[invar]
|
||||
|
||||
per_input_ragged_axis_dim = []
|
||||
for rav in ragged_axis_values:
|
||||
if rav is not None:
|
||||
per_input_ragged_axis_dim.append(rav[1])
|
||||
else:
|
||||
per_input_ragged_axis_dim.append(None)
|
||||
|
||||
def when_wrapped_kernel(lengths_ref, *args, **kwargs):
|
||||
b_idx = primitives.program_id(stacked_axis)
|
||||
i_idx = (
|
||||
primitives.program_id(ragged_axis_dim)
|
||||
* val_at_ragged_dim
|
||||
)
|
||||
b_idx = primitives.program_id(mapped_dim_idx)
|
||||
|
||||
b_len = lengths_ref[b_idx]
|
||||
run_kernel = jnp.array(True)
|
||||
for i, _ in enumerate(args):
|
||||
ragged_axis_dim = per_input_ragged_axis_dim[i]
|
||||
if ragged_axis_dim is None:
|
||||
continue
|
||||
arg_i_idx = (
|
||||
primitives.program_id(ragged_axis_dim)
|
||||
* block_shapes[i][ragged_axis_dim]
|
||||
)
|
||||
run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len)
|
||||
|
||||
# TODO(mvoz): Unimplemented primitive in pallas
|
||||
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
|
||||
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
|
||||
|
||||
@pallas_utils.when(i_idx < b_len)
|
||||
@pallas_utils.when(run_kernel)
|
||||
def f():
|
||||
# Important! This allows us to trace the inner kernel with the correct
|
||||
# grid to preserve user program_id semantics. Ex: program_id(0) will
|
||||
@ -922,20 +983,108 @@ def _pallas_call_batching_rule(
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs)
|
||||
|
||||
if debug_zero_fill_counterfactual:
|
||||
|
||||
@pallas_utils.when(i_idx >= b_len)
|
||||
def g():
|
||||
for arg_ref in args:
|
||||
arg_ref[...] = jnp.zeros_like(arg_ref)
|
||||
|
||||
kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars]
|
||||
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(
|
||||
list(kernel_avals)
|
||||
)
|
||||
|
||||
def _rewrite_index_jaxpr(enumerate_batched_block_mapping):
|
||||
arg_pos, batched_block_mapping = enumerate_batched_block_mapping
|
||||
indexer_avals = [
|
||||
v.aval for v in batched_block_mapping.index_map_jaxpr.jaxpr.invars
|
||||
]
|
||||
flat_indexer_avals, indexer_in_tree = tree_util.tree_flatten(
|
||||
list(indexer_avals)
|
||||
)
|
||||
|
||||
def index_rewrite_kernel(*indexer_args):
|
||||
ragged_axis_dim = per_input_ragged_axis_dim[arg_pos]
|
||||
|
||||
# the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means
|
||||
# that the indexer for output isnt getting written - before, it always was
|
||||
|
||||
lengths_ref = indexer_args[-1]
|
||||
rest_indexer_args = indexer_args[:-1]
|
||||
# Lengths are always the last argument of the indexer.
|
||||
# lengths_ref = args[-1]
|
||||
# Invariant: Stacked axis is enforced to be the mapped axis above.
|
||||
b_idx = indexer_args[mapped_dim_idx]
|
||||
|
||||
nargs = list(rest_indexer_args)
|
||||
|
||||
if ragged_axis_dim is not None:
|
||||
val_at_ragged_dim = batched_block_mapping.block_shape[ragged_axis_dim]
|
||||
|
||||
# The current index into the ragged dimension.
|
||||
# Invariant: There is only one ragged dimension, enforced above.
|
||||
i_idx = indexer_args[ragged_axis_dim]
|
||||
|
||||
# grid space -> element space
|
||||
i_len = i_idx * val_at_ragged_dim
|
||||
|
||||
# The length of the current batch.
|
||||
b_len = lengths_ref[b_idx]
|
||||
|
||||
# Have we reached the end of the current batch?
|
||||
not_done = i_len < b_len
|
||||
|
||||
am_last_batch = b_idx == axis_size - 1
|
||||
last_good_block = lax.div(b_len, val_at_ragged_dim) - 1
|
||||
|
||||
# The logic below can be thought of as:
|
||||
# if index_oob_ragged:
|
||||
# if not last_batch:
|
||||
# batch_idx += 1
|
||||
# ragged_idx = 0
|
||||
# else:
|
||||
# ragged_idx = last_good_block
|
||||
#
|
||||
# wherein we find the next good block by incrementing the batch index
|
||||
# and setting the ragged index to 0 if we are not in the last batch.
|
||||
# Otherwise, we set the ragged index to the last good block.
|
||||
b_next = jnp.where(
|
||||
not_done, b_idx, jnp.where(am_last_batch, b_idx, b_idx + 1)
|
||||
)
|
||||
i_next = jnp.where(
|
||||
not_done, i_idx, jnp.where(am_last_batch, last_good_block, 0)
|
||||
)
|
||||
nargs[ragged_axis_dim] = i_next
|
||||
nargs[mapped_dim_idx] = b_next
|
||||
|
||||
nargs = nargs + [lengths_ref]
|
||||
return jax_core.eval_jaxpr(
|
||||
batched_block_mapping.index_map_jaxpr.jaxpr,
|
||||
batched_block_mapping.index_map_jaxpr.consts,
|
||||
*nargs,
|
||||
)
|
||||
|
||||
index_jaxpr, _ = _trace_kernel_to_jaxpr(
|
||||
index_rewrite_kernel,
|
||||
"index_rewrite_kernel",
|
||||
batched_grid_mapping,
|
||||
tuple(flat_indexer_avals),
|
||||
indexer_in_tree,
|
||||
tuple(() for _ in flat_indexer_avals),
|
||||
interpret=interpret,
|
||||
indexer=True,
|
||||
)
|
||||
|
||||
batched_block_mapping = batched_block_mapping.replace(
|
||||
index_map_jaxpr=pe.close_jaxpr(index_jaxpr)
|
||||
)
|
||||
return batched_block_mapping
|
||||
|
||||
# Important! This allows us to trace the outer kernel with the correct grid
|
||||
# to enable accessing the batch program_id.
|
||||
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
|
||||
batched_block_mappings = map(
|
||||
_rewrite_index_jaxpr, enumerate(batched_block_mappings)
|
||||
)
|
||||
|
||||
batched_grid_mapping = batched_grid_mapping.replace(
|
||||
block_mappings=tuple(batched_block_mappings),
|
||||
)
|
||||
|
||||
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"
|
||||
|
||||
jaxpr, consts = _trace_kernel_to_jaxpr(
|
||||
@ -950,7 +1099,15 @@ def _pallas_call_batching_rule(
|
||||
if consts:
|
||||
raise NotImplementedError("consts not supported in pallas_call")
|
||||
|
||||
assert ragged_axis_length is not None
|
||||
# We need to rewrite the input_output_aliases here, the initial call
|
||||
# to broadcast is done, and we have inseted a new input (lengths), so
|
||||
# there's an off-by-one here now.
|
||||
new_input_output_aliases = []
|
||||
for k, v in input_output_aliases:
|
||||
new_input_output_aliases.append((k + 1, v))
|
||||
input_output_aliases = tuple(new_input_output_aliases)
|
||||
|
||||
# assert ragged_axis_length is not None
|
||||
args = (ragged_axis_length, *args)
|
||||
assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals)
|
||||
batched_out_avals = tuple(
|
||||
@ -1225,7 +1382,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
return new_error, results
|
||||
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
|
||||
# All of those shenanigans are because we can't make TransformedRef a PyTree,
|
||||
# because they should appear as atomic JAX values to the users.
|
||||
@lu.transformation
|
||||
@ -1247,6 +1403,7 @@ def _trace_kernel_to_jaxpr(
|
||||
kernel_in_tree: tree_util.PyTreeDef,
|
||||
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
|
||||
interpret: bool,
|
||||
indexer: bool = False,
|
||||
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
|
||||
if interpret:
|
||||
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
|
||||
@ -1268,7 +1425,7 @@ def _trace_kernel_to_jaxpr(
|
||||
"You should pass them as inputs")
|
||||
|
||||
kernel_out_tree = out_tree_thunk()
|
||||
if kernel_out_tree != tree_util.tree_structure(None):
|
||||
if not indexer and kernel_out_tree != tree_util.tree_structure(None):
|
||||
raise ValueError(
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
f"should return None. It returns a PyTree: {kernel_out_tree}")
|
||||
@ -1673,6 +1830,18 @@ def pallas_call(
|
||||
f"[0, {len(flat_out_avals)})")
|
||||
in_aval = flat_in_avals[i_idx]
|
||||
out_aval = flat_out_avals[o_idx]
|
||||
if isinstance(in_aval, jax_core.DShapedArray):
|
||||
new_shape = []
|
||||
for d in in_aval.shape:
|
||||
if isinstance(d, int):
|
||||
new_shape.append(d)
|
||||
else:
|
||||
new_shape.append(d.dtype.bound)
|
||||
|
||||
in_aval = jax_core.ShapedArray(
|
||||
tuple(new_shape), in_aval.dtype, in_aval.weak_type
|
||||
)
|
||||
|
||||
if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype:
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
||||
|
@ -51,6 +51,7 @@ map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
program_id_p = jax_core.Primitive("program_id")
|
||||
batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule
|
||||
|
||||
def program_id(axis: int) -> jax.Array:
|
||||
"""Returns the kernel execution position along the given axis of the grid.
|
||||
|
@ -63,6 +63,7 @@ traceback_util.register_exclusion(__file__)
|
||||
# a:f32[3] <- x[]
|
||||
get_p = core.Primitive("get")
|
||||
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
|
||||
batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity
|
||||
|
||||
Indexer = Union[int, slice, Array, types.EllipsisType]
|
||||
|
||||
@ -122,6 +123,16 @@ swap_p = core.Primitive("swap")
|
||||
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
|
||||
|
||||
|
||||
def swap_ragged_prop_rule(invar_raggedness, outvars):
|
||||
assert len(invar_raggedness) == 2
|
||||
invar_raggedness_lhs = invar_raggedness[0]
|
||||
invar_raggedness_rhs = invar_raggedness[1]
|
||||
|
||||
return [invar_raggedness_rhs, invar_raggedness_lhs], [None]
|
||||
|
||||
|
||||
batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule
|
||||
|
||||
def ref_swap(
|
||||
ref_or_view: AbstractRef | TransformedRef,
|
||||
idx: Indexer | tuple[Indexer, ...] | None,
|
||||
|
@ -111,8 +111,81 @@ class PallasCallRaggedVmapTest(PallasBaseTest):
|
||||
ragged_total = 0
|
||||
for dim in ragged_shape:
|
||||
ragged_total += row_count * dim * 128
|
||||
# See note - on zero filling counterfactuals
|
||||
self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total)
|
||||
|
||||
def correct(v):
|
||||
return np.count_nonzero(v == jnp.sin(1.0))
|
||||
|
||||
for b, batch in enumerate(res):
|
||||
ragged_val = ragged_shape[b]
|
||||
for r, row in enumerate(batch):
|
||||
row_total = ragged_val * 128
|
||||
self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}")
|
||||
|
||||
self.assertEqual(correct(res), ragged_total)
|
||||
|
||||
def test_vmap_jumble_over_add_kernel(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Only tested on TPU")
|
||||
|
||||
row_count = 8
|
||||
col_grid_size = 5
|
||||
ragged_shape = [3, 1, 4]
|
||||
sizes = lax.convert_element_type(
|
||||
jnp.array([128 * x for x in ragged_shape]),
|
||||
core.bint(col_grid_size * 128),
|
||||
)
|
||||
x = jax.vmap(
|
||||
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
y = jax.vmap(
|
||||
lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
|
||||
def kernel(x_ref, y_ref, o_ref):
|
||||
o_ref[...] = x_ref[...] + y_ref[...]
|
||||
|
||||
def invoke_kernel(x, y):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
pl.BlockSpec((8, 128), lambda j, k: (j, k)),
|
||||
pl.BlockSpec((8, 128), lambda j, k: (j, k)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
|
||||
out_shape=jax.ShapeDtypeStruct(
|
||||
(8, col_grid_size * 128), dtype=jnp.float32
|
||||
),
|
||||
grid=(1, col_grid_size),
|
||||
interpret=False,
|
||||
)(x, y)
|
||||
|
||||
res = jax.vmap(
|
||||
invoke_kernel,
|
||||
out_axes=batching.jumble_axis,
|
||||
in_axes=batching.jumble_axis,
|
||||
axis_size=3,
|
||||
)(x, y)
|
||||
|
||||
res = res.data
|
||||
total = len(ragged_shape) * row_count * col_grid_size * 128
|
||||
res_total = np.prod(res.shape)
|
||||
self.assertEqual(res_total, total)
|
||||
ragged_total = 0
|
||||
for dim in ragged_shape:
|
||||
ragged_total += row_count * dim * 128
|
||||
|
||||
def correct(v):
|
||||
return np.count_nonzero(v == 2.0)
|
||||
|
||||
for r, row in enumerate(res):
|
||||
ragged_val = ragged_shape[r]
|
||||
row_total = ragged_val * 128 * row_count
|
||||
self.assertEqual(correct(row), row_total)
|
||||
for col in row:
|
||||
col_total = ragged_val * 128
|
||||
self.assertEqual(correct(col), col_total)
|
||||
|
||||
self.assertEqual(np.count_nonzero(res == 2.0), ragged_total)
|
||||
|
||||
def test_vmap_jumble_over_sin_kernel_grid_remapping(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
@ -150,6 +223,97 @@ class PallasCallRaggedVmapTest(PallasBaseTest):
|
||||
axis_size=3,
|
||||
)(x)
|
||||
|
||||
def test_vmap_jumble_over_matmul_kernel(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Only tested on TPU")
|
||||
|
||||
m = 128
|
||||
k = 640
|
||||
n = 640
|
||||
|
||||
def matmul_kernel(x_ref, y_ref, x_sentinel, z_ref):
|
||||
# weird little once-only reset
|
||||
@pl.when(x_sentinel[...][0][0] == 1.0)
|
||||
def _():
|
||||
z_ref[...] = jnp.zeros_like(z_ref)
|
||||
x_sentinel[...] = jnp.zeros_like(x_sentinel)
|
||||
|
||||
z_ref[...] += x_ref[...] @ y_ref[...]
|
||||
|
||||
def matmul(
|
||||
x: jax.Array,
|
||||
y: jax.Array,
|
||||
x_sentinel: jax.Array,
|
||||
*,
|
||||
bm: int = 128,
|
||||
bk: int = 128,
|
||||
bn: int = 640,
|
||||
):
|
||||
# m, k = x.shape
|
||||
# _, n = y.shape
|
||||
# a (1, 5) grid
|
||||
# TODO(mvoz): parameterize this grid?
|
||||
grid = (n // bn, k // bk)
|
||||
return pl.pallas_call(
|
||||
matmul_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
|
||||
in_specs=[
|
||||
pl.BlockSpec(
|
||||
(bm, bk),
|
||||
lambda j, k: (0, k),
|
||||
),
|
||||
pl.BlockSpec(
|
||||
(bk, bn),
|
||||
lambda j, k: (k, j),
|
||||
),
|
||||
pl.BlockSpec(
|
||||
(bm, bn),
|
||||
lambda j, k: (0, j),
|
||||
),
|
||||
],
|
||||
out_specs=pl.BlockSpec(
|
||||
(bm, bn),
|
||||
lambda j, k: (0, j),
|
||||
),
|
||||
grid=grid,
|
||||
input_output_aliases={2: 0},
|
||||
interpret=False,
|
||||
)(x, y, x_sentinel)
|
||||
|
||||
# TODO(mvoz): parameterize this shape?
|
||||
ragged_shape = [3, 1, 4]
|
||||
sizes = lax.convert_element_type(
|
||||
jnp.array([128 * x for x in ragged_shape]),
|
||||
core.bint(k),
|
||||
)
|
||||
x = jax.vmap(lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis)(
|
||||
sizes
|
||||
)
|
||||
x_sentinel = jax.vmap(
|
||||
lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis
|
||||
)(sizes)
|
||||
y = jax.vmap(lambda k_: jnp.ones((k_, n)), out_axes=batching.jumble_axis)(
|
||||
sizes
|
||||
)
|
||||
|
||||
res = jax.vmap(
|
||||
matmul,
|
||||
out_axes=batching.jumble_axis,
|
||||
in_axes=batching.jumble_axis,
|
||||
axis_size=3,
|
||||
)(x, y, x_sentinel)
|
||||
|
||||
ref = jax.vmap(
|
||||
jnp.dot,
|
||||
out_axes=batching.jumble_axis,
|
||||
in_axes=batching.jumble_axis,
|
||||
axis_size=3,
|
||||
)(x, y)
|
||||
|
||||
ref = ref.data
|
||||
res = res.data
|
||||
np.testing.assert_allclose(ref, res)
|
||||
|
||||
def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Only tested on TPU")
|
||||
|
Loading…
x
Reference in New Issue
Block a user