diff --git a/jax/experimental/stax.py b/jax/experimental/stax.py index e9988f128..ffc2dec4d 100644 --- a/jax/experimental/stax.py +++ b/jax/experimental/stax.py @@ -146,6 +146,7 @@ def _elemwise_no_params(fun, **kwargs): return init_fun, apply_fun Tanh = _elemwise_no_params(np.tanh) Relu = _elemwise_no_params(relu) +Exp = _elemwise_no_params(np.exp) LogSoftmax = _elemwise_no_params(logsoftmax, axis=-1) Softplus = _elemwise_no_params(softplus) @@ -314,3 +315,8 @@ def shape_dependent(make_layer): def apply_fun(params, inputs, rng=None): return make_layer(inputs.shape)[1](params, inputs, rng) return init_fun, apply_fun + + +# Simple compositions + +Softmax = serial(LogSoftmax, Exp) diff --git a/tests/stax_test.py b/tests/stax_test.py index cb3502502..245c9b24e 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -159,6 +159,17 @@ class StaxTest(jtu.JaxTestCase): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes) + def testIsuse182(self): + init_fun, apply_fun = stax.Softmax + input_shape = (10, 3) + inputs = onp.arange(30.).astype("float32").reshape(input_shape) + + out_shape, params = init_fun(input_shape) + out = apply_fun(params, inputs) + + assert out_shape == out.shape + assert onp.allclose(onp.sum(onp.asarray(out), -1), 1.) + if __name__ == "__main__": absltest.main()