Merge pull request #27032 from dfm:lax-dtype

PiperOrigin-RevId: 735424674
This commit is contained in:
jax authors 2025-03-10 10:18:58 -07:00
commit 14b215fe76

View File

@ -3723,10 +3723,11 @@ def _sin_complex(x):
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
a, b = real(x), imag(x)
a_is_zero = eq(a, _const(a, 0))
two = _const(a, 2)
sn, cs = sin(a), cos(a)
e1m, e2m = expm1(b), expm1(-b)
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
re, im = sn * csh, cs * snh
e1m, e2m = expm1(b), expm1(neg(b))
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
re, im = mul(sn, csh), mul(cs, snh)
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
@ -3752,10 +3753,11 @@ def _cos_complex(x):
# see also _sin_complex
a, b = real(x), imag(x)
a_is_zero = eq(a, _const(a, 0))
two = _const(a, 2)
sn, cs = sin(a), cos(a)
e1m, e2m = expm1(b), expm1(-b)
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
re, im = cs * csh, -sn * snh
e1m, e2m = expm1(b), expm1(neg(b))
snh, csh = div(sub(e1m, e2m), two), div(add(add(e1m, e2m), two), two)
re, im = mul(cs, csh), mul(neg(sn), snh)
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
def _cos_lowering(ctx, x):
@ -3769,28 +3771,28 @@ ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
mlir.register_lowering(cos_p, _cos_lowering)
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans))))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan))
asin_p = standard_unop(_float | _complex, 'asin')
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x)))))
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
acos_p = standard_unop(_float | _complex, 'acos')
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
ad.defjvp(acos_p, lambda g, x: mul(g, neg(rsqrt(sub(_const(x, 1), square(x))))))
mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos))
def atan_impl(x):
return atan2(x, _const(x, 1))
atan_p = standard_unop(_float | _complex, 'atan')
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
ad.defjvp(atan_p, lambda g, x: div(g, add(_const(x, 1), square(x))))
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
ad.defjvp(atan2_p,
lambda g, x, y: g * (y / (square(x) + square(y))),
lambda g, x, y: g * -x / (square(x) + square(y)))
lambda g, x, y: mul(g, div(y, add(square(x), square(y)))),
lambda g, x, y: mul(g, div(neg(x), add(square(x), square(y)))))
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
sinh_p = standard_unop(_float | _complex, 'sinh')
@ -3802,17 +3804,17 @@ ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
asinh_p = standard_unop(_float | _complex, 'asinh')
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(add(square(x), _one(x)))))
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
acosh_p = standard_unop(_float | _complex, 'acosh')
ad.defjvp(acosh_p,
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x))))))
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
atanh_p = standard_unop(_float | _complex, 'atanh')
ad.defjvp(atanh_p,
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
lambda g, x: mul(reciprocal(add(_one(x), x)), div(g, sub(_one(x), x))))
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
real_p = unop(_complex_basetype, _complex, 'real')
@ -3906,11 +3908,11 @@ def _square_complex(x):
a, b = real(x), imag(x)
# zero square(x).real is handled explicitly for abs(a)==abs(b) cases
# where for finite a, 2 * a is non-finite:
zero_re = is_finite(a) & (eq(a, b) | eq(a, -b))
zero_re = is_finite(a) & (eq(a, b) | eq(a, neg(b)))
# equivalent to a**2 - b**2 but avoids overflow errors for large a
# and large b cases:
re = (a - b) * (a + b)
im = a * b * 2
re = mul(sub(a, b), add(a, b))
im = mul(mul(a, b), _const(a, 2))
return select(zero_re, complex(_const(a, 0), im), complex(re, im))
def _square_lower_hlo(ctx, x):
@ -5276,7 +5278,7 @@ def _ragged_dot_jvp_rule(
if type(dy) is not ad_util.Zero
else _zeros(primal_out)
)
tangent_out = dx_out + dy_out
tangent_out = add(dx_out, dy_out)
return primal_out, tangent_out