mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Fix shard_map docs build
PiperOrigin-RevId: 746033054
This commit is contained in:
parent
95f1207fbf
commit
160bbe12d3
docs/notebooks
@ -677,11 +677,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=None)\n",
|
||||
"@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None)\n",
|
||||
"def f(x):\n",
|
||||
" print(jax.typeof(x)) # f32[3]\n",
|
||||
" print(jax.typeof(x)) # f32[6]\n",
|
||||
" y = jax.lax.pvary(x, 'i')\n",
|
||||
" print(jax.typeof(y)) # f32[3]{i}\n",
|
||||
" print(jax.typeof(y)) # f32[6]{i}\n",
|
||||
"\n",
|
||||
"x = jnp.arange(6.)\n",
|
||||
"f(x)"
|
||||
@ -715,7 +715,8 @@
|
||||
" return x * y\n",
|
||||
"\n",
|
||||
"x = jnp.arange(6.)\n",
|
||||
"print(jax.make_jaxpr(f)(x))"
|
||||
"y = jnp.arange(3.)\n",
|
||||
"print(jax.make_jaxpr(f)(x, y))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -445,11 +445,11 @@ Sometimes we want to treat a value that is unvarying over a mesh axis as
|
||||
varying over that mesh axis. That's what `jax.lax.pvary` does:
|
||||
|
||||
```{code-cell}
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=None)
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=None)
|
||||
def f(x):
|
||||
print(jax.typeof(x)) # f32[3]
|
||||
print(jax.typeof(x)) # f32[6]
|
||||
y = jax.lax.pvary(x, 'i')
|
||||
print(jax.typeof(y)) # f32[3]{i}
|
||||
print(jax.typeof(y)) # f32[6]{i}
|
||||
|
||||
x = jnp.arange(6.)
|
||||
f(x)
|
||||
@ -471,7 +471,8 @@ def f(x, y):
|
||||
return x * y
|
||||
|
||||
x = jnp.arange(6.)
|
||||
print(jax.make_jaxpr(f)(x))
|
||||
y = jnp.arange(3.)
|
||||
print(jax.make_jaxpr(f)(x, y))
|
||||
```
|
||||
|
||||
In a jaxpr, the multiplication operation requires the VMA types of its
|
||||
|
Loading…
x
Reference in New Issue
Block a user