mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[MHLO] Switch tan to use CHLO lowering
Currently, it's desugared to sin(x)/cos(x) with upcast because CHLO_TanOp legalization doesn't support complex numbers. tan implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1175-1177;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e PiperOrigin-RevId: 443649394
This commit is contained in:
parent
636345fd67
commit
5f16873aad
@ -1643,7 +1643,11 @@ def _tan_impl(x):
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan')
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
||||
mlir.register_lowering(tan_p, mlir.lower_fun(_tan_impl, multiple_results=False))
|
||||
if jax._src.lib.mlir_api_version >= 11:
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))
|
||||
else:
|
||||
mlir.register_lowering(tan_p,
|
||||
mlir.lower_fun(_tan_impl, multiple_results=False))
|
||||
|
||||
def asin_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
|
@ -458,10 +458,8 @@ def main(_):
|
||||
print_ir(jnp.bfloat16(0))(lax.sqrt)
|
||||
|
||||
# CHECK-LABEL: TEST: tan float16[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK-SAME: tensor<f32>
|
||||
# CHECK: mhlo.cosine
|
||||
# CHECK-SAME: tensor<f32>
|
||||
# CHECK: chlo.tan
|
||||
# CHECK-SAME: tensor<f16>
|
||||
print_ir(np.float16(0))(lax.tan)
|
||||
|
||||
# CHECK-LABEL: TEST: tanh float32[]
|
||||
|
Loading…
x
Reference in New Issue
Block a user