From 57b5acf1b60470ff06368b806a0e117b2ac7435d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 7 Sep 2022 06:06:22 -0700 Subject: [PATCH] Roll forward: Upgrade logistic into a primitive. Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change. PiperOrigin-RevId: 472705623 --- docs/jax.lax.rst | 1 + jax/_src/lax/lax.py | 16 ++++++++++++++++ jax/_src/lax_reference.py | 1 + jax/_src/nn/functions.py | 2 +- jax/_src/scipy/special.py | 5 +---- jax/experimental/jax2tf/jax2tf.py | 4 ++++ .../jax2tf/tests/jax2tf_limitations.py | 17 +++++++++-------- .../jax2tf/tests/primitive_harness.py | 1 + jax/experimental/jet.py | 9 ++++++--- jax/lax/__init__.py | 2 ++ tests/jet_test.py | 4 ++-- tests/lax_autodiff_test.py | 3 +++ tests/lax_test.py | 3 +++ tests/scipy_stats_test.py | 4 ++-- 14 files changed, 52 insertions(+), 20 deletions(-) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 4cbc040c2..8c77633ba 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -103,6 +103,7 @@ Operators lgamma log log1p + logistic max min mul diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index bb2dca4aa..c965a3c37 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -300,6 +300,10 @@ def tanh(x: Array) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.""" return tanh_p.bind(x) +def logistic(x: Array) -> Array: + r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.""" + return logistic_p.bind(x) + def sin(x: Array) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`.""" return sin_p.bind(x) @@ -1736,6 +1740,18 @@ ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), sub(_one(x), ans))) mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp)) +logistic_p = standard_unop(_float | _complex, 'logistic') +ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) +# TODO(phawkins): switch to mhlo.logistic lowering; debug numerical problems. +# mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp)) + +def logistic_impl(x): + one = _const(x, 1) + return div(one, add(one, exp(neg(x)))) + +mlir.register_lowering(logistic_p, + mlir.lower_fun(logistic_impl, multiple_results=False)) + sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) if mlir_api_version < 27: diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index 5c4a597d6..30b347181 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -69,6 +69,7 @@ asinh = np.arcsinh acosh = np.arccosh atanh = np.arctanh +def logistic(x): return 1 / (1 + np.exp(-x)) def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype) def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype) def digamma(x): return scipy.special.digamma(x).astype(x.dtype) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index f8958e9f8..7fba8f2a7 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -91,7 +91,7 @@ def sigmoid(x: Array) -> Array: Args: x : input array """ - return expit(x) + return lax.logistic(x) @jax.jit def silu(x: Array) -> Array: diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 7524721d5..6d1576e64 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -96,13 +96,10 @@ logit.defjvps( lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x)))) -@api.custom_jvp @_wraps(osp_special.expit, module='scipy.special', update_doc=False) def expit(x): x, = _promote_args_inexact("expit", x) - one = _lax_const(x, 1) - return lax.div(one, lax.add(one, lax.exp(lax.neg(x)))) -expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans)) + return lax.logistic(x) @_wraps(osp_special.logsumexp, module='scipy.special') diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8212a3133..2d8426eb7 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1366,6 +1366,10 @@ tf_impl_with_avals[lax.asin_p] = _convert_jax_impl( tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) +# TODO(phawkins): use tf.math.sigmoid here instead. +tf_impl_with_avals[lax.logistic_p] = _convert_jax_impl( + lax_internal.logistic_impl, multiple_results=False) + def _atan2(y, x, **kwargs): if x.dtype.is_complex or y.dtype.is_complex: complex_component_dtype = { diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 0b0bed067..207b1b9ee 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -130,14 +130,15 @@ class Jax2TfLimitation(primitive_harness.Limitation): "cos", "cosh", "complex", "conj", "convert_element_type", "cummax", "cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le", - "lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count", - "random_categorical", "random_uniform", "random_randint", - "reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", - "reduce_window_mul", "reduce_window_min", "reduce_window_max", "real", - "reshape", "rev", "rsqrt", "select_n", "select_and_scatter_add", - "shift_left", "shift_right_logical", "shift_right_arithmetic", "sign", - "sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub", - "tie_in", "transpose", "xor", "zeros_like" + "logistic", "lt", "log", "mul", "ne", "neg", "not", "or", "pad", + "population_count", "random_categorical", "random_uniform", + "random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or", + "reduce_sum", "reduce_window_mul", "reduce_window_min", + "reduce_window_max", "real", "reshape", "rev", "rsqrt", "select_n", + "select_and_scatter_add", "shift_left", "shift_right_logical", + "shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt", + "squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor", + "zeros_like" } @classmethod diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index a748c8c2b..a0e4456f6 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -432,6 +432,7 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex: _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 7ad54c6ae..f3f36f031 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -498,11 +498,11 @@ def _integer_pow_taylor(primals_in, series_in, *, y): jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _expit_taylor(primals_in, series_in): +def _logistic_taylor(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series - v = [jax.scipy.special.expit(x)] + [None] * len(series) + v = [lax.logistic(x)] + [None] * len(series) e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)) @@ -511,12 +511,15 @@ def _expit_taylor(primals_in, series_in): primal_out, *series_out = v return primal_out, series_out +jet_rules[lax.logistic_p] = _logistic_taylor + + def _tanh_taylor(primals_in, series_in): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] primals_in, *series_in = u - primal_out, series_out = _expit_taylor((primals_in, ), (series_in, )) + primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, )) series_out = [2 * series_ for series_ in series_out] return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index c10f4e6ab..53480a66f 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -139,6 +139,8 @@ from jax._src.lax.lax import ( log1p as log1p, log1p_p as log1p_p, log_p as log_p, + logistic as logistic, + logistic_p as logistic_p, lt as lt, lt_p as lt_p, make_bint as make_bint, diff --git a/tests/jet_test.py b/tests/jet_test.py index d1d8d0a5d..71b4ef12a 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -188,7 +188,7 @@ class JetTest(jtu.JaxTestCase): primals = (primal_in, ) series = (terms_in, ) - y, terms = jax.experimental.jet._expit_taylor(primals, series) + y, terms = jax.experimental.jet._logistic_taylor(primals, series) expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series) atol = 1e-4 @@ -283,7 +283,7 @@ class JetTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5) @jtu.skip_on_devices("tpu") - def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-100, 100], order=5) + def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5) @jtu.skip_on_devices("tpu") def test_expit2(self): self.expit_check(lims=[-500, 500], order=5) @jtu.skip_on_devices("tpu") diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index f62e6ae6f..67f9c8db4 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -134,6 +134,9 @@ LAX_GRAD_OPS = [ dtypes=grad_complex_dtypes), grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default, dtypes=grad_float_dtypes, tol={np.float64: 3e-5}), + grad_test_spec(lax.logistic, nargs=1, order=2, + rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default, dtypes=grad_inexact_dtypes), diff --git a/tests/lax_test.py b/tests/lax_test.py index cb04ed5e7..d53de32b9 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -115,6 +115,7 @@ LAX_OPS = [ # TODO(b/143135720): on GPU, tanh has only ~float32 precision. op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small, {np.float64: 1e-9, np.complex128: 1e-7}), + op_record("logistic", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("atan2", 2, float_dtypes, jtu.rand_default), @@ -2959,6 +2960,8 @@ class LazyConstantTest(jtu.JaxTestCase): for op, dtypes in unary_op_types.items()) def testUnaryWeakTypes(self, op_name, rec_dtypes): """Test that all lax unary ops propagate weak_type information appropriately.""" + if op_name == "bitwise_not": + raise unittest.SkipTest("https://github.com/google/jax/issues/12066") # Find a valid dtype for the function. for dtype in [np.float_, np.int_, np.complex_, np.bool_]: dtype = dtypes.canonicalize_dtype(dtype) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 841c27780..d4d8cefd3 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -351,7 +351,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-6) + tol=3e-5) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) @@ -397,7 +397,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-6) + tol=2e-5) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3)