[shard_map docs]: Fix doc typos

PiperOrigin-RevId: 664960613
This commit is contained in:
jax authors 2024-08-19 13:45:06 -07:00 committed by jax authors
parent f1974b6471
commit 3e764f617a
2 changed files with 2 additions and 2 deletions

View File

@ -531,7 +531,7 @@
"\n",
"```python\n",
"def f_shmapped_ref(x):\n",
" x_blocks = jnp.array_split(x, mesh.shape[0])\n",
" x_blocks = jnp.array_split(x, mesh.shape['i'])\n",
" y_blocks = [f(x_blk) for x_blk in x_blocks]\n",
" return jnp.concatenate(y_blocks)\n",
"```\n",

View File

@ -378,7 +378,7 @@ values, as this reference function:
```python
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
x_blocks = jnp.array_split(x, mesh.shape['i'])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
```