mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[shard_map docs]: Fix doc typos
PiperOrigin-RevId: 664960613
This commit is contained in:
parent
f1974b6471
commit
3e764f617a
@ -531,7 +531,7 @@
|
||||
"\n",
|
||||
"```python\n",
|
||||
"def f_shmapped_ref(x):\n",
|
||||
" x_blocks = jnp.array_split(x, mesh.shape[0])\n",
|
||||
" x_blocks = jnp.array_split(x, mesh.shape['i'])\n",
|
||||
" y_blocks = [f(x_blk) for x_blk in x_blocks]\n",
|
||||
" return jnp.concatenate(y_blocks)\n",
|
||||
"```\n",
|
||||
|
@ -378,7 +378,7 @@ values, as this reference function:
|
||||
|
||||
```python
|
||||
def f_shmapped_ref(x):
|
||||
x_blocks = jnp.array_split(x, mesh.shape[0])
|
||||
x_blocks = jnp.array_split(x, mesh.shape['i'])
|
||||
y_blocks = [f(x_blk) for x_blk in x_blocks]
|
||||
return jnp.concatenate(y_blocks)
|
||||
```
|
||||
|
Loading…
x
Reference in New Issue
Block a user