From 3e3542b0d696155751e77100d6c8a97e858fd8dd Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 24 Aug 2022 15:39:02 -0700 Subject: [PATCH] Upgrade logistic (sigmoid) function into a lax primitive. This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code. PiperOrigin-RevId: 469841487 --- docs/jax.lax.rst | 1 - jax/_src/lax/lax.py | 8 -------- jax/_src/lax_reference.py | 1 - jax/_src/nn/functions.py | 2 +- jax/_src/scipy/special.py | 5 ++++- jax/experimental/jax2tf/jax2tf.py | 2 -- .../jax2tf/tests/jax2tf_limitations.py | 18 +++++++++--------- .../jax2tf/tests/primitive_harness.py | 1 - jax/experimental/jet.py | 9 +++------ jax/lax/__init__.py | 2 -- tests/jet_test.py | 4 ++-- tests/lax_autodiff_test.py | 3 --- tests/lax_test.py | 3 --- tests/scipy_stats_test.py | 4 ++-- 14 files changed, 21 insertions(+), 42 deletions(-) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 8c77633ba..4cbc040c2 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -103,7 +103,6 @@ Operators lgamma log log1p - logistic max min mul diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a3a1fd316..fcadb6477 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -300,10 +300,6 @@ def tanh(x: Array) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.""" return tanh_p.bind(x) -def logistic(x: Array) -> Array: - r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.""" - return logistic_p.bind(x) - def sin(x: Array) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`.""" return sin_p.bind(x) @@ -1734,10 +1730,6 @@ ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), sub(_one(x), ans))) mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp)) -logistic_p = standard_unop(_float | _complex, 'logistic') -ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) -mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp)) - sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) if mlir_api_version < 27: diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index 30b347181..5c4a597d6 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -69,7 +69,6 @@ asinh = np.arcsinh acosh = np.arccosh atanh = np.arctanh -def logistic(x): return 1 / (1 + np.exp(-x)) def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype) def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype) def digamma(x): return scipy.special.digamma(x).astype(x.dtype) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 7fba8f2a7..f8958e9f8 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -91,7 +91,7 @@ def sigmoid(x: Array) -> Array: Args: x : input array """ - return lax.logistic(x) + return expit(x) @jax.jit def silu(x: Array) -> Array: diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 6d1576e64..7524721d5 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -96,10 +96,13 @@ logit.defjvps( lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x)))) +@api.custom_jvp @_wraps(osp_special.expit, module='scipy.special', update_doc=False) def expit(x): x, = _promote_args_inexact("expit", x) - return lax.logistic(x) + one = _lax_const(x, 1) + return lax.div(one, lax.add(one, lax.exp(lax.neg(x)))) +expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans)) @_wraps(osp_special.logsumexp, module='scipy.special') diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 684947132..940c677cb 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1339,8 +1339,6 @@ tf_impl_with_avals[lax.asin_p] = _convert_jax_impl( tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) -tf_impl[lax.logistic_p] = tf.math.sigmoid - def _atan2(y, x, **kwargs): if x.dtype.is_complex or y.dtype.is_complex: complex_component_dtype = { diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index ae2ddbe35..4e94c146b 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -130,15 +130,15 @@ class Jax2TfLimitation(primitive_harness.Limitation): "cos", "cosh", "complex", "conj", "convert_element_type", "cummax", "cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le", - "logistic", "lt", "log", "mul", "ne", "neg", "not", "or", "pad", - "population_count", "random_categorical", "random_uniform", - "random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or", - "reduce_sum", "reduce_window_mul", "reduce_window_min", - "reduce_window_max", "real", "reshape", "rev", "rsqrt", "scatter_max", - "scatter_min", "select_n", "select_and_scatter_add", "shift_left", - "shift_right_logical", "shift_right_arithmetic", "sign", "sin", "sinh", - "slice", "sqrt", "squeeze", "stop_gradient", "sub", "tie_in", "transpose", - "xor", "zeros_like" + "lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count", + "random_categorical", "random_uniform", "random_randint", + "reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", + "reduce_window_mul", "reduce_window_min", "reduce_window_max", "real", + "reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n", + "select_and_scatter_add", "shift_left", "shift_right_logical", + "shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt", + "squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor", + "zeros_like" } @classmethod diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index 6e8626dad..b646d8cb0 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -432,7 +432,6 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index f3f36f031..7ad54c6ae 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -498,11 +498,11 @@ def _integer_pow_taylor(primals_in, series_in, *, y): jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _logistic_taylor(primals_in, series_in): +def _expit_taylor(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series - v = [lax.logistic(x)] + [None] * len(series) + v = [jax.scipy.special.expit(x)] + [None] * len(series) e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)) @@ -511,15 +511,12 @@ def _logistic_taylor(primals_in, series_in): primal_out, *series_out = v return primal_out, series_out -jet_rules[lax.logistic_p] = _logistic_taylor - - def _tanh_taylor(primals_in, series_in): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] primals_in, *series_in = u - primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, )) + primal_out, series_out = _expit_taylor((primals_in, ), (series_in, )) series_out = [2 * series_ for series_ in series_out] return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 53480a66f..c10f4e6ab 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -139,8 +139,6 @@ from jax._src.lax.lax import ( log1p as log1p, log1p_p as log1p_p, log_p as log_p, - logistic as logistic, - logistic_p as logistic_p, lt as lt, lt_p as lt_p, make_bint as make_bint, diff --git a/tests/jet_test.py b/tests/jet_test.py index 71b4ef12a..d1d8d0a5d 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -188,7 +188,7 @@ class JetTest(jtu.JaxTestCase): primals = (primal_in, ) series = (terms_in, ) - y, terms = jax.experimental.jet._logistic_taylor(primals, series) + y, terms = jax.experimental.jet._expit_taylor(primals, series) expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series) atol = 1e-4 @@ -283,7 +283,7 @@ class JetTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5) @jtu.skip_on_devices("tpu") - def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5) + def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-100, 100], order=5) @jtu.skip_on_devices("tpu") def test_expit2(self): self.expit_check(lims=[-500, 500], order=5) @jtu.skip_on_devices("tpu") diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 67f9c8db4..f62e6ae6f 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -134,9 +134,6 @@ LAX_GRAD_OPS = [ dtypes=grad_complex_dtypes), grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default, dtypes=grad_float_dtypes, tol={np.float64: 3e-5}), - grad_test_spec(lax.logistic, nargs=1, order=2, - rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default, dtypes=grad_inexact_dtypes), diff --git a/tests/lax_test.py b/tests/lax_test.py index c6bc389e9..5cdd3c425 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -114,7 +114,6 @@ LAX_OPS = [ # TODO(b/143135720): on GPU, tanh has only ~float32 precision. op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small, {np.float64: 1e-9, np.complex128: 1e-7}), - op_record("logistic", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("atan2", 2, float_dtypes, jtu.rand_default), @@ -2959,8 +2958,6 @@ class LazyConstantTest(jtu.JaxTestCase): for op, dtypes in unary_op_types.items())) def testUnaryWeakTypes(self, op_name, rec_dtypes): """Test that all lax unary ops propagate weak_type information appropriately.""" - if op_name == "bitwise_not": - raise unittest.SkipTest("https://github.com/google/jax/issues/12066") # Find a valid dtype for the function. for dtype in [np.float_, np.int_, np.complex_, np.bool_]: dtype = dtypes.canonicalize_dtype(dtype) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index d4d8cefd3..841c27780 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -351,7 +351,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=3e-5) + tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) @@ -397,7 +397,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=2e-5) + tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3)