mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage. This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation. And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather. PiperOrigin-RevId: 708752566
This commit is contained in:
parent
d28b0daccd
commit
1719986aaa
@ -38,6 +38,7 @@ from jax._src.lax.control_flow.conditionals import (
|
||||
cond_p as cond_p,
|
||||
switch as switch,
|
||||
platform_dependent as platform_dependent,
|
||||
platform_index_p as platform_index_p,
|
||||
)
|
||||
from jax._src.lax.control_flow.solves import (
|
||||
custom_linear_solve as custom_linear_solve,
|
||||
|
@ -945,6 +945,7 @@ def platform_dependent(*args: Any,
|
||||
platform_index = platform_index_p.bind(
|
||||
platforms=tuple(tuple(ps) for ps in platforms_lists),
|
||||
has_default=(default is not None))
|
||||
|
||||
if default is not None:
|
||||
branches = branches + (default,)
|
||||
# Use a switch, to get the proper transformation rules for free. Since
|
||||
@ -957,6 +958,8 @@ def platform_dependent(*args: Any,
|
||||
# recognized on the compilation platform. Detect eager mode and keep only the
|
||||
# needed branch.
|
||||
try:
|
||||
# Note/TODO(mvoz): This actually rarely seems to concretize - we could look into
|
||||
# core.ensure_compile_time_eval to get better single-branch selection.
|
||||
platform_index_concrete = core.concrete_or_error(operator.index, platform_index)
|
||||
except core.ConcretizationTypeError:
|
||||
return switch(platform_index, branches, *args)
|
||||
|
@ -8341,18 +8341,41 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0,
|
||||
Array([4, 8], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("diagonal", a)
|
||||
a_shape = shape(a)
|
||||
|
||||
if ndim(a) < 2:
|
||||
raise ValueError("diagonal requires an array of at least two dimensions.")
|
||||
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
|
||||
|
||||
a = moveaxis(a, (axis1, axis2), (-2, -1))
|
||||
def _default_diag(a):
|
||||
a_shape = shape(a)
|
||||
|
||||
diag_size = max(0, min(a_shape[axis1] + min(offset, 0),
|
||||
a_shape[axis2] - max(offset, 0)))
|
||||
i = arange(diag_size)
|
||||
j = arange(abs(offset), abs(offset) + diag_size)
|
||||
return a[..., i, j] if offset >= 0 else a[..., j, i]
|
||||
a = moveaxis(a, (axis1, axis2), (-2, -1))
|
||||
|
||||
diag_size = max(
|
||||
0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0))
|
||||
)
|
||||
i = arange(diag_size)
|
||||
j = arange(abs(offset), abs(offset) + diag_size)
|
||||
return a[..., i, j] if offset >= 0 else a[..., j, i]
|
||||
|
||||
|
||||
# The mosaic lowering rule for diag is only defined for square arrays.
|
||||
# TODO(mvoz): Add support for offsets.
|
||||
if shape(a)[0] != shape(a)[1] or ndim(a) != 2 or offset != 0 or _dtype(a) == bool_:
|
||||
return _default_diag(a)
|
||||
else:
|
||||
a_shape_eye = eye(shape(a)[0], dtype=_dtype(a))
|
||||
|
||||
def _mosaic_diag(a):
|
||||
def _sum(x, axis):
|
||||
return lax.reduce(
|
||||
x,
|
||||
np.array(0, _dtype(x)),
|
||||
lax.add if _dtype(x) != bool_ else lax.bitwise_or,
|
||||
(axis,),
|
||||
)
|
||||
return _sum(lax.mul(a_shape_eye, a), axis=0)
|
||||
return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag)
|
||||
|
||||
|
||||
@export
|
||||
|
@ -547,9 +547,13 @@ def lower_jaxpr_to_module(
|
||||
module_name = name_and_src_info.name
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
sym_tab = ir.SymbolTable(m.operation)
|
||||
|
||||
func_op = lower_jaxpr_to_func(
|
||||
ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
name="main", for_verification=for_verification,
|
||||
ctx,
|
||||
jaxpr,
|
||||
mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
name="main",
|
||||
for_verification=for_verification,
|
||||
)
|
||||
m.body.append(func_op)
|
||||
sym_tab.insert(func_op)
|
||||
@ -568,6 +572,7 @@ def lower_jaxpr_to_module(
|
||||
# We checked above that the block does not require windowing.
|
||||
window_params.append(ir.DictAttr.get())
|
||||
continue
|
||||
|
||||
mlir_func = lower_jaxpr_to_transform_func(
|
||||
ctx,
|
||||
bm.index_map_jaxpr.jaxpr,
|
||||
@ -1990,6 +1995,36 @@ lowering_rules[ad_util.add_any_p] = _add_lowering_rule
|
||||
skip_mlir_conversions.add(ad_util.add_any_p)
|
||||
|
||||
|
||||
class FoldingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _fold_and_get_constant_value(x):
|
||||
def _fold(x, fuel):
|
||||
if fuel <= 0:
|
||||
raise FoldingError("Folding depth exceeded")
|
||||
op_name = getattr(x.owner, "name", None)
|
||||
binop_folds = {
|
||||
"arith.maxsi": max,
|
||||
"arith.minsi": min,
|
||||
}
|
||||
if op_name == "arith.constant":
|
||||
if ir.IntegerType.isinstance(x.type):
|
||||
return ir.IntegerAttr(x.owner.attributes["value"]).value
|
||||
elif ir.FloatType.isinstance(x.type):
|
||||
return ir.FloatAttr(x.owner.attributes["value"]).value
|
||||
else:
|
||||
raise ValueError(f"Unsupported constant type: {x.type}")
|
||||
if op_name in binop_folds:
|
||||
return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands)
|
||||
raise FoldingError(f"Folding not supported for {x.owner}")
|
||||
|
||||
try:
|
||||
return _fold(x, 10)
|
||||
except FoldingError:
|
||||
return None
|
||||
|
||||
|
||||
def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
||||
(aval_out,) = ctx.avals_out
|
||||
@ -2708,6 +2743,12 @@ lowering_rules[lax.while_p] = _while_lowering_rule
|
||||
|
||||
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
|
||||
index, *args = args
|
||||
constant_index = _fold_and_get_constant_value(index)
|
||||
|
||||
if constant_index is not None:
|
||||
return jaxpr_subcomp(
|
||||
ctx.lowering_context.replace(block_shapes=ctx.block_shapes[1:]), branches[constant_index].jaxpr, *args
|
||||
)
|
||||
out_types = map(aval_to_ir_type, ctx.avals_out)
|
||||
pred = arith.cmpi(
|
||||
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
|
||||
@ -3375,3 +3416,25 @@ def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
|
||||
|
||||
|
||||
lowering_rules[lax.pad_p] = _pad_lowering_rule
|
||||
|
||||
|
||||
def _platform_index_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*,
|
||||
platforms: Sequence[Sequence[str]],
|
||||
has_default: bool,
|
||||
):
|
||||
for i, ps in enumerate(platforms):
|
||||
# note - slightly odd structure here, as platforms is a seq[seq[str]]
|
||||
if "mosaic" in ps:
|
||||
return ir_constant(i)
|
||||
|
||||
if has_default:
|
||||
return ir_constant(len(platforms))
|
||||
|
||||
raise NotImplementedError(
|
||||
"No mosaic or default platform indexing rule found."
|
||||
)
|
||||
|
||||
|
||||
lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering
|
||||
|
@ -2127,6 +2127,21 @@ class OpsTest(PallasBaseTest):
|
||||
)
|
||||
self.assertTrue(acceptable_errors, "Failed with error: " + str(e))
|
||||
|
||||
@parameterized.parameters((128, 128), (256, 256))
|
||||
def test_jnp_diagonal_pallas(self, n, m):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# TODO(mvoz): platform_index_p on GPU
|
||||
self.skipTest("Not implemented on GPU")
|
||||
x = jnp.arange(n * m, dtype=jnp.float32).reshape((n, m))
|
||||
|
||||
def kernel(x_ref, out_ref):
|
||||
out_ref[...] = jnp.diagonal(x_ref[...])
|
||||
|
||||
out = self.pallas_call(
|
||||
kernel, out_shape=jax.ShapeDtypeStruct((n,), jnp.float32)
|
||||
)(x)
|
||||
np.testing.assert_array_equal(out, np.diagonal(x))
|
||||
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user