mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
fixed cross-entropy losses in mnist examples (fixes #1023)
This commit is contained in:
parent
7cbd58b6c6
commit
f8c5d98653
@ -127,7 +127,7 @@ def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
logits = predict(params, inputs)
|
||||
logits = stax.logsoftmax(logits) # log normalize
|
||||
return -np.mean(np.sum(logits * targets, 1)) # cross entropy loss
|
||||
return -np.mean(np.sum(logits * targets, axis=1)) # cross entropy loss
|
||||
|
||||
|
||||
def accuracy(params, batch):
|
||||
|
@ -40,7 +40,7 @@ from examples import datasets
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(preds * targets)
|
||||
return -np.mean(np.sum(preds * targets, axis=1))
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
|
@ -49,7 +49,7 @@ def predict(params, inputs):
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(preds * targets)
|
||||
return -np.mean(np.sum(preds * targets, axis=1))
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
|
@ -57,7 +57,7 @@ def predict(params, inputs):
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
preds = predict(params, inputs)
|
||||
return -np.mean(preds * targets)
|
||||
return -np.mean(np.sum(preds * targets, axis=1))
|
||||
|
||||
@jit
|
||||
def accuracy(params, batch):
|
||||
|
Loading…
x
Reference in New Issue
Block a user