mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas:gpu] Simplify broadcast_to
, min
, max
lowering.
PiperOrigin-RevId: 574204406
This commit is contained in:
parent
2c9ea51ef2
commit
c16b893600
@ -398,12 +398,15 @@ _TRITON_FN_MAPPING = {
|
||||
lax.ge_p: tl.semantic.greater_equal,
|
||||
lax.lt_p: tl.semantic.less_than,
|
||||
lax.le_p: tl.semantic.less_equal,
|
||||
lax.max_p: tl.math.max,
|
||||
lax.min_p: tl.math.min,
|
||||
lax.shift_left_p: tl.semantic.shl,
|
||||
lax.shift_right_arithmetic_p: tl.semantic.ashr,
|
||||
lax.shift_right_logical_p: tl.semantic.lshr,
|
||||
lax.nextafter_p: tl.math.nextafter,
|
||||
ad_util.add_any_p: tl.semantic.add,
|
||||
# Other ops.
|
||||
indexing.broadcast_to_p: tl.broadcast_to,
|
||||
primitives.atomic_cas_p: tl.atomic_cas,
|
||||
primitives.max_contiguous_p: tl.max_contiguous,
|
||||
primitives.multiple_of_p: tl.multiple_of,
|
||||
@ -424,7 +427,8 @@ for primitive, fn in _TRITON_FN_MAPPING.items():
|
||||
|
||||
|
||||
def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
|
||||
return _min_lowering_rule(ctx, max_lowering_rule(ctx, min, operand), max)
|
||||
operand = tl.math.max(operand, min, _builder=ctx.builder)
|
||||
return tl.math.min(operand, max, _builder=ctx.builder)
|
||||
|
||||
|
||||
triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
|
||||
@ -476,14 +480,6 @@ def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
|
||||
triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
||||
|
||||
|
||||
def _min_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
||||
pred = a.__lt__(b, _builder=ctx.builder)
|
||||
return tl.semantic.where(pred, a, b, ctx.builder)
|
||||
|
||||
|
||||
triton_lowering_rules[lax.min_p] = _min_lowering_rule
|
||||
|
||||
|
||||
def _convert_element_type_lowering_rule(
|
||||
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
|
||||
):
|
||||
@ -497,14 +493,6 @@ triton_lowering_rules[lax.convert_element_type_p] = (
|
||||
)
|
||||
|
||||
|
||||
def max_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
||||
pred = a.__gt__(b, _builder=ctx.builder)
|
||||
return tl.semantic.where(pred, a, b, ctx.builder)
|
||||
|
||||
|
||||
triton_lowering_rules[lax.max_p] = max_lowering_rule
|
||||
|
||||
|
||||
def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b):
|
||||
return tl.semantic.where(pred, b, a, ctx.builder)
|
||||
|
||||
@ -529,18 +517,6 @@ triton_lowering_rules[jax.lax.broadcast_in_dim_p] = (
|
||||
)
|
||||
|
||||
|
||||
def _broadcast_to_lowering_rule(
|
||||
ctx: TritonLoweringRuleContext, a, *, shape
|
||||
):
|
||||
shape = map(tl.constexpr, shape)
|
||||
return tl.broadcast_to(a, shape, _builder=ctx.builder)
|
||||
|
||||
|
||||
triton_lowering_rules[indexing.broadcast_to_p] = (
|
||||
_broadcast_to_lowering_rule
|
||||
)
|
||||
|
||||
|
||||
def _squeeze_lowering_rule(ctx: TritonLoweringRuleContext, a, *, dimensions):
|
||||
del dimensions
|
||||
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None)
|
||||
|
Loading…
x
Reference in New Issue
Block a user