Update stateful-computations.md

tree_map -> tree.map
This commit is contained in:
Gary Miguel 2025-03-09 21:35:46 -07:00 committed by GitHub
parent b9fb69d1fc
commit 6a718b762f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree_map(
new_params = jax.tree.map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params