mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] Add sin
and clamp
lowering rules and support multiple branches in cond
. Add a pallas_call test using scan/cond. Improve the error message for lowering exceptions and add a LoweringException
type.
PiperOrigin-RevId: 568945255
This commit is contained in:
parent
87af945cbe
commit
c62c6fc1ab
@ -35,6 +35,7 @@ py_library_providing_imports_info(
|
||||
deps = [
|
||||
":core",
|
||||
":kernel_regeneration_util",
|
||||
":lowering",
|
||||
":pallas_call_registration",
|
||||
":primitives",
|
||||
],
|
||||
|
@ -21,6 +21,7 @@ from jax._src.pallas.mosaic.core import SemaphoreType
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException
|
||||
from jax._src.pallas.mosaic.primitives import async_copy
|
||||
from jax._src.pallas.mosaic.primitives import async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import device_id
|
||||
|
@ -232,6 +232,7 @@ def lower_jaxpr_to_transform_func(
|
||||
body.func_op.verify()
|
||||
return body.func_op
|
||||
|
||||
|
||||
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
@ -319,6 +320,10 @@ def lower_jaxpr_to_func(
|
||||
return body.func_op
|
||||
|
||||
|
||||
class LoweringException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def jaxpr_subcomp(
|
||||
ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
|
||||
) -> Sequence[ir.Value]:
|
||||
@ -362,7 +367,21 @@ def jaxpr_subcomp(
|
||||
[v.aval for v in eqn.outvars],
|
||||
block_shapes,
|
||||
)
|
||||
ans = lowering_rules[eqn.primitive](rule_context, *invals, **eqn.params)
|
||||
try:
|
||||
ans = lowering_rules[eqn.primitive](
|
||||
rule_context, *invals, **eqn.params
|
||||
)
|
||||
except LoweringException:
|
||||
raise # We only add the extra info to the innermost exception.
|
||||
except Exception as e:
|
||||
raise LoweringException(
|
||||
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
|
||||
f" {rule_context}\nWith inval"
|
||||
f" shapes={map(lambda t: getattr(t, 'shape', None), invals)}\nWith"
|
||||
" inval"
|
||||
f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn"
|
||||
f" jaxpr:\n{jaxpr}"
|
||||
) from e
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented primitive in Pallas TPU lowering: "
|
||||
@ -829,6 +848,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions):
|
||||
raise NotImplementedError
|
||||
if any(d is None for d in new_sizes):
|
||||
raise NotImplementedError
|
||||
if not ctx.avals_in[0].shape:
|
||||
return vector.BroadcastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
|
||||
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
|
||||
|
||||
|
||||
@ -875,13 +896,13 @@ lowering_rules[lax.transpose_p] = _transpose_lowering_rule
|
||||
|
||||
|
||||
def _bcast(x, y, x_aval, y_aval, out_aval):
|
||||
if isinstance(x, (np.ndarray, np.uint32, int, float)):
|
||||
if isinstance(x, (np.ndarray, np.number, int, float)):
|
||||
if hasattr(y, "type") and y.type == ir.IndexType.get():
|
||||
mlir_type = y.type
|
||||
else:
|
||||
mlir_type = mlir.dtype_to_ir_type(x_aval.dtype)
|
||||
x = ir_constant(x, mlir_type)
|
||||
if isinstance(y, (np.ndarray, np.uint32, int, float)):
|
||||
if isinstance(y, (np.ndarray, np.number, int, float)):
|
||||
if hasattr(x, "type") and x.type == ir.IndexType.get():
|
||||
mlir_type = x.type
|
||||
else:
|
||||
@ -1045,7 +1066,8 @@ lowering_rules[lax.exp_p] = _exp_lowering_rule
|
||||
def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
if not isinstance(x, ir.Value) and x == 2.:
|
||||
return math.Exp2Op(y).result
|
||||
raise NotImplementedError("Only 2^x supported")
|
||||
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
||||
return math.PowFOp(x, y).result
|
||||
|
||||
|
||||
lowering_rules[lax.pow_p] = _pow_lowering_rule
|
||||
@ -1075,10 +1097,11 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
neg_x = arith.NegFOp(x).result
|
||||
exp_neg_x = math.ExpOp(neg_x).result
|
||||
aval_out = ctx.avals_out[0]
|
||||
out_type = ir.VectorType.get(
|
||||
aval_out.shape, mlir.dtype_to_ir_type(aval_out.dtype)
|
||||
)
|
||||
one = vector.BroadcastOp(out_type, ir_constant(1.0))
|
||||
out_type = aval_to_ir_type(aval_out)
|
||||
if aval_out.shape == ():
|
||||
one = ir_constant(1.0, mlir_type=out_type)
|
||||
else:
|
||||
one = vector.BroadcastOp(out_type, ir_constant(1.0))
|
||||
denom = arith.AddFOp(one, exp_neg_x).result
|
||||
return arith.DivFOp(one, denom).result
|
||||
|
||||
@ -1086,6 +1109,13 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
lowering_rules[lax.logistic_p] = _logistic_lowering_rule
|
||||
|
||||
|
||||
def _sin_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
return math.SinOp(x).result
|
||||
|
||||
|
||||
lowering_rules[lax.sin_p] = _sin_lowering_rule
|
||||
|
||||
|
||||
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
return math.TanhOp(x).result
|
||||
|
||||
@ -1179,6 +1209,20 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args):
|
||||
|
||||
lowering_rules[lax.select_n_p] = _select_n_lowering_rule
|
||||
|
||||
|
||||
def _clamp(min, operand, max):
|
||||
res = jnp.maximum(operand, min)
|
||||
return jnp.minimum(res, max)
|
||||
|
||||
|
||||
def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max):
|
||||
"""Compute minimum_p(maximum_p(min, operand), max)."""
|
||||
return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max)
|
||||
|
||||
|
||||
lowering_rules[lax.clamp_p] = _clamp_lowering_rule
|
||||
|
||||
|
||||
def _for_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
@ -1211,7 +1255,6 @@ def _for_lowering_rule(
|
||||
|
||||
|
||||
lowering_rules[for_loop.for_p] = _for_lowering_rule
|
||||
skip_mlir_conversions.add(for_loop.for_p)
|
||||
|
||||
|
||||
def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext,
|
||||
@ -1277,35 +1320,36 @@ skip_mlir_conversions.add(lax.scan_p)
|
||||
|
||||
|
||||
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):
|
||||
del linear
|
||||
if len(branches) > 2:
|
||||
raise NotImplementedError
|
||||
pred, *args = args
|
||||
index, *args = args
|
||||
out_types = map(aval_to_ir_type, ctx.avals_out)
|
||||
pred = arith.TruncIOp(
|
||||
aval_to_ir_type(jax_core.ShapedArray((), jnp.bool_)), pred
|
||||
pred = arith.CmpIOp(
|
||||
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
|
||||
).result
|
||||
# Specialize to singleton `if`s
|
||||
singleton = len(out_types) == 1
|
||||
if singleton:
|
||||
out_types = out_types[0]
|
||||
if_op = scf.IfOp(pred, out_types, hasElse=True)
|
||||
lowering_context = ctx.lowering_context.replace(
|
||||
block_shapes=ctx.block_shapes[1:],
|
||||
)
|
||||
with ir.InsertionPoint(if_op.then_block):
|
||||
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
|
||||
# TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of
|
||||
# if/else.
|
||||
if len(branches) > 2:
|
||||
out = _cond_lowering_rule(
|
||||
ctx,
|
||||
arith.SubIOp(index, ir_constant(1, index.type)).result,
|
||||
*args,
|
||||
branches=branches[1:],
|
||||
linear=linear,
|
||||
)
|
||||
else:
|
||||
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
|
||||
scf.YieldOp(out)
|
||||
with ir.InsertionPoint(if_op.else_block):
|
||||
out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args)
|
||||
scf.YieldOp(out)
|
||||
if singleton:
|
||||
return if_op.result
|
||||
return if_op.results
|
||||
|
||||
|
||||
lowering_rules[lax.cond_p] = _cond_lowering_rule
|
||||
skip_mlir_conversions.add(lax.cond_p)
|
||||
|
||||
|
||||
def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
||||
|
Loading…
x
Reference in New Issue
Block a user