fixed cross-entropy losses in mnist examples (fixes #1023)

This commit is contained in:
Jonas Rauber 2019-10-27 09:43:24 +01:00
parent 7cbd58b6c6
commit f8c5d98653
4 changed files with 4 additions and 4 deletions

View File

@ -127,7 +127,7 @@ def loss(params, batch):
inputs, targets = batch
logits = predict(params, inputs)
logits = stax.logsoftmax(logits) # log normalize
return -np.mean(np.sum(logits * targets, 1)) # cross entropy loss
return -np.mean(np.sum(logits * targets, axis=1)) # cross entropy loss
def accuracy(params, batch):

View File

@ -40,7 +40,7 @@ from examples import datasets
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
return -np.mean(np.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch

View File

@ -49,7 +49,7 @@ def predict(params, inputs):
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
return -np.mean(np.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch

View File

@ -57,7 +57,7 @@ def predict(params, inputs):
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(preds * targets)
return -np.mean(np.sum(preds * targets, axis=1))
@jit
def accuracy(params, batch):