Merge pull request #10872 from fehiepsi:xmap-docs

PiperOrigin-RevId: 452803943
This commit is contained in:
jax authors 2022-06-03 11:14:25 -07:00
commit 2765293746
3 changed files with 19 additions and 19 deletions

View File

@ -169,13 +169,13 @@
"source": [
"import jax\n",
"import numpy as np\n",
"from jax.experimental.maps import mesh\n",
"from jax.experimental.maps import Mesh\n",
"\n",
"loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],\n",
" axis_resources={'batch': 'x'})\n",
"\n",
"devices = np.array(jax.local_devices())\n",
"with mesh(devices, ('x',)):\n",
"with Mesh(devices, ('x',)):\n",
" print(loss(w1, w2, images, labels))"
]
},
@ -200,7 +200,7 @@
" axis_resources={'hidden': 'x'})\n",
"\n",
"devices = np.array(jax.local_devices())\n",
"with mesh(devices, ('x',)):\n",
"with Mesh(devices, ('x',)):\n",
" print(loss(w1, w2, images, labels))"
]
},
@ -225,7 +225,7 @@
" axis_resources={'batch': 'x', 'hidden': 'y'})\n",
"\n",
"devices = np.array(jax.local_devices()).reshape((4, 2))\n",
"with mesh(devices, ('x', 'y')):\n",
"with Mesh(devices, ('x', 'y')):\n",
" print(loss(w1, w2, images, labels))"
]
},
@ -779,7 +779,7 @@
"id": "KHbRwYl0BOr1"
},
"source": [
"To introduce the resources in a scope, use the `with mesh` context manager:"
"To introduce the resources in a scope, use the `with Mesh` context manager:"
]
},
{
@ -790,10 +790,10 @@
},
"outputs": [],
"source": [
"from jax.experimental.maps import mesh\n",
"from jax.experimental.maps import Mesh\n",
"\n",
"local = local_matmul(x, x) # The local function doesn't require the mesh definition\n",
"with mesh(*mesh_def): # Makes the mesh axis names available as resources\n",
"with Mesh(*mesh_def): # Makes the mesh axis names available as resources\n",
" distr = distr_matmul(x, x)\n",
"np.testing.assert_allclose(local, distr)"
]
@ -859,7 +859,7 @@
"\n",
"q = jnp.ones((4,), dtype=np.float32)\n",
"u = jnp.ones((12,), dtype=np.float32)\n",
"with mesh(np.array(jax.devices()[:4]), ('x',)):\n",
"with Mesh(np.array(jax.devices()[:4]), ('x',)):\n",
" v = xmap(sum_two_args,\n",
" in_axes=(['a', ...], ['b', ...]),\n",
" out_axes=[...],\n",

View File

@ -120,13 +120,13 @@ But on a whim we can decide to parallelize over the batch axis:
import jax
import numpy as np
from jax.experimental.maps import mesh
from jax.experimental.maps import Mesh
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'batch': 'x'})
devices = np.array(jax.local_devices())
with mesh(devices, ('x',)):
with Mesh(devices, ('x',)):
print(loss(w1, w2, images, labels))
```
@ -141,7 +141,7 @@ loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'hidden': 'x'})
devices = np.array(jax.local_devices())
with mesh(devices, ('x',)):
with Mesh(devices, ('x',)):
print(loss(w1, w2, images, labels))
```
@ -156,7 +156,7 @@ loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'batch': 'x', 'hidden': 'y'})
devices = np.array(jax.local_devices()).reshape((4, 2))
with mesh(devices, ('x', 'y')):
with Mesh(devices, ('x', 'y')):
print(loss(w1, w2, images, labels))
```
@ -531,15 +531,15 @@ except Exception as e:
+++ {"id": "KHbRwYl0BOr1"}
To introduce the resources in a scope, use the `with mesh` context manager:
To introduce the resources in a scope, use the `with Mesh` context manager:
```{code-cell} ipython3
:id: kYdoeaSS9m9f
from jax.experimental.maps import mesh
from jax.experimental.maps import Mesh
local = local_matmul(x, x) # The local function doesn't require the mesh definition
with mesh(*mesh_def): # Makes the mesh axis names available as resources
with Mesh(*mesh_def): # Makes the mesh axis names available as resources
distr = distr_matmul(x, x)
np.testing.assert_allclose(local, distr)
```
@ -580,7 +580,7 @@ def sum_two_args(x: f32[(), {'a': 4}], y: f32[(), {'b': 12}]) -> f32[()]:
q = jnp.ones((4,), dtype=np.float32)
u = jnp.ones((12,), dtype=np.float32)
with mesh(np.array(jax.devices()[:4]), ('x',)):
with Mesh(np.array(jax.devices()[:4]), ('x',)):
v = xmap(sum_two_args,
in_axes=(['a', ...], ['b', ...]),
out_axes=[...],

View File

@ -154,7 +154,7 @@ class SerialLoop:
def serial_loop(name: ResourceAxisName, length: int):
"""Define a serial loop resource to be available in scope of this context manager.
This is similar to :py:func:`mesh` in that it extends the resource
This is similar to :py:class:`Mesh` in that it extends the resource
environment with a resource called ``name``. But, any use of this resource
axis in ``axis_resources`` argument of :py:func:`xmap` will cause the
body of :py:func:`xmap` to get executed ``length`` times with each execution
@ -330,7 +330,7 @@ def xmap(fun: Callable,
:py:func:`vmap`. However, this behavior can be further customized by the
``axis_resources`` argument. When specified, each axis introduced by
:py:func:`xmap` can be assigned to one or more *resource axes*. Those include
the axes of the hardware mesh, as defined by the :py:func:`mesh` context
the axes of the hardware mesh, as defined by the :py:class:`Mesh` context
manager. Each value that has a named axis in its ``named_shape`` will be
partitioned over all mesh axes that axis is assigned to. Hence,
:py:func:`xmap` can be seen as an alternative to :py:func:`pmap` that also
@ -423,7 +423,7 @@ def xmap(fun: Callable,
to implement a distributed matrix-multiplication in just a few lines of code::
devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
with Mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
distributed_out = xmap(
jnp.vdot,
in_axes=({0: 'left'}, {1: 'right'}),