mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
- sign missing from loss function definition
This commit is contained in:
parent
1603a7393a
commit
1090e89a86
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
logits = predict_fun(params, inputs)
|
||||
return np.sum(logits * targets)
|
||||
return -np.sum(logits * targets)
|
||||
|
||||
def accuracy(params, batch):
|
||||
inputs, targets = batch
|
||||
|
Loading…
x
Reference in New Issue
Block a user