mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
fixing. Renaming per reviewer feedback.
This commit is contained in:
parent
80173aad62
commit
c1904ea7a7
@ -217,10 +217,11 @@ function:
|
||||
```python
|
||||
def predict(params, input_vec):
|
||||
assert input_vec.ndim == 1
|
||||
activations = inputs
|
||||
for W, b in params:
|
||||
output_vec = jnp.dot(W, input_vec) + b # `input_vec` on the right-hand side!
|
||||
output_vec = jnp.tanh(output_vec)
|
||||
return output_vec
|
||||
outputs = jnp.dot(W, activations) + b # `input_vec` on the right-hand side!
|
||||
activations = jnp.tanh(outputs)
|
||||
return outputs
|
||||
```
|
||||
|
||||
We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the
|
||||
|
Loading…
x
Reference in New Issue
Block a user