[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.

This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.

This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.

And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.

PiperOrigin-RevId: 708752566
This commit is contained in:
jax authors 2024-12-22 00:50:12 -08:00
parent d28b0daccd
commit 1719986aaa
5 changed files with 114 additions and 9 deletions

View File

@ -38,6 +38,7 @@ from jax._src.lax.control_flow.conditionals import (
cond_p as cond_p,
switch as switch,
platform_dependent as platform_dependent,
platform_index_p as platform_index_p,
)
from jax._src.lax.control_flow.solves import (
custom_linear_solve as custom_linear_solve,

View File

@ -945,6 +945,7 @@ def platform_dependent(*args: Any,
platform_index = platform_index_p.bind(
platforms=tuple(tuple(ps) for ps in platforms_lists),
has_default=(default is not None))
if default is not None:
branches = branches + (default,)
# Use a switch, to get the proper transformation rules for free. Since
@ -957,6 +958,8 @@ def platform_dependent(*args: Any,
# recognized on the compilation platform. Detect eager mode and keep only the
# needed branch.
try:
# Note/TODO(mvoz): This actually rarely seems to concretize - we could look into
# core.ensure_compile_time_eval to get better single-branch selection.
platform_index_concrete = core.concrete_or_error(operator.index, platform_index)
except core.ConcretizationTypeError:
return switch(platform_index, branches, *args)

View File

@ -8341,18 +8341,41 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
Array([4, 8], dtype=int32)
"""
util.check_arraylike("diagonal", a)
a_shape = shape(a)
if ndim(a) < 2:
raise ValueError("diagonal requires an array of at least two dimensions.")
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
a = moveaxis(a, (axis1, axis2), (-2, -1))
def _default_diag(a):
a_shape = shape(a)
diag_size = max(0, min(a_shape[axis1] + min(offset, 0),
a_shape[axis2] - max(offset, 0)))
i = arange(diag_size)
j = arange(abs(offset), abs(offset) + diag_size)
return a[..., i, j] if offset >= 0 else a[..., j, i]
a = moveaxis(a, (axis1, axis2), (-2, -1))
diag_size = max(
0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0))
)
i = arange(diag_size)
j = arange(abs(offset), abs(offset) + diag_size)
return a[..., i, j] if offset >= 0 else a[..., j, i]
# The mosaic lowering rule for diag is only defined for square arrays.
# TODO(mvoz): Add support for offsets.
if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0 or _dtype(a) == bool_:
return _default_diag(a)
else:
a_shape_eye = eye(shape(a)[0], dtype=_dtype(a))
def _mosaic_diag(a):
def _sum(x, axis):
return lax.reduce(
x,
np.array(0, _dtype(x)),
lax.add if _dtype(x) != bool_ else lax.bitwise_or,
(axis,),
)
return _sum(lax.mul(a_shape_eye, a), axis=0)
return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag)
@export

View File

@ -547,9 +547,13 @@ def lower_jaxpr_to_module(
module_name = name_and_src_info.name
attrs["sym_name"] = ir.StringAttr.get(module_name)
sym_tab = ir.SymbolTable(m.operation)
func_op = lower_jaxpr_to_func(
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
name="main", for_verification=for_verification,
ctx,
jaxpr,
mosaic_grid_mapping=mosaic_grid_mapping,
name="main",
for_verification=for_verification,
)
m.body.append(func_op)
sym_tab.insert(func_op)
@ -568,6 +572,7 @@ def lower_jaxpr_to_module(
# We checked above that the block does not require windowing.
window_params.append(ir.DictAttr.get())
continue
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
@ -1990,6 +1995,36 @@ lowering_rules[ad_util.add_any_p] = _add_lowering_rule
skip_mlir_conversions.add(ad_util.add_any_p)
class FoldingError(Exception):
pass
def _fold_and_get_constant_value(x):
def _fold(x, fuel):
if fuel <= 0:
raise FoldingError("Folding depth exceeded")
op_name = getattr(x.owner, "name", None)
binop_folds = {
"arith.maxsi": max,
"arith.minsi": min,
}
if op_name == "arith.constant":
if ir.IntegerType.isinstance(x.type):
return ir.IntegerAttr(x.owner.attributes["value"]).value
elif ir.FloatType.isinstance(x.type):
return ir.FloatAttr(x.owner.attributes["value"]).value
else:
raise ValueError(f"Unsupported constant type: {x.type}")
if op_name in binop_folds:
return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands)
raise FoldingError(f"Folding not supported for {x.owner}")
try:
return _fold(x, 10)
except FoldingError:
return None
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out
@ -2708,6 +2743,12 @@ lowering_rules[lax.while_p] = _while_lowering_rule
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
index, *args = args
constant_index = _fold_and_get_constant_value(index)
if constant_index is not None:
return jaxpr_subcomp(
ctx.lowering_context.replace(block_shapes=ctx.block_shapes[1:]), branches[constant_index].jaxpr, *args
)
out_types = map(aval_to_ir_type, ctx.avals_out)
pred = arith.cmpi(
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
@ -3375,3 +3416,25 @@ def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
lowering_rules[lax.pad_p] = _pad_lowering_rule
def _platform_index_lowering(
ctx: mlir.LoweringRuleContext,
*,
platforms: Sequence[Sequence[str]],
has_default: bool,
):
for i, ps in enumerate(platforms):
# note - slightly odd structure here, as platforms is a seq[seq[str]]
if "mosaic" in ps:
return ir_constant(i)
if has_default:
return ir_constant(len(platforms))
raise NotImplementedError(
"No mosaic or default platform indexing rule found."
)
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering

View File

@ -2127,6 +2127,21 @@ class OpsTest(PallasBaseTest):
)
self.assertTrue(acceptable_errors, "Failed with error: " + str(e))
@parameterized.parameters((128, 128), (256, 256))
def test_jnp_diagonal_pallas(self, n, m):
if jtu.test_device_matches(["gpu"]):
# TODO(mvoz): platform_index_p on GPU
self.skipTest("Not implemented on GPU")
x = jnp.arange(n * m, dtype=jnp.float32).reshape((n, m))
def kernel(x_ref, out_ref):
out_ref[...] = jnp.diagonal(x_ref[...])
out = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((n,), jnp.float32)
)(x)
np.testing.assert_array_equal(out, np.diagonal(x))
class OpsInterpretTest(OpsTest):
INTERPRET = True