mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Don't use an out-of-line lowering for integer_pow for small powers.
This yields a smaller stablehlo output. Add a fast path for y == 1 and y == -1, which turn out to be reasonably common.
This commit is contained in:
parent
aefe6215ca
commit
081eaeaacc
@ -2633,24 +2633,24 @@ def _integer_pow(x, *, y):
|
||||
def _integer_pow_lowering(ctx, x, *, y):
|
||||
# These cases are subsumed by the general case, but it's faster to emit these
|
||||
# common cases directly.
|
||||
if y == 2:
|
||||
if y == 1:
|
||||
out = x
|
||||
elif y == 2:
|
||||
out = hlo.multiply(x, x)
|
||||
elif y == 3:
|
||||
out = hlo.multiply(hlo.multiply(x, x), x)
|
||||
elif y == -1:
|
||||
out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x)
|
||||
else:
|
||||
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
|
||||
# TODO(b/217551391): emitting an out-of-line call leads to a large
|
||||
# expansion when the MLIR is lowered to HLO, because the HLO lowering
|
||||
# clones the callee. Consider unconditionally caching when the MLIR->HLO
|
||||
# lowering doesn't expand the program.
|
||||
lowering = mlir.cache_lowering(lowering)
|
||||
out = lowering(ctx, x, y=y)
|
||||
if builtins.abs(y) >= 3:
|
||||
lowering = mlir.cache_lowering(lowering)
|
||||
out, = lowering(ctx, x, y=y)
|
||||
if config.sharding_in_types.value:
|
||||
aval_out, = ctx.avals_out
|
||||
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
||||
out = out[0] if isinstance(out, list) else out
|
||||
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
|
||||
return out if isinstance(out, list) else [out]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user