Change jax to lower the asin and atan primitives to their corresponding chlo

ops.

PiperOrigin-RevId: 466766999
This commit is contained in:
Bixia Zheng 2022-08-10 13:04:46 -07:00 committed by jax authors
parent e81578a9fa
commit bb92038b6f
3 changed files with 16 additions and 8 deletions

View File

@ -1759,8 +1759,11 @@ def asin_impl(x):
asin_p = standard_unop(_float | _complex, 'asin')
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
mlir.register_lowering(asin_p, mlir.lower_fun(asin_impl,
multiple_results=False))
if jax._src.lib.mlir_api_version < 31:
mlir.register_lowering(asin_p, mlir.lower_fun(asin_impl,
multiple_results=False))
else:
mlir.register_lowering(asin_p, partial(_nary_lower_mhlo, chlo.AsinOp))
def acos_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
@ -1789,8 +1792,11 @@ def atan_impl(x):
atan_p = standard_unop(_float | _complex, 'atan')
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
mlir.register_lowering(atan_p, mlir.lower_fun(atan_impl,
multiple_results=False))
if jax._src.lib.mlir_api_version < 31:
mlir.register_lowering(atan_p, mlir.lower_fun(atan_impl,
multiple_results=False))
else:
mlir.register_lowering(atan_p, partial(_nary_lower_mhlo, chlo.AtanOp))
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
ad.defjvp(atan2_p,

View File

@ -211,8 +211,10 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def asin(cls, harness: primitive_harness.Harness):
return [
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4),
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4,
modes=("eager", "graph", "compiled")),
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12,
modes=("eager", "graph", "compiled")),
cls.helper_get_trig_custom_limitation(np.sin)
]

View File

@ -51,7 +51,7 @@ def main(_):
print_ir(np.float32(0))(lax.acosh)
# CHECK-LABEL: TEST: asin float32[]
# CHECK: mhlo.atan2
# CHECK: chlo.asin
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1))(lax.asin)
@ -61,7 +61,7 @@ def main(_):
print_ir(np.float32(0))(lax.asinh)
# CHECK-LABEL: TEST: atan float32[]
# CHECK: mhlo.atan2
# CHECK: chlo.atan
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1))(lax.atan)