Merge pull request #24933 from hawkinsp:pow

PiperOrigin-RevId: 697622037
This commit is contained in:
jax authors 2024-11-18 07:38:56 -08:00
commit afdc79271c

View File

@ -2574,15 +2574,12 @@ ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
def _pow_lower(ctx, x, y):
x_aval, y_aval = ctx.avals_in
out_aval, = ctx.avals_out
convert = mlir.lower_fun(
partial(convert_element_type, new_dtype=out_aval.dtype), False)
x_aval_ = x_aval.update(dtype=out_aval.dtype)
y_aval_ = y_aval.update(dtype=out_aval.dtype)
[x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
[y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
if x_aval.dtype != y_aval.dtype:
out_aval, = ctx.avals_out
y_aval = y_aval.update(dtype=out_aval.dtype)
y = hlo.convert(mlir.aval_to_ir_type(y_aval), y)
ctx = ctx.replace(avals_in=[x_aval, y_aval])
return _nary_lower_hlo(hlo.power, ctx, x, y)
mlir.register_lowering(pow_p, _pow_lower)
def _integer_pow_dtype_rule(x, *, y):