From 6a718b762f8ae59ee81c3793273d2689fe0b553c Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Sun, 9 Mar 2025 21:35:46 -0700 Subject: [PATCH] Update stateful-computations.md tree_map -> tree.map --- docs/stateful-computations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index fe84fc0d7..30c626bec 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -195,7 +195,7 @@ def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params: # and then use `updates` instead of `grad` to actually update the params. # (And we'd include `new_optimizer_state` in the output, naturally.) - new_params = jax.tree_map( + new_params = jax.tree.map( lambda param, g: param - g * LEARNING_RATE, params, grad) return new_params