Upgrade logistic (sigmoid) function into a lax primitive.

This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
This commit is contained in:
Peter Hawkins 2022-08-26 11:57:54 -07:00 committed by jax authors
parent 45764ea9a8
commit f68f1c0cd0
14 changed files with 42 additions and 20 deletions

View File

@ -103,6 +103,7 @@ Operators
lgamma
log
log1p
logistic
max
min
mul

View File

@ -300,6 +300,10 @@ def tanh(x: Array) -> Array:
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
return tanh_p.bind(x)
def logistic(x: Array) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
return logistic_p.bind(x)
def sin(x: Array) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
return sin_p.bind(x)
@ -1730,6 +1734,10 @@ ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
sub(_one(x), ans)))
mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp))
logistic_p = standard_unop(_float | _complex, 'logistic')
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp))
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
if mlir_api_version < 27:

View File

@ -69,6 +69,7 @@ asinh = np.arcsinh
acosh = np.arccosh
atanh = np.arctanh
def logistic(x): return 1 / (1 + np.exp(-x))
def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype)
def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype)
def digamma(x): return scipy.special.digamma(x).astype(x.dtype)

View File

@ -91,7 +91,7 @@ def sigmoid(x: Array) -> Array:
Args:
x : input array
"""
return expit(x)
return lax.logistic(x)
@jax.jit
def silu(x: Array) -> Array:

View File

@ -96,13 +96,10 @@ logit.defjvps(
lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))
@api.custom_jvp
@_wraps(osp_special.expit, module='scipy.special', update_doc=False)
def expit(x):
x, = _promote_args_inexact("expit", x)
one = _lax_const(x, 1)
return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans))
return lax.logistic(x)
@_wraps(osp_special.logsumexp, module='scipy.special')

View File

@ -1352,6 +1352,8 @@ tf_impl_with_avals[lax.asin_p] = _convert_jax_impl(
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
lax_internal.atan_impl, multiple_results=False)
tf_impl[lax.logistic_p] = tf.math.sigmoid
def _atan2(y, x, **kwargs):
if x.dtype.is_complex or y.dtype.is_complex:
complex_component_dtype = {

View File

@ -130,14 +130,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
"random_categorical", "random_uniform", "random_randint",
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
"reshape", "rev", "rsqrt", "select_n", "select_and_scatter_add",
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
"tie_in", "transpose", "xor", "zeros_like"
"logistic", "lt", "log", "mul", "ne", "neg", "not", "or", "pad",
"population_count", "random_categorical", "random_uniform",
"random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or",
"reduce_sum", "reduce_window_mul", "reduce_window_min",
"reduce_window_max", "real", "reshape", "rev", "rsqrt", "select_n",
"select_and_scatter_add", "shift_left", "shift_right_logical",
"shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt",
"squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor",
"zeros_like"
}
@classmethod

View File

@ -432,6 +432,7 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype)
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype)
for dtype in jtu.dtypes.all_floating:
_make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype)

View File

@ -498,11 +498,11 @@ def _integer_pow_taylor(primals_in, series_in, *, y):
jet_rules[lax.integer_pow_p] = _integer_pow_taylor
def _expit_taylor(primals_in, series_in):
def _logistic_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [jax.scipy.special.expit(x)] + [None] * len(series)
v = [lax.logistic(x)] + [None] * len(series)
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1))
@ -511,12 +511,15 @@ def _expit_taylor(primals_in, series_in):
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.logistic_p] = _logistic_taylor
def _tanh_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [2*x] + [2 * series_ for series_ in series]
primals_in, *series_in = u
primal_out, series_out = _expit_taylor((primals_in, ), (series_in, ))
primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, ))
series_out = [2 * series_ for series_ in series_out]
return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor

View File

@ -139,6 +139,8 @@ from jax._src.lax.lax import (
log1p as log1p,
log1p_p as log1p_p,
log_p as log_p,
logistic as logistic,
logistic_p as logistic_p,
lt as lt,
lt_p as lt_p,
make_bint as make_bint,

View File

@ -188,7 +188,7 @@ class JetTest(jtu.JaxTestCase):
primals = (primal_in, )
series = (terms_in, )
y, terms = jax.experimental.jet._expit_taylor(primals, series)
y, terms = jax.experimental.jet._logistic_taylor(primals, series)
expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series)
atol = 1e-4
@ -283,7 +283,7 @@ class JetTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-100, 100], order=5)
def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5)
@jtu.skip_on_devices("tpu")
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
@jtu.skip_on_devices("tpu")

View File

@ -134,6 +134,9 @@ LAX_GRAD_OPS = [
dtypes=grad_complex_dtypes),
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes, tol={np.float64: 3e-5}),
grad_test_spec(lax.logistic, nargs=1, order=2,
rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),

View File

@ -114,6 +114,7 @@ LAX_OPS = [
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small,
{np.float64: 1e-9, np.complex128: 1e-7}),
op_record("logistic", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default),
op_record("atan2", 2, float_dtypes, jtu.rand_default),
@ -2958,6 +2959,8 @@ class LazyConstantTest(jtu.JaxTestCase):
for op, dtypes in unary_op_types.items())
def testUnaryWeakTypes(self, op_name, rec_dtypes):
"""Test that all lax unary ops propagate weak_type information appropriately."""
if op_name == "bitwise_not":
raise unittest.SkipTest("https://github.com/google/jax/issues/12066")
# Find a valid dtype for the function.
for dtype in [np.float_, np.int_, np.complex_, np.bool_]:
dtype = dtypes.canonicalize_dtype(dtype)

View File

@ -351,7 +351,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
tol=3e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
@ -397,7 +397,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
return list(map(rng, shapes, dtypes))
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
tol=2e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)