Removed unused _tan_impl

Also removed the legacy lowering for `tan_p`.

PiperOrigin-RevId: 691195720
This commit is contained in:
Sergei Lebedev 2024-10-29 16:08:23 -07:00 committed by jax authors
parent 5ad066eeaa
commit 539c940946

View File

@ -60,7 +60,6 @@ from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype, dtype_to_string, standard_abstract_eval,
standard_multi_result_abstract_eval, standard_primitive)
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@ -2388,21 +2387,9 @@ cos_p = standard_unop(_float | _complex, 'cos')
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
mlir.register_lowering(cos_p, _cos_lowering)
@_upcast_fp16_for_computation
def _tan_impl(x):
return div(sin(x), cos(x))
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this
# lowering is mostly supported, but it fails on export or with the PJRT plugin
# because those modes target an older StableHLO version, and the
# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't
# included in the 0.4.33 release.
if jaxlib_version <= (0, 4, 33):
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
else:
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):