mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Change jax to lower the asin and atan primitives to their corresponding chlo
ops. PiperOrigin-RevId: 466766999
This commit is contained in:
parent
e81578a9fa
commit
bb92038b6f
@ -1759,8 +1759,11 @@ def asin_impl(x):
|
||||
|
||||
asin_p = standard_unop(_float | _complex, 'asin')
|
||||
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
|
||||
mlir.register_lowering(asin_p, mlir.lower_fun(asin_impl,
|
||||
multiple_results=False))
|
||||
if jax._src.lib.mlir_api_version < 31:
|
||||
mlir.register_lowering(asin_p, mlir.lower_fun(asin_impl,
|
||||
multiple_results=False))
|
||||
else:
|
||||
mlir.register_lowering(asin_p, partial(_nary_lower_mhlo, chlo.AsinOp))
|
||||
|
||||
def acos_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
@ -1789,8 +1792,11 @@ def atan_impl(x):
|
||||
|
||||
atan_p = standard_unop(_float | _complex, 'atan')
|
||||
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
|
||||
mlir.register_lowering(atan_p, mlir.lower_fun(atan_impl,
|
||||
multiple_results=False))
|
||||
if jax._src.lib.mlir_api_version < 31:
|
||||
mlir.register_lowering(atan_p, mlir.lower_fun(atan_impl,
|
||||
multiple_results=False))
|
||||
else:
|
||||
mlir.register_lowering(atan_p, partial(_nary_lower_mhlo, chlo.AtanOp))
|
||||
|
||||
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
|
||||
ad.defjvp(atan2_p,
|
||||
|
@ -211,8 +211,10 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def asin(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4),
|
||||
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12),
|
||||
custom_numeric(dtypes=np.complex64, devices=("cpu", "gpu"), tol=1e-4,
|
||||
modes=("eager", "graph", "compiled")),
|
||||
custom_numeric(dtypes=np.complex128, devices=("cpu", "gpu"), tol=1e-12,
|
||||
modes=("eager", "graph", "compiled")),
|
||||
cls.helper_get_trig_custom_limitation(np.sin)
|
||||
]
|
||||
|
||||
|
@ -51,7 +51,7 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.acosh)
|
||||
|
||||
# CHECK-LABEL: TEST: asin float32[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK: chlo.asin
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1))(lax.asin)
|
||||
|
||||
@ -61,7 +61,7 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.asinh)
|
||||
|
||||
# CHECK-LABEL: TEST: atan float32[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK: chlo.atan
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1))(lax.atan)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user