add Softmax layer to stax (closes #182)

This commit is contained in:
Matthew Johnson 2019-01-05 10:06:31 -08:00
parent 8ada14e96b
commit ecaae6bdd0
2 changed files with 17 additions and 0 deletions

View File

@ -146,6 +146,7 @@ def _elemwise_no_params(fun, **kwargs):
return init_fun, apply_fun
Tanh = _elemwise_no_params(np.tanh)
Relu = _elemwise_no_params(relu)
Exp = _elemwise_no_params(np.exp)
LogSoftmax = _elemwise_no_params(logsoftmax, axis=-1)
Softplus = _elemwise_no_params(softplus)
@ -314,3 +315,8 @@ def shape_dependent(make_layer):
def apply_fun(params, inputs, rng=None):
return make_layer(inputs.shape)[1](params, inputs, rng)
return init_fun, apply_fun
# Simple compositions
Softmax = serial(LogSoftmax, Exp)

View File

@ -159,6 +159,17 @@ class StaxTest(jtu.JaxTestCase):
init_fun, apply_fun = stax.FanInConcat(axis)
_CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
def testIsuse182(self):
init_fun, apply_fun = stax.Softmax
input_shape = (10, 3)
inputs = onp.arange(30.).astype("float32").reshape(input_shape)
out_shape, params = init_fun(input_shape)
out = apply_fun(params, inputs)
assert out_shape == out.shape
assert onp.allclose(onp.sum(onp.asarray(out), -1), 1.)
if __name__ == "__main__":
absltest.main()