mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Update stateful-computations.md
tree_map -> tree.map
This commit is contained in:
parent
b9fb69d1fc
commit
6a718b762f
@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
|
||||
# and then use `updates` instead of `grad` to actually update the params.
|
||||
# (And we'd include `new_optimizer_state` in the output, naturally.)
|
||||
|
||||
new_params = jax.tree_map(
|
||||
new_params = jax.tree.map(
|
||||
lambda param, g: param - g * LEARNING_RATE, params, grad)
|
||||
|
||||
return new_params
|
||||
|
Loading…
x
Reference in New Issue
Block a user