- sign missing from loss function definition

This commit is contained in:
Karthik Kumara 2019-10-17 16:00:35 -07:00
parent 1603a7393a
commit 1090e89a86

View File

@ -101,7 +101,7 @@ if __name__ == "__main__":
def loss(params, batch):
inputs, targets = batch
logits = predict_fun(params, inputs)
return np.sum(logits * targets)
return -np.sum(logits * targets)
def accuracy(params, batch):
inputs, targets = batch