mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #10872 from fehiepsi:xmap-docs
PiperOrigin-RevId: 452803943
This commit is contained in:
commit
2765293746
@ -169,13 +169,13 @@
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import numpy as np\n",
|
||||
"from jax.experimental.maps import mesh\n",
|
||||
"from jax.experimental.maps import Mesh\n",
|
||||
"\n",
|
||||
"loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],\n",
|
||||
" axis_resources={'batch': 'x'})\n",
|
||||
"\n",
|
||||
"devices = np.array(jax.local_devices())\n",
|
||||
"with mesh(devices, ('x',)):\n",
|
||||
"with Mesh(devices, ('x',)):\n",
|
||||
" print(loss(w1, w2, images, labels))"
|
||||
]
|
||||
},
|
||||
@ -200,7 +200,7 @@
|
||||
" axis_resources={'hidden': 'x'})\n",
|
||||
"\n",
|
||||
"devices = np.array(jax.local_devices())\n",
|
||||
"with mesh(devices, ('x',)):\n",
|
||||
"with Mesh(devices, ('x',)):\n",
|
||||
" print(loss(w1, w2, images, labels))"
|
||||
]
|
||||
},
|
||||
@ -225,7 +225,7 @@
|
||||
" axis_resources={'batch': 'x', 'hidden': 'y'})\n",
|
||||
"\n",
|
||||
"devices = np.array(jax.local_devices()).reshape((4, 2))\n",
|
||||
"with mesh(devices, ('x', 'y')):\n",
|
||||
"with Mesh(devices, ('x', 'y')):\n",
|
||||
" print(loss(w1, w2, images, labels))"
|
||||
]
|
||||
},
|
||||
@ -779,7 +779,7 @@
|
||||
"id": "KHbRwYl0BOr1"
|
||||
},
|
||||
"source": [
|
||||
"To introduce the resources in a scope, use the `with mesh` context manager:"
|
||||
"To introduce the resources in a scope, use the `with Mesh` context manager:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -790,10 +790,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.experimental.maps import mesh\n",
|
||||
"from jax.experimental.maps import Mesh\n",
|
||||
"\n",
|
||||
"local = local_matmul(x, x) # The local function doesn't require the mesh definition\n",
|
||||
"with mesh(*mesh_def): # Makes the mesh axis names available as resources\n",
|
||||
"with Mesh(*mesh_def): # Makes the mesh axis names available as resources\n",
|
||||
" distr = distr_matmul(x, x)\n",
|
||||
"np.testing.assert_allclose(local, distr)"
|
||||
]
|
||||
@ -859,7 +859,7 @@
|
||||
"\n",
|
||||
"q = jnp.ones((4,), dtype=np.float32)\n",
|
||||
"u = jnp.ones((12,), dtype=np.float32)\n",
|
||||
"with mesh(np.array(jax.devices()[:4]), ('x',)):\n",
|
||||
"with Mesh(np.array(jax.devices()[:4]), ('x',)):\n",
|
||||
" v = xmap(sum_two_args,\n",
|
||||
" in_axes=(['a', ...], ['b', ...]),\n",
|
||||
" out_axes=[...],\n",
|
||||
|
@ -120,13 +120,13 @@ But on a whim we can decide to parallelize over the batch axis:
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from jax.experimental.maps import mesh
|
||||
from jax.experimental.maps import Mesh
|
||||
|
||||
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
|
||||
axis_resources={'batch': 'x'})
|
||||
|
||||
devices = np.array(jax.local_devices())
|
||||
with mesh(devices, ('x',)):
|
||||
with Mesh(devices, ('x',)):
|
||||
print(loss(w1, w2, images, labels))
|
||||
```
|
||||
|
||||
@ -141,7 +141,7 @@ loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
|
||||
axis_resources={'hidden': 'x'})
|
||||
|
||||
devices = np.array(jax.local_devices())
|
||||
with mesh(devices, ('x',)):
|
||||
with Mesh(devices, ('x',)):
|
||||
print(loss(w1, w2, images, labels))
|
||||
```
|
||||
|
||||
@ -156,7 +156,7 @@ loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
|
||||
axis_resources={'batch': 'x', 'hidden': 'y'})
|
||||
|
||||
devices = np.array(jax.local_devices()).reshape((4, 2))
|
||||
with mesh(devices, ('x', 'y')):
|
||||
with Mesh(devices, ('x', 'y')):
|
||||
print(loss(w1, w2, images, labels))
|
||||
```
|
||||
|
||||
@ -531,15 +531,15 @@ except Exception as e:
|
||||
|
||||
+++ {"id": "KHbRwYl0BOr1"}
|
||||
|
||||
To introduce the resources in a scope, use the `with mesh` context manager:
|
||||
To introduce the resources in a scope, use the `with Mesh` context manager:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: kYdoeaSS9m9f
|
||||
|
||||
from jax.experimental.maps import mesh
|
||||
from jax.experimental.maps import Mesh
|
||||
|
||||
local = local_matmul(x, x) # The local function doesn't require the mesh definition
|
||||
with mesh(*mesh_def): # Makes the mesh axis names available as resources
|
||||
with Mesh(*mesh_def): # Makes the mesh axis names available as resources
|
||||
distr = distr_matmul(x, x)
|
||||
np.testing.assert_allclose(local, distr)
|
||||
```
|
||||
@ -580,7 +580,7 @@ def sum_two_args(x: f32[(), {'a': 4}], y: f32[(), {'b': 12}]) -> f32[()]:
|
||||
|
||||
q = jnp.ones((4,), dtype=np.float32)
|
||||
u = jnp.ones((12,), dtype=np.float32)
|
||||
with mesh(np.array(jax.devices()[:4]), ('x',)):
|
||||
with Mesh(np.array(jax.devices()[:4]), ('x',)):
|
||||
v = xmap(sum_two_args,
|
||||
in_axes=(['a', ...], ['b', ...]),
|
||||
out_axes=[...],
|
||||
|
@ -154,7 +154,7 @@ class SerialLoop:
|
||||
def serial_loop(name: ResourceAxisName, length: int):
|
||||
"""Define a serial loop resource to be available in scope of this context manager.
|
||||
|
||||
This is similar to :py:func:`mesh` in that it extends the resource
|
||||
This is similar to :py:class:`Mesh` in that it extends the resource
|
||||
environment with a resource called ``name``. But, any use of this resource
|
||||
axis in ``axis_resources`` argument of :py:func:`xmap` will cause the
|
||||
body of :py:func:`xmap` to get executed ``length`` times with each execution
|
||||
@ -330,7 +330,7 @@ def xmap(fun: Callable,
|
||||
:py:func:`vmap`. However, this behavior can be further customized by the
|
||||
``axis_resources`` argument. When specified, each axis introduced by
|
||||
:py:func:`xmap` can be assigned to one or more *resource axes*. Those include
|
||||
the axes of the hardware mesh, as defined by the :py:func:`mesh` context
|
||||
the axes of the hardware mesh, as defined by the :py:class:`Mesh` context
|
||||
manager. Each value that has a named axis in its ``named_shape`` will be
|
||||
partitioned over all mesh axes that axis is assigned to. Hence,
|
||||
:py:func:`xmap` can be seen as an alternative to :py:func:`pmap` that also
|
||||
@ -423,7 +423,7 @@ def xmap(fun: Callable,
|
||||
to implement a distributed matrix-multiplication in just a few lines of code::
|
||||
|
||||
devices = np.array(jax.devices())[:4].reshape((2, 2))
|
||||
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
|
||||
with Mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
|
||||
distributed_out = xmap(
|
||||
jnp.vdot,
|
||||
in_axes=({0: 'left'}, {1: 'right'}),
|
||||
|
Loading…
x
Reference in New Issue
Block a user