mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24933 from hawkinsp:pow
PiperOrigin-RevId: 697622037
This commit is contained in:
commit
afdc79271c
@ -2574,15 +2574,12 @@ ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
||||
|
||||
def _pow_lower(ctx, x, y):
|
||||
x_aval, y_aval = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
convert = mlir.lower_fun(
|
||||
partial(convert_element_type, new_dtype=out_aval.dtype), False)
|
||||
x_aval_ = x_aval.update(dtype=out_aval.dtype)
|
||||
y_aval_ = y_aval.update(dtype=out_aval.dtype)
|
||||
[x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
|
||||
[y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
|
||||
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
|
||||
return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
|
||||
if x_aval.dtype != y_aval.dtype:
|
||||
out_aval, = ctx.avals_out
|
||||
y_aval = y_aval.update(dtype=out_aval.dtype)
|
||||
y = hlo.convert(mlir.aval_to_ir_type(y_aval), y)
|
||||
ctx = ctx.replace(avals_in=[x_aval, y_aval])
|
||||
return _nary_lower_hlo(hlo.power, ctx, x, y)
|
||||
mlir.register_lowering(pow_p, _pow_lower)
|
||||
|
||||
def _integer_pow_dtype_rule(x, *, y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user