diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index bb80f7630..9878ddc3c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3723,10 +3723,11 @@ def _sin_complex(x): # 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2 a, b = real(x), imag(x) a_is_zero = eq(a, _const(a, 0)) + two = _const(a, 2) sn, cs = sin(a), cos(a) - e1m, e2m = expm1(b), expm1(-b) - snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 - re, im = sn * csh, cs * snh + e1m, e2m = expm1(b), expm1(neg(b)) + snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two) + re, im = mul(sn, csh), mul(cs, snh) # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) @@ -3752,10 +3753,11 @@ def _cos_complex(x): # see also _sin_complex a, b = real(x), imag(x) a_is_zero = eq(a, _const(a, 0)) + two = _const(a, 2) sn, cs = sin(a), cos(a) - e1m, e2m = expm1(b), expm1(-b) - snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 - re, im = cs * csh, -sn * snh + e1m, e2m = expm1(b), expm1(neg(b)) + snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two) + re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) def _cos_lowering(ctx, x): @@ -3769,28 +3771,28 @@ ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) mlir.register_lowering(cos_p, _cos_lowering) tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) +ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) asin_p = standard_unop(_float | _complex, 'asin') -ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) +ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin)) acos_p = standard_unop(_float | _complex, 'acos') -ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) +ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x)))))) mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos)) def atan_impl(x): return atan2(x, _const(x, 1)) atan_p = standard_unop(_float | _complex, 'atan') -ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x))) +ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x)))) mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan)) atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2') ad.defjvp(atan2_p, - lambda g, x, y: g * (y / (square(x) + square(y))), - lambda g, x, y: g * -x / (square(x) + square(y))) + lambda g, x, y: mul(g, div(y, add(square(x), square(y)))), + lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y))))) mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2)) sinh_p = standard_unop(_float | _complex, 'sinh') @@ -3802,17 +3804,17 @@ ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x))) mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh)) asinh_p = standard_unop(_float | _complex, 'asinh') -ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x)))) +ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x))))) mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh)) acosh_p = standard_unop(_float | _complex, 'acosh') ad.defjvp(acosh_p, - lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x))))) + lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x)))))) mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh)) atanh_p = standard_unop(_float | _complex, 'atanh') ad.defjvp(atanh_p, - lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x)))) + lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x)))) mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh)) real_p = unop(_complex_basetype, _complex, 'real') @@ -3906,11 +3908,11 @@ def _square_complex(x): a, b = real(x), imag(x) # zero square(x).real is handled explicitly for abs(a)==abs(b) cases # where for finite a, 2 * a is non-finite: - zero_re = is_finite(a) & (eq(a, b) | eq(a, -b)) + zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b))) # equivalent to a**2 - b**2 but avoids overflow errors for large a # and large b cases: - re = (a - b) * (a + b) - im = a * b * 2 + re = mul(sub(a, b), add(a, b)) + im = mul(mul(a, b), _const(a, 2)) return select(zero_re, complex(_const(a, 0), im), complex(re, im)) def _square_lower_hlo(ctx, x): @@ -5276,7 +5278,7 @@ def _ragged_dot_jvp_rule( if type(dy) is not ad_util.Zero else _zeros(primal_out) ) - tangent_out = dx_out + dy_out + tangent_out = add(dx_out, dy_out) return primal_out, tangent_out