mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
commit
c6eb632f57
@ -6,7 +6,7 @@ It provides composable transformations of Python+NumPy programs: differentiate,
|
||||
parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
|
||||
.. note::
|
||||
JAX 0.4.0 introduces new parallelism APIs, including breaking changes to :func:`jax.experimental.pjit` and a new unified ``jax.Array`` type.
|
||||
JAX 0.4.1 introduces new parallelism APIs, including breaking changes to :func:`jax.experimental.pjit` and a new unified ``jax.Array`` type.
|
||||
Please see `Distributed arrays and automatic parallelization <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_ tutorial and the :ref:`jax-array-migration`
|
||||
guide for more information.
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,13 +17,15 @@ kernelspec:
|
||||
|
||||
+++ {"id": "pFtQjv4SzHRj"}
|
||||
|
||||
**This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.0 and newer.**
|
||||
**This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.**
|
||||
|
||||
See [`jax-array-migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide for migrating existing pre-v0.4.0 codebases to `jax.Array`.
|
||||
See [`jax-array-migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide for migrating existing pre-v0.4.1 codebases to `jax.Array`.
|
||||
|
||||
**The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Cloud TPU.**
|
||||
|
||||
```{code-cell}
|
||||
:id: FNxScTfq3vGF
|
||||
|
||||
import os
|
||||
|
||||
import functools
|
||||
@ -33,8 +35,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax.config.update('jax_array', True)
|
||||
```
|
||||
|
||||
+++ {"id": "eyHMwyEfQJcz"}
|
||||
@ -42,6 +42,8 @@ jax.config.update('jax_array', True)
|
||||
⚠️ WARNING: notebook requires 8 devices to run.
|
||||
|
||||
```{code-cell}
|
||||
:id: IZMLqOUV3vGG
|
||||
|
||||
if len(jax.local_devices()) < 8:
|
||||
raise Exception("Notebook requires 8 devices to run")
|
||||
```
|
||||
@ -59,16 +61,23 @@ Before we think step by step, here's a quick example.
|
||||
First, we'll create a `jax.Array` sharded across multiple devices:
|
||||
|
||||
```{code-cell}
|
||||
:id: Gf2lO4ii3vGG
|
||||
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import PositionalSharding
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: q-XBTEoy3vGG
|
||||
|
||||
# Create a Sharding object to distribute a value across devices:
|
||||
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: vI39znW93vGH
|
||||
:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b
|
||||
|
||||
# Create an array of random values:
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
|
||||
# and use jax.device_put to distribute it across devices:
|
||||
@ -82,6 +91,9 @@ Next, we'll apply a computation to it and visualize how the result values are
|
||||
stored across multiple devices too:
|
||||
|
||||
```{code-cell}
|
||||
:id: -qCnHZl83vGI
|
||||
:outputId: 9da9c29e-ce88-4425-e1ec-e93e5bcf3106
|
||||
|
||||
z = jnp.sin(y)
|
||||
jax.debug.visualize_array_sharding(z)
|
||||
```
|
||||
@ -92,11 +104,17 @@ The evaluation of the `jnp.sin` application was automatically parallelized
|
||||
across the devices on which the input values (and output values) are stored:
|
||||
|
||||
```{code-cell}
|
||||
:id: _VTzN0r03vGI
|
||||
:outputId: c9208010-984b-442b-d105-c8c6a3a010e6
|
||||
|
||||
# `x` is present on a single device
|
||||
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: QuzhU1g63vGI
|
||||
:outputId: d48fc76e-79a7-47b9-d392-b18a1c33c798
|
||||
|
||||
# `y` is sharded across 8 devices.
|
||||
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
|
||||
```
|
||||
@ -121,11 +139,16 @@ In JAX, `Sharding` objects describe distributed memory layouts. They can be used
|
||||
For example, here's a value with a single-device `Sharding`:
|
||||
|
||||
```{code-cell}
|
||||
:id: VmoX4SUp3vGJ
|
||||
|
||||
import jax
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: vNRabO2J3vGJ
|
||||
:outputId: 73db7b6e-c2e7-467d-a0ef-c35e29e582dd
|
||||
|
||||
jax.debug.visualize_array_sharding(x)
|
||||
```
|
||||
|
||||
@ -136,6 +159,8 @@ Here, we're using the `jax.debug.visualize_array_sharding` function to show wher
|
||||
But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:
|
||||
|
||||
```{code-cell}
|
||||
:id: VUIEIzRp3vGK
|
||||
|
||||
from jax.experimental import mesh_utils
|
||||
devices = mesh_utils.create_device_mesh((8,))
|
||||
```
|
||||
@ -145,6 +170,9 @@ devices = mesh_utils.create_device_mesh((8,))
|
||||
Then, we create a `PositionalSharding` and use it with `device_put`:
|
||||
|
||||
```{code-cell}
|
||||
:id: jwrWfZeB3vGK
|
||||
:outputId: e6f126bd-f6bd-48c7-c130-6f02757e3342
|
||||
|
||||
from jax.sharding import PositionalSharding
|
||||
|
||||
sharding = PositionalSharding(devices)
|
||||
@ -158,6 +186,9 @@ jax.debug.visualize_array_sharding(x)
|
||||
Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements:
|
||||
|
||||
```{code-cell}
|
||||
:id: zxWB82Kz3vGK
|
||||
:outputId: 11384a6b-fabc-4c4c-bcad-a3be51eb0465
|
||||
|
||||
sharding
|
||||
```
|
||||
|
||||
@ -166,10 +197,16 @@ sharding
|
||||
By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:
|
||||
|
||||
```{code-cell}
|
||||
:id: PLsnpSzc3vGL
|
||||
:outputId: 9f4db733-cafe-46ae-c057-dc31046a6f66
|
||||
|
||||
sharding.reshape(8, 1)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: iqKdI4LO3vGL
|
||||
:outputId: 6aa10fc2-cec4-4401-a0df-343e71646e0a
|
||||
|
||||
sharding.reshape(4, 2)
|
||||
```
|
||||
|
||||
@ -185,11 +222,17 @@ def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:
|
||||
For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`:
|
||||
|
||||
```{code-cell}
|
||||
:id: SELr4xNi3vGL
|
||||
:outputId: b2f4acec-0cd3-4829-ca16-cae2e0e8ca60
|
||||
|
||||
sharding = sharding.reshape(4, 2)
|
||||
print(sharding)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: 8IVIsqfX3vGL
|
||||
:outputId: 033d0e02-a643-4f4c-9d24-9cd8465bc69a
|
||||
|
||||
y = jax.device_put(x, sharding)
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -201,11 +244,17 @@ Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are st
|
||||
Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result:
|
||||
|
||||
```{code-cell}
|
||||
:id: cCjt6QCz3vGM
|
||||
:outputId: 4ad8a611-596d-424f-b6c5-fc00f1adc306
|
||||
|
||||
sharding = sharding.reshape(1, 8)
|
||||
print(sharding)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: yTK4Nz3u3vGM
|
||||
:outputId: e445c6bc-4fe3-4e9d-cc9e-d82858f58312
|
||||
|
||||
y = jax.device_put(x, sharding)
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -217,11 +266,17 @@ In some cases, we don't just want to store each slice of `x` in a single device'
|
||||
With `PositionalSharding`, we can express replication by calling the reducer method `replicate`:
|
||||
|
||||
```{code-cell}
|
||||
:id: _jr6XYKx3vGM
|
||||
:outputId: 59c8b9a4-b8af-493a-ba8d-da5931e88f93
|
||||
|
||||
sharding = sharding.reshape(4, 2)
|
||||
print(sharding.replicate(axis=0, keepdims=True))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: S5vzjFuH3vGN
|
||||
:outputId: b6ce2675-7261-4e57-fa8c-b4e87abf7e52
|
||||
|
||||
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -233,11 +288,17 @@ Here the visualization shows that `x` is sharded two ways along its second dimen
|
||||
The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:
|
||||
|
||||
```{code-cell}
|
||||
:id: DR7VV-6e3vGN
|
||||
:outputId: f879fc2c-5723-4199-b306-295bc1b3681e
|
||||
|
||||
print(sharding.replicate(0).shape)
|
||||
print(sharding.replicate(1).shape)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: agUtVUVx3vGN
|
||||
:outputId: 0e9789ef-ce52-4ed6-8bd5-c876b95f66e6
|
||||
|
||||
y = jax.device_put(x, sharding.replicate(1))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -253,14 +314,18 @@ So far we've worked with `PositionalSharding`, but there are alternative ways to
|
||||
Another convenient way to express sharding is with the `NamedSharding`:
|
||||
|
||||
```{code-cell}
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import PartitionSpec
|
||||
from jax.experimental import mesh_utils
|
||||
:id: zpB1JxyK3vGN
|
||||
:outputId: 46d5da37-840c-49d8-8380-a162811bae8a
|
||||
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.experimental import mesh_utils
|
||||
|
||||
P = PartitionSpec
|
||||
|
||||
devices = mesh_utils.create_device_mesh((4, 2))
|
||||
mesh = maps.Mesh(devices, axis_names=('a', 'b'))
|
||||
mesh = Mesh(devices, axis_names=('a', 'b'))
|
||||
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -270,11 +335,13 @@ jax.debug.visualize_array_sharding(y)
|
||||
We can define a helper function to make things simpler:
|
||||
|
||||
```{code-cell}
|
||||
:id: 8g0Md2Gd3vGO
|
||||
|
||||
devices = mesh_utils.create_device_mesh((4, 2))
|
||||
default_mesh = maps.Mesh(devices, axis_names=('a', 'b'))
|
||||
default_mesh = Mesh(devices, axis_names=('a', 'b'))
|
||||
|
||||
def mesh_sharding(
|
||||
pspec: PartitionSpec, mesh: Optional[maps.Mesh] = None,
|
||||
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
|
||||
) -> NamedSharding:
|
||||
if mesh is None:
|
||||
mesh = default_mesh
|
||||
@ -282,6 +349,9 @@ def mesh_sharding(
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: zp3MfS4Y3vGO
|
||||
:outputId: 2c2f7201-c2c1-49e5-f8a5-0730c124d89a
|
||||
|
||||
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -291,11 +361,17 @@ jax.debug.visualize_array_sharding(y)
|
||||
Here, we use `P('a', 'b')` to express that the first and second axes of `x` should be sharded over the device mesh axes `'a'` and `'b'`, respectively. We can easily switch to `P('b', 'a')` to shard the axes of `x` over different devices:
|
||||
|
||||
```{code-cell}
|
||||
:id: FigK5Zsa3vGO
|
||||
:outputId: eca784e8-33fe-4e9b-a41d-21e9ee781a35
|
||||
|
||||
y = jax.device_put(x, mesh_sharding(P('b', 'a')))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: hI-HD0xN3vGO
|
||||
:outputId: c3e7dc3c-4048-448a-ef0b-50683532fcdc
|
||||
|
||||
# This `None` means that `x` is not sharded on its second dimension,
|
||||
# and since the Mesh axis name 'b' is not mentioned, shards are
|
||||
# replicated across it.
|
||||
@ -310,11 +386,17 @@ Here, because `P('a', None)` doesn't mention the `Mesh` axis name `'b'`, we get
|
||||
To shard only over the second axis of `x`, we can use a `None` placeholder in the `PartitionSpec`:
|
||||
|
||||
```{code-cell}
|
||||
:id: EXBExMQC3vGP
|
||||
:outputId: fe1c8d7e-3345-4438-b9d2-780e7854b4eb
|
||||
|
||||
y = jax.device_put(x, mesh_sharding(P(None, 'b')))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: PjUpG8uz3vGP
|
||||
:outputId: 64d8224d-15d9-4ad4-d613-f7f85b1dc1af
|
||||
|
||||
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -324,6 +406,9 @@ jax.debug.visualize_array_sharding(y)
|
||||
For a fixed mesh, we can even partition one logical axis of `x` over multiple device mesh axes:
|
||||
|
||||
```{code-cell}
|
||||
:id: fVcPbDUA3vGP
|
||||
:outputId: 7f524ba5-a6d8-4490-cda9-685ad11416f9
|
||||
|
||||
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
@ -343,12 +428,17 @@ With sharded input data, the compiler can give us parallel computation. In parti
|
||||
For example, the simplest computation is an elementwise one:
|
||||
|
||||
```{code-cell}
|
||||
:id: _EmQwggc3vGQ
|
||||
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import PositionalSharding
|
||||
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: LnT0vWjc3vGQ
|
||||
:outputId: 8089effc-aa4c-49e3-dd19-7064881dbad0
|
||||
|
||||
x = jax.device_put(x, sharding.reshape(4, 2))
|
||||
print('input sharding:')
|
||||
jax.debug.visualize_array_sharding(x)
|
||||
@ -367,6 +457,9 @@ In other words, even though we wrote the `jnp.sin` computation as if a single ma
|
||||
We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:
|
||||
|
||||
```{code-cell}
|
||||
:id: Dq043GkP3vGQ
|
||||
:outputId: 350219a8-1e4a-4404-fe14-50f97ea3e7ba
|
||||
|
||||
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
|
||||
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
|
||||
print('lhs sharding:')
|
||||
@ -386,20 +479,32 @@ Here the compiler chose the output sharding so that it could maximally paralleli
|
||||
How can we be sure it's actually running in parallel? We can do a simple timing experiment:
|
||||
|
||||
```{code-cell}
|
||||
:id: QjQ5u8qh3vGQ
|
||||
:outputId: bd29edcd-b87c-486e-c568-906f06ae16be
|
||||
|
||||
x_single = jax.device_put(x, jax.devices()[0])
|
||||
jax.debug.visualize_array_sharding(x_single)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: 8tn8lOj73vGR
|
||||
:outputId: 5809b3c8-7333-4cd3-db97-a7aede943dce
|
||||
|
||||
np.allclose(jnp.dot(x_single, x_single),
|
||||
jnp.dot(y, z))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: D7PpZwhR3vGR
|
||||
:outputId: 4f0bd43d-0b32-4089-d3da-c8f1449e3526
|
||||
|
||||
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: rgo_yVHF3vGR
|
||||
:outputId: 97f19052-f1c9-4d30-f453-07b3a7208aa9
|
||||
|
||||
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
|
||||
```
|
||||
|
||||
@ -408,6 +513,9 @@ np.allclose(jnp.dot(x_single, x_single),
|
||||
Even copying a sharded `Array` produces a result with the sharding of the input:
|
||||
|
||||
```{code-cell}
|
||||
:id: f1Zw-2lH3vGR
|
||||
:outputId: a796bed4-07b0-497d-8fd8-31a22ab9762e
|
||||
|
||||
w_copy = jnp.copy(w)
|
||||
jax.debug.visualize_array_sharding(w_copy)
|
||||
```
|
||||
@ -424,6 +532,8 @@ But what if two arguments to a computation are explicitly placed on different se
|
||||
In these ambiguous cases, an error is raised:
|
||||
|
||||
```{code-cell}
|
||||
:id: 1vAkZAOY3vGR
|
||||
|
||||
import textwrap
|
||||
from termcolor import colored
|
||||
|
||||
@ -433,6 +543,9 @@ def print_exception(e):
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: DHh0N3vn3vGS
|
||||
:outputId: e7741882-0ebf-4237-e5d1-e48c9b9c178f
|
||||
|
||||
sharding1 = PositionalSharding(jax.devices()[:4])
|
||||
sharding2 = PositionalSharding(jax.devices()[4:])
|
||||
|
||||
@ -443,6 +556,9 @@ except ValueError as e: print_exception(e)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: Im7DkoOl3vGS
|
||||
:outputId: 3adfe1cb-db52-4a9d-e98e-62c6455c3100
|
||||
|
||||
devices = jax.devices()
|
||||
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
|
||||
|
||||
@ -465,6 +581,9 @@ Unlike committed arrays, uncommitted arrays can be moved and resharded automatic
|
||||
For example, the output of `jnp.zeros`, `jnp.arange`, and `jnp.array` are uncommitted:
|
||||
|
||||
```{code-cell}
|
||||
:id: _QvtKL8r3vGS
|
||||
:outputId: e0078805-bdfd-436e-f94f-7cd256d2574f
|
||||
|
||||
y = jax.device_put(x, sharding1.reshape(4, 2))
|
||||
y + jnp.ones_like(y)
|
||||
y + jnp.arange(y.size).reshape(y.shape)
|
||||
@ -480,20 +599,21 @@ print('no error!')
|
||||
While the compiler will attempt to decide how a function's intermediate values and outputs should be sharded, we can also give it hints using `jax.lax.with_sharding_constraint`. Using `jax.lax.with_sharding_constraint` is much like `jax.device_put`, except we use it inside staged-out (i.e. `jit`-decorated) functions:
|
||||
|
||||
```{code-cell}
|
||||
# TODO(mattjj,yashkatariya): remove cell when with_sharding_constraint is in jax.lax
|
||||
jax.lax.with_sharding_constraint = jax.experimental.pjit.with_sharding_constraint
|
||||
```
|
||||
:id: jniSFm5V3vGT
|
||||
|
||||
```{code-cell}
|
||||
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: Q1wuDp-L3vGT
|
||||
|
||||
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
|
||||
x = jax.device_put(x, sharding.reshape(4, 2))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: rqEDj0wB3vGT
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
x = x + 1
|
||||
@ -502,12 +622,17 @@ def f(x):
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: zYFS-n4r3vGT
|
||||
:outputId: d23a7938-cb7d-44b4-b9c7-83edf1d1145e
|
||||
|
||||
jax.debug.visualize_array_sharding(x)
|
||||
y = f(x)
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: 8g_2Y8wp3vGT
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
x = x + 1
|
||||
@ -516,6 +641,9 @@ def f(x):
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: AiRFtVsR3vGT
|
||||
:outputId: f3e28a70-46cf-46fb-c801-82f0ddb447e4
|
||||
|
||||
jax.debug.visualize_array_sharding(x)
|
||||
y = f(x)
|
||||
jax.debug.visualize_array_sharding(y)
|
||||
@ -540,11 +668,15 @@ It's often a good practice to annotate the outputs of computations, for example
|
||||
We can use `jax.device_put` and `jax.jit`'s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:
|
||||
|
||||
```{code-cell}
|
||||
:id: mEKF3zIF3vGU
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: Mocs3oGe3vGU
|
||||
|
||||
def predict(params, inputs):
|
||||
for W, b in params:
|
||||
outputs = jnp.dot(inputs, W) + b
|
||||
@ -558,11 +690,15 @@ def loss(params, batch):
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: glBB8tzW3vGU
|
||||
|
||||
loss_jit = jax.jit(loss)
|
||||
gradfun = jax.jit(jax.grad(loss))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: R0x62AIa3vGU
|
||||
|
||||
def init_layer(key, n_in, n_out):
|
||||
k1, k2 = jax.random.split(key)
|
||||
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
|
||||
@ -590,19 +726,29 @@ params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)
|
||||
### 8-way batch data parallelism
|
||||
|
||||
```{code-cell}
|
||||
:id: _Q5NbdOn3vGV
|
||||
|
||||
sharding = PositionalSharding(jax.devices()).reshape(8, 1)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: 3KC6ieEe3vGV
|
||||
|
||||
batch = jax.device_put(batch, sharding)
|
||||
params = jax.device_put(params, sharding.replicate())
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: MUb-QE2b3vGV
|
||||
:outputId: 1f831ea5-5a30-49ad-8195-977ff7ed476a
|
||||
|
||||
loss_jit(params, batch)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: HUkw0u413vGV
|
||||
:outputId: dfa2599c-9440-4657-9035-0dc3bbf625e1
|
||||
|
||||
step_size = 1e-5
|
||||
|
||||
for _ in range(30):
|
||||
@ -614,15 +760,23 @@ print(loss_jit(params, batch))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: paCw6Zaj3vGV
|
||||
:outputId: 8ab1c32c-f2b1-465c-df71-f5a599e7f19e
|
||||
|
||||
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: BF86UWpg3vGV
|
||||
|
||||
batch_single = jax.device_put(batch, jax.devices()[0])
|
||||
params_single = jax.device_put(params, jax.devices()[0])
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: Z1wgUKXk3vGV
|
||||
:outputId: 74df8892-c349-41dc-cb1b-e0843ec5c994
|
||||
|
||||
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
|
||||
```
|
||||
|
||||
@ -631,16 +785,23 @@ params_single = jax.device_put(params, jax.devices()[0])
|
||||
### 4-way batch data parallelism and 2-way model tensor parallelism
|
||||
|
||||
```{code-cell}
|
||||
:id: N5-zzgW03vGW
|
||||
|
||||
sharding = sharding.reshape(4, 2)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: sgIWCjJK3vGW
|
||||
:outputId: b2fdc556-05cc-4e68-fa04-48643d194dee
|
||||
|
||||
batch = jax.device_put(batch, sharding.replicate(1))
|
||||
jax.debug.visualize_array_sharding(batch[0])
|
||||
jax.debug.visualize_array_sharding(batch[1])
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: BqCjYCgg3vGW
|
||||
|
||||
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
|
||||
|
||||
W1 = jax.device_put(W1, sharding.replicate())
|
||||
@ -659,18 +820,29 @@ params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: _lSJ63sh3vGW
|
||||
:outputId: 5b37aa8b-3226-4805-8282-876e8d06edda
|
||||
|
||||
jax.debug.visualize_array_sharding(W2)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: fxkfWYkk3vGW
|
||||
:outputId: 8a1063c3-540b-47c1-d990-a6845da861f7
|
||||
|
||||
jax.debug.visualize_array_sharding(W3)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: uPCVs-_k3vGW
|
||||
:outputId: de01cdfc-36cb-4823-c692-22c692ef4220
|
||||
|
||||
print(loss_jit(params, batch))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: L9JebLK_3vGW
|
||||
|
||||
step_size = 1e-5
|
||||
|
||||
for _ in range(30):
|
||||
@ -680,16 +852,25 @@ for _ in range(30):
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: c9Sbl69e3vGX
|
||||
:outputId: 8272c5fa-e59f-4953-c2d5-658c42a28712
|
||||
|
||||
print(loss_jit(params, batch))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: lkAF0dAb3vGX
|
||||
:outputId: acf0df31-c5e1-4683-b73f-b0cd1b0929f8
|
||||
|
||||
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
|
||||
jax.debug.visualize_array_sharding(W2)
|
||||
jax.debug.visualize_array_sharding(W3)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: I1Npor3i3vGX
|
||||
:outputId: 4099f6dd-7b46-4123-c1cb-5173c3d3278e
|
||||
|
||||
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
|
||||
```
|
||||
|
||||
@ -712,6 +893,8 @@ However, the existing stable RNG implementation is not automatically partitionab
|
||||
Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:
|
||||
|
||||
```{code-cell}
|
||||
:id: kwS-aQE_3vGX
|
||||
|
||||
@jax.jit
|
||||
def f(key, x):
|
||||
numbers = jax.random.uniform(key, x.shape)
|
||||
@ -727,6 +910,9 @@ x = jax.device_put(jnp.arange(24), x_sharding)
|
||||
On a partitioned input, the function `f` produces output that is also partitioned:
|
||||
|
||||
```{code-cell}
|
||||
:id: Oi97rpLz3vGY
|
||||
:outputId: 204a7e8d-dc88-4b77-b7e3-0e72f306c5d3
|
||||
|
||||
jax.debug.visualize_array_sharding(f(key, x))
|
||||
```
|
||||
|
||||
@ -735,6 +921,9 @@ jax.debug.visualize_array_sharding(f(key, x))
|
||||
But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication:
|
||||
|
||||
```{code-cell}
|
||||
:id: 64wIZuSJ3vGY
|
||||
:outputId: 1054fe99-0476-44ec-9693-b0d8f98bf6a8
|
||||
|
||||
f_exe = f.lower(key, x).compile()
|
||||
print('Communicating?', 'collective-permute' in f_exe.as_text())
|
||||
```
|
||||
@ -744,6 +933,9 @@ print('Communicating?', 'collective-permute' in f_exe.as_text())
|
||||
One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation:
|
||||
|
||||
```{code-cell}
|
||||
:id: 1I7bqxA63vGY
|
||||
:outputId: ec4c579d-f446-4b48-ceda-785c09ba299b
|
||||
|
||||
jax.config.update('jax_threefry_partitionable', True)
|
||||
f_exe = f.lower(key, x).compile()
|
||||
print('Communicating?', 'collective-permute' in f_exe.as_text())
|
||||
@ -754,6 +946,9 @@ print('Communicating?', 'collective-permute' in f_exe.as_text())
|
||||
The output is still partitioned:
|
||||
|
||||
```{code-cell}
|
||||
:id: zHPJzdn23vGY
|
||||
:outputId: a8904d20-4d04-4f59-8eae-281e47d29246
|
||||
|
||||
jax.debug.visualize_array_sharding(f(key, x))
|
||||
```
|
||||
|
||||
@ -762,6 +957,9 @@ jax.debug.visualize_array_sharding(f(key, x))
|
||||
One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:
|
||||
|
||||
```{code-cell}
|
||||
:id: nBUHBBal3vGY
|
||||
:outputId: f194c213-0688-4b7a-ffb8-c4453b82b1f1
|
||||
|
||||
jax.config.update('jax_threefry_partitionable', False)
|
||||
print('Stable:')
|
||||
print(f(key, x))
|
||||
|
Loading…
x
Reference in New Issue
Block a user