mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code. PiperOrigin-RevId: 469789339
This commit is contained in:
parent
91b1e01f60
commit
6276194e1c
@ -103,6 +103,7 @@ Operators
|
||||
lgamma
|
||||
log
|
||||
log1p
|
||||
logistic
|
||||
max
|
||||
min
|
||||
mul
|
||||
|
@ -300,6 +300,10 @@ def tanh(x: Array) -> Array:
|
||||
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
|
||||
return tanh_p.bind(x)
|
||||
|
||||
def logistic(x: Array) -> Array:
|
||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
|
||||
return logistic_p.bind(x)
|
||||
|
||||
def sin(x: Array) -> Array:
|
||||
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
|
||||
return sin_p.bind(x)
|
||||
@ -1730,6 +1734,10 @@ ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
|
||||
sub(_one(x), ans)))
|
||||
mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp))
|
||||
|
||||
logistic_p = standard_unop(_float | _complex, 'logistic')
|
||||
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
|
||||
mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp))
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
if mlir_api_version < 27:
|
||||
|
@ -69,6 +69,7 @@ asinh = np.arcsinh
|
||||
acosh = np.arccosh
|
||||
atanh = np.arctanh
|
||||
|
||||
def logistic(x): return 1 / (1 + np.exp(-x))
|
||||
def betainc(a, b, x): return scipy.special.betainc(a, b, x).astype(x.dtype)
|
||||
def lgamma(x): return scipy.special.gammaln(x).astype(x.dtype)
|
||||
def digamma(x): return scipy.special.digamma(x).astype(x.dtype)
|
||||
|
@ -91,7 +91,7 @@ def sigmoid(x: Array) -> Array:
|
||||
Args:
|
||||
x : input array
|
||||
"""
|
||||
return expit(x)
|
||||
return lax.logistic(x)
|
||||
|
||||
@jax.jit
|
||||
def silu(x: Array) -> Array:
|
||||
|
@ -96,13 +96,10 @@ logit.defjvps(
|
||||
lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))
|
||||
|
||||
|
||||
@api.custom_jvp
|
||||
@_wraps(osp_special.expit, module='scipy.special', update_doc=False)
|
||||
def expit(x):
|
||||
x, = _promote_args_inexact("expit", x)
|
||||
one = _lax_const(x, 1)
|
||||
return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
|
||||
expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans))
|
||||
return lax.logistic(x)
|
||||
|
||||
|
||||
@_wraps(osp_special.logsumexp, module='scipy.special')
|
||||
|
@ -1339,6 +1339,8 @@ tf_impl_with_avals[lax.asin_p] = _convert_jax_impl(
|
||||
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
|
||||
lax_internal.atan_impl, multiple_results=False)
|
||||
|
||||
tf_impl[lax.logistic_p] = tf.math.sigmoid
|
||||
|
||||
def _atan2(y, x, **kwargs):
|
||||
if x.dtype.is_complex or y.dtype.is_complex:
|
||||
complex_component_dtype = {
|
||||
|
@ -130,15 +130,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
|
||||
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
|
||||
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
|
||||
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
|
||||
"random_categorical", "random_uniform", "random_randint",
|
||||
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
||||
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
|
||||
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
|
||||
"select_and_scatter_add", "shift_left", "shift_right_logical",
|
||||
"shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt",
|
||||
"squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor",
|
||||
"zeros_like"
|
||||
"logistic", "lt", "log", "mul", "ne", "neg", "not", "or", "pad",
|
||||
"population_count", "random_categorical", "random_uniform",
|
||||
"random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or",
|
||||
"reduce_sum", "reduce_window_mul", "reduce_window_min",
|
||||
"reduce_window_max", "real", "reshape", "rev", "rsqrt", "scatter_max",
|
||||
"scatter_min", "select_n", "select_and_scatter_add", "shift_left",
|
||||
"shift_right_logical", "shift_right_arithmetic", "sign", "sin", "sinh",
|
||||
"slice", "sqrt", "squeeze", "stop_gradient", "sub", "tie_in", "transpose",
|
||||
"xor", "zeros_like"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@ -432,6 +432,7 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
|
||||
_make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype)
|
||||
_make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype)
|
||||
|
||||
for dtype in jtu.dtypes.all_floating:
|
||||
_make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype)
|
||||
|
@ -498,11 +498,11 @@ def _integer_pow_taylor(primals_in, series_in, *, y):
|
||||
jet_rules[lax.integer_pow_p] = _integer_pow_taylor
|
||||
|
||||
|
||||
def _expit_taylor(primals_in, series_in):
|
||||
def _logistic_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [x] + series
|
||||
v = [jax.scipy.special.expit(x)] + [None] * len(series)
|
||||
v = [lax.logistic(x)] + [None] * len(series)
|
||||
e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid)
|
||||
for k in range(1, len(v)):
|
||||
v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1))
|
||||
@ -511,12 +511,15 @@ def _expit_taylor(primals_in, series_in):
|
||||
primal_out, *series_out = v
|
||||
return primal_out, series_out
|
||||
|
||||
jet_rules[lax.logistic_p] = _logistic_taylor
|
||||
|
||||
|
||||
def _tanh_taylor(primals_in, series_in):
|
||||
x, = primals_in
|
||||
series, = series_in
|
||||
u = [2*x] + [2 * series_ for series_ in series]
|
||||
primals_in, *series_in = u
|
||||
primal_out, series_out = _expit_taylor((primals_in, ), (series_in, ))
|
||||
primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, ))
|
||||
series_out = [2 * series_ for series_ in series_out]
|
||||
return 2 * primal_out - 1, series_out
|
||||
jet_rules[lax.tanh_p] = _tanh_taylor
|
||||
|
@ -139,6 +139,8 @@ from jax._src.lax.lax import (
|
||||
log1p as log1p,
|
||||
log1p_p as log1p_p,
|
||||
log_p as log_p,
|
||||
logistic as logistic,
|
||||
logistic_p as logistic_p,
|
||||
lt as lt,
|
||||
lt_p as lt_p,
|
||||
make_bint as make_bint,
|
||||
|
@ -188,7 +188,7 @@ class JetTest(jtu.JaxTestCase):
|
||||
primals = (primal_in, )
|
||||
series = (terms_in, )
|
||||
|
||||
y, terms = jax.experimental.jet._expit_taylor(primals, series)
|
||||
y, terms = jax.experimental.jet._logistic_taylor(primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(jax.scipy.special.expit, primals, series)
|
||||
|
||||
atol = 1e-4
|
||||
@ -283,7 +283,7 @@ class JetTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expit(self): self.unary_check(jax.scipy.special.expit, lims=[-100, 100], order=5)
|
||||
def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_expit2(self): self.expit_check(lims=[-500, 500], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
|
@ -134,6 +134,9 @@ LAX_GRAD_OPS = [
|
||||
dtypes=grad_complex_dtypes),
|
||||
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_float_dtypes, tol={np.float64: 3e-5}),
|
||||
grad_test_spec(lax.logistic, nargs=1, order=2,
|
||||
rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes),
|
||||
|
||||
grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default,
|
||||
dtypes=grad_inexact_dtypes),
|
||||
|
@ -114,6 +114,7 @@ LAX_OPS = [
|
||||
# TODO(b/143135720): on GPU, tanh has only ~float32 precision.
|
||||
op_record("tanh", 1, float_dtypes + complex_dtypes, jtu.rand_small,
|
||||
{np.float64: 1e-9, np.complex128: 1e-7}),
|
||||
op_record("logistic", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("sin", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("cos", 1, float_dtypes + complex_dtypes, jtu.rand_default),
|
||||
op_record("atan2", 2, float_dtypes, jtu.rand_default),
|
||||
@ -2958,6 +2959,8 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
for op, dtypes in unary_op_types.items()))
|
||||
def testUnaryWeakTypes(self, op_name, rec_dtypes):
|
||||
"""Test that all lax unary ops propagate weak_type information appropriately."""
|
||||
if op_name == "bitwise_not":
|
||||
raise unittest.SkipTest("https://github.com/google/jax/issues/12066")
|
||||
# Find a valid dtype for the function.
|
||||
for dtype in [np.float_, np.int_, np.complex_, np.bool_]:
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
|
@ -351,7 +351,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
tol=3e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@ -397,7 +397,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
return list(map(rng, shapes, dtypes))
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
tol=2e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user