Fix formatting in the docs for transposing pytrees

This commit is contained in:
mikcl 2024-12-26 01:12:05 +00:00
parent 42a0d55503
commit 008c25a369

View File

@ -490,7 +490,7 @@ This section covers some of the most common patterns with JAX pytrees.
### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
**Option 1:** Use {func}`jax.tree.map`. Here's an example: