mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
add Softmax layer to stax (closes #182)
This commit is contained in:
parent
8ada14e96b
commit
ecaae6bdd0
@ -146,6 +146,7 @@ def _elemwise_no_params(fun, **kwargs):
|
||||
return init_fun, apply_fun
|
||||
Tanh = _elemwise_no_params(np.tanh)
|
||||
Relu = _elemwise_no_params(relu)
|
||||
Exp = _elemwise_no_params(np.exp)
|
||||
LogSoftmax = _elemwise_no_params(logsoftmax, axis=-1)
|
||||
Softplus = _elemwise_no_params(softplus)
|
||||
|
||||
@ -314,3 +315,8 @@ def shape_dependent(make_layer):
|
||||
def apply_fun(params, inputs, rng=None):
|
||||
return make_layer(inputs.shape)[1](params, inputs, rng)
|
||||
return init_fun, apply_fun
|
||||
|
||||
|
||||
# Simple compositions
|
||||
|
||||
Softmax = serial(LogSoftmax, Exp)
|
||||
|
@ -159,6 +159,17 @@ class StaxTest(jtu.JaxTestCase):
|
||||
init_fun, apply_fun = stax.FanInConcat(axis)
|
||||
_CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
|
||||
|
||||
def testIsuse182(self):
|
||||
init_fun, apply_fun = stax.Softmax
|
||||
input_shape = (10, 3)
|
||||
inputs = onp.arange(30.).astype("float32").reshape(input_shape)
|
||||
|
||||
out_shape, params = init_fun(input_shape)
|
||||
out = apply_fun(params, inputs)
|
||||
|
||||
assert out_shape == out.shape
|
||||
assert onp.allclose(onp.sum(onp.asarray(out), -1), 1.)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user