mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #5236 from bileschi:patch-1
PiperOrigin-RevId: 350975202
This commit is contained in:
commit
783e1b7c89
@ -220,10 +220,11 @@ function:
|
||||
```python
|
||||
def predict(params, input_vec):
|
||||
assert input_vec.ndim == 1
|
||||
activations = inputs
|
||||
for W, b in params:
|
||||
output_vec = jnp.dot(W, input_vec) + b # `input_vec` on the right-hand side!
|
||||
input_vec = jnp.tanh(output_vec)
|
||||
return output_vec
|
||||
outputs = jnp.dot(W, activations) + b # `input_vec` on the right-hand side!
|
||||
activations = jnp.tanh(outputs)
|
||||
return outputs
|
||||
```
|
||||
|
||||
We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the
|
||||
|
Loading…
x
Reference in New Issue
Block a user