[pallas:gpu] Simplify broadcast_to, min, max lowering.

PiperOrigin-RevId: 574204406
This commit is contained in:
Chris Jones 2023-10-17 11:00:13 -07:00 committed by jax authors
parent 2c9ea51ef2
commit c16b893600

View File

@ -398,12 +398,15 @@ _TRITON_FN_MAPPING = {
lax.ge_p: tl.semantic.greater_equal,
lax.lt_p: tl.semantic.less_than,
lax.le_p: tl.semantic.less_equal,
lax.max_p: tl.math.max,
lax.min_p: tl.math.min,
lax.shift_left_p: tl.semantic.shl,
lax.shift_right_arithmetic_p: tl.semantic.ashr,
lax.shift_right_logical_p: tl.semantic.lshr,
lax.nextafter_p: tl.math.nextafter,
ad_util.add_any_p: tl.semantic.add,
# Other ops.
indexing.broadcast_to_p: tl.broadcast_to,
primitives.atomic_cas_p: tl.atomic_cas,
primitives.max_contiguous_p: tl.max_contiguous,
primitives.multiple_of_p: tl.multiple_of,
@ -424,7 +427,8 @@ for primitive, fn in _TRITON_FN_MAPPING.items():
def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
return _min_lowering_rule(ctx, max_lowering_rule(ctx, min, operand), max)
operand = tl.math.max(operand, min, _builder=ctx.builder)
return tl.math.min(operand, max, _builder=ctx.builder)
triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
@ -476,14 +480,6 @@ def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
def _min_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
pred = a.__lt__(b, _builder=ctx.builder)
return tl.semantic.where(pred, a, b, ctx.builder)
triton_lowering_rules[lax.min_p] = _min_lowering_rule
def _convert_element_type_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
):
@ -497,14 +493,6 @@ triton_lowering_rules[lax.convert_element_type_p] = (
)
def max_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
pred = a.__gt__(b, _builder=ctx.builder)
return tl.semantic.where(pred, a, b, ctx.builder)
triton_lowering_rules[lax.max_p] = max_lowering_rule
def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b):
return tl.semantic.where(pred, b, a, ctx.builder)
@ -529,18 +517,6 @@ triton_lowering_rules[jax.lax.broadcast_in_dim_p] = (
)
def _broadcast_to_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, shape
):
shape = map(tl.constexpr, shape)
return tl.broadcast_to(a, shape, _builder=ctx.builder)
triton_lowering_rules[indexing.broadcast_to_p] = (
_broadcast_to_lowering_rule
)
def _squeeze_lowering_rule(ctx: TritonLoweringRuleContext, a, *, dimensions):
del dimensions
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None)