mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Removed unused _tan_impl
Also removed the legacy lowering for `tan_p`. PiperOrigin-RevId: 691195720
This commit is contained in:
parent
5ad066eeaa
commit
539c940946
@ -60,7 +60,6 @@ from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
_input_dtype, dtype_to_string, standard_abstract_eval,
|
||||
standard_multi_result_abstract_eval, standard_primitive)
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -2388,21 +2387,9 @@ cos_p = standard_unop(_float | _complex, 'cos')
|
||||
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
||||
mlir.register_lowering(cos_p, _cos_lowering)
|
||||
|
||||
@_upcast_fp16_for_computation
|
||||
def _tan_impl(x):
|
||||
return div(sin(x), cos(x))
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan')
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
||||
# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this
|
||||
# lowering is mostly supported, but it fails on export or with the PJRT plugin
|
||||
# because those modes target an older StableHLO version, and the
|
||||
# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't
|
||||
# included in the 0.4.33 release.
|
||||
if jaxlib_version <= (0, 4, 33):
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
|
||||
else:
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
|
||||
|
||||
def asin_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
|
Loading…
x
Reference in New Issue
Block a user