Merge pull request #5236 from bileschi:patch-1

PiperOrigin-RevId: 350975202
This commit is contained in:
jax authors 2021-01-09 20:43:00 -08:00
commit 783e1b7c89

View File

@ -220,10 +220,11 @@ function:
```python
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = inputs
for W, b in params:
output_vec = jnp.dot(W, input_vec) + b # `input_vec` on the right-hand side!
input_vec = jnp.tanh(output_vec)
return output_vec
outputs = jnp.dot(W, activations) + b # `input_vec` on the right-hand side!
activations = jnp.tanh(outputs)
return outputs
```
We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the