Don't use an out-of-line lowering for integer_pow for small powers.

This yields a smaller stablehlo output.

Add a fast path for y == 1 and y == -1, which turn out to be reasonably common.
This commit is contained in:
Peter Hawkins 2024-11-14 08:17:10 -08:00
parent aefe6215ca
commit 081eaeaacc

View File

@ -2633,24 +2633,24 @@ def _integer_pow(x, *, y):
def _integer_pow_lowering(ctx, x, *, y):
# These cases are subsumed by the general case, but it's faster to emit these
# common cases directly.
if y == 2:
if y == 1:
out = x
elif y == 2:
out = hlo.multiply(x, x)
elif y == 3:
out = hlo.multiply(hlo.multiply(x, x), x)
elif y == -1:
out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x)
else:
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
# TODO(b/217551391): emitting an out-of-line call leads to a large
# expansion when the MLIR is lowered to HLO, because the HLO lowering
# clones the callee. Consider unconditionally caching when the MLIR->HLO
# lowering doesn't expand the program.
lowering = mlir.cache_lowering(lowering)
out = lowering(ctx, x, y=y)
if builtins.abs(y) >= 3:
lowering = mlir.cache_lowering(lowering)
out, = lowering(ctx, x, y=y)
if config.sharding_in_types.value:
aval_out, = ctx.avals_out
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
out = out[0] if isinstance(out, list) else out
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return out if isinstance(out, list) else [out]
return [out]
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)