[Mosaic] Add sin and clamp lowering rules and support multiple branches in cond. Add a pallas_call test using scan/cond. Improve the error message for lowering exceptions and add a LoweringException type.

PiperOrigin-RevId: 568945255
This commit is contained in:
Emily Fertig 2023-09-27 13:33:04 -07:00 committed by jax authors
parent 87af945cbe
commit c62c6fc1ab
3 changed files with 69 additions and 23 deletions

View File

@ -35,6 +35,7 @@ py_library_providing_imports_info(
deps = [
":core",
":kernel_regeneration_util",
":lowering",
":pallas_call_registration",
":primitives",
],

View File

@ -21,6 +21,7 @@ from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import device_id

View File

@ -232,6 +232,7 @@ def lower_jaxpr_to_transform_func(
body.func_op.verify()
return body.func_op
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
@ -319,6 +320,10 @@ def lower_jaxpr_to_func(
return body.func_op
class LoweringException(Exception):
pass
def jaxpr_subcomp(
ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
) -> Sequence[ir.Value]:
@ -362,7 +367,21 @@ def jaxpr_subcomp(
[v.aval for v in eqn.outvars],
block_shapes,
)
ans = lowering_rules[eqn.primitive](rule_context, *invals, **eqn.params)
try:
ans = lowering_rules[eqn.primitive](
rule_context, *invals, **eqn.params
)
except LoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
raise LoweringException(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
f" {rule_context}\nWith inval"
f" shapes={map(lambda t: getattr(t, 'shape', None), invals)}\nWith"
" inval"
f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn"
f" jaxpr:\n{jaxpr}"
) from e
else:
raise NotImplementedError(
"Unimplemented primitive in Pallas TPU lowering: "
@ -829,6 +848,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions):
raise NotImplementedError
if any(d is None for d in new_sizes):
raise NotImplementedError
if not ctx.avals_in[0].shape:
return vector.BroadcastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
@ -875,13 +896,13 @@ lowering_rules[lax.transpose_p] = _transpose_lowering_rule
def _bcast(x, y, x_aval, y_aval, out_aval):
if isinstance(x, (np.ndarray, np.uint32, int, float)):
if isinstance(x, (np.ndarray, np.number, int, float)):
if hasattr(y, "type") and y.type == ir.IndexType.get():
mlir_type = y.type
else:
mlir_type = mlir.dtype_to_ir_type(x_aval.dtype)
x = ir_constant(x, mlir_type)
if isinstance(y, (np.ndarray, np.uint32, int, float)):
if isinstance(y, (np.ndarray, np.number, int, float)):
if hasattr(x, "type") and x.type == ir.IndexType.get():
mlir_type = x.type
else:
@ -1045,7 +1066,8 @@ lowering_rules[lax.exp_p] = _exp_lowering_rule
def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
if not isinstance(x, ir.Value) and x == 2.:
return math.Exp2Op(y).result
raise NotImplementedError("Only 2^x supported")
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
return math.PowFOp(x, y).result
lowering_rules[lax.pow_p] = _pow_lowering_rule
@ -1075,10 +1097,11 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
neg_x = arith.NegFOp(x).result
exp_neg_x = math.ExpOp(neg_x).result
aval_out = ctx.avals_out[0]
out_type = ir.VectorType.get(
aval_out.shape, mlir.dtype_to_ir_type(aval_out.dtype)
)
one = vector.BroadcastOp(out_type, ir_constant(1.0))
out_type = aval_to_ir_type(aval_out)
if aval_out.shape == ():
one = ir_constant(1.0, mlir_type=out_type)
else:
one = vector.BroadcastOp(out_type, ir_constant(1.0))
denom = arith.AddFOp(one, exp_neg_x).result
return arith.DivFOp(one, denom).result
@ -1086,6 +1109,13 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.logistic_p] = _logistic_lowering_rule
def _sin_lowering_rule(ctx: LoweringRuleContext, x):
return math.SinOp(x).result
lowering_rules[lax.sin_p] = _sin_lowering_rule
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
return math.TanhOp(x).result
@ -1179,6 +1209,20 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args):
lowering_rules[lax.select_n_p] = _select_n_lowering_rule
def _clamp(min, operand, max):
res = jnp.maximum(operand, min)
return jnp.minimum(res, max)
def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max):
"""Compute minimum_p(maximum_p(min, operand), max)."""
return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max)
lowering_rules[lax.clamp_p] = _clamp_lowering_rule
def _for_lowering_rule(
ctx: LoweringRuleContext,
*args,
@ -1211,7 +1255,6 @@ def _for_lowering_rule(
lowering_rules[for_loop.for_p] = _for_lowering_rule
skip_mlir_conversions.add(for_loop.for_p)
def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext,
@ -1277,35 +1320,36 @@ skip_mlir_conversions.add(lax.scan_p)
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):
del linear
if len(branches) > 2:
raise NotImplementedError
pred, *args = args
index, *args = args
out_types = map(aval_to_ir_type, ctx.avals_out)
pred = arith.TruncIOp(
aval_to_ir_type(jax_core.ShapedArray((), jnp.bool_)), pred
pred = arith.CmpIOp(
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
).result
# Specialize to singleton `if`s
singleton = len(out_types) == 1
if singleton:
out_types = out_types[0]
if_op = scf.IfOp(pred, out_types, hasElse=True)
lowering_context = ctx.lowering_context.replace(
block_shapes=ctx.block_shapes[1:],
)
with ir.InsertionPoint(if_op.then_block):
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
# TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of
# if/else.
if len(branches) > 2:
out = _cond_lowering_rule(
ctx,
arith.SubIOp(index, ir_constant(1, index.type)).result,
*args,
branches=branches[1:],
linear=linear,
)
else:
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
scf.YieldOp(out)
with ir.InsertionPoint(if_op.else_block):
out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args)
scf.YieldOp(out)
if singleton:
return if_op.result
return if_op.results
lowering_rules[lax.cond_p] = _cond_lowering_rule
skip_mlir_conversions.add(lax.cond_p)
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):