1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Fix shard_map docs build

PiperOrigin-RevId: 746033054
This commit is contained in:
Yash Katariya 2025-04-10 08:03:56 -07:00 committed by jax authors
parent 95f1207fbf
commit 160bbe12d3
2 changed files with 10 additions and 8 deletions

@ -677,11 +677,11 @@
"metadata": {},
"outputs": [],
"source": [
"@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=None)\n",
"@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None)\n",
"def f(x):\n",
" print(jax.typeof(x)) # f32[3]\n",
" print(jax.typeof(x)) # f32[6]\n",
" y = jax.lax.pvary(x, 'i')\n",
" print(jax.typeof(y)) # f32[3]{i}\n",
" print(jax.typeof(y)) # f32[6]{i}\n",
"\n",
"x = jnp.arange(6.)\n",
"f(x)"
@ -715,7 +715,8 @@
" return x * y\n",
"\n",
"x = jnp.arange(6.)\n",
"print(jax.make_jaxpr(f)(x))"
"y = jnp.arange(3.)\n",
"print(jax.make_jaxpr(f)(x, y))"
]
},
{

@ -445,11 +445,11 @@ Sometimes we want to treat a value that is unvarying over a mesh axis as
varying over that mesh axis. That's what `jax.lax.pvary` does:
```{code-cell}
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=None)
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None)
def f(x):
print(jax.typeof(x)) # f32[3]
print(jax.typeof(x)) # f32[6]
y = jax.lax.pvary(x, 'i')
print(jax.typeof(y)) # f32[3]{i}
print(jax.typeof(y)) # f32[6]{i}
x = jnp.arange(6.)
f(x)
@ -471,7 +471,8 @@ def f(x, y):
return x * y
x = jnp.arange(6.)
print(jax.make_jaxpr(f)(x))
y = jnp.arange(3.)
print(jax.make_jaxpr(f)(x, y))
```
In a jaxpr, the multiplication operation requires the VMA types of its