mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
2617 lines
109 KiB
Plaintext
2617 lines
109 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "PxHrg4Cjuapm"
|
|
},
|
|
"source": [
|
|
"# Distributed arrays and automatic parallelization\n",
|
|
"\n",
|
|
"<!--* freshness: { reviewed: '2024-04-16' } *-->"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pFtQjv4SzHRj"
|
|
},
|
|
"source": [
|
|
"[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
|
|
"\n",
|
|
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"id": "FNxScTfq3vGF"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"from typing import Optional\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"import jax\n",
|
|
"import jax.numpy as jnp"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "eyHMwyEfQJcz"
|
|
},
|
|
"source": [
|
|
"⚠️ WARNING: The notebook requires 8 devices to run."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"id": "IZMLqOUV3vGG"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"if len(jax.local_devices()) < 8:\n",
|
|
" raise Exception(\"Notebook requires 8 devices to run\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3f37ca93"
|
|
},
|
|
"source": [
|
|
"## Intro and a quick example\n",
|
|
"\n",
|
|
"By reading this tutorial notebook, you'll learn about `jax.Array`, a unified \n",
|
|
"datatype for representing arrays, even with physical storage spanning multiple\n",
|
|
"devices. You'll also learn about how using `jax.Array`s together with `jax.jit`\n",
|
|
"can provide automatic compiler-based parallelization.\n",
|
|
"\n",
|
|
"Before we think step by step, here's a quick example.\n",
|
|
"First, we'll create a `jax.Array` sharded across multiple devices:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"id": "Gf2lO4ii3vGG"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jax.experimental import mesh_utils\n",
|
|
"from jax.sharding import Mesh, PartitionSpec as P, NamedSharding"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"id": "q-XBTEoy3vGG"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a Sharding object to distribute a value across devices:\n",
|
|
"mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),\n",
|
|
" axis_names=('x', 'y'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 166
|
|
},
|
|
"id": "vI39znW93vGH",
|
|
"outputId": "4f702753-8add-4b65-a4af-0f18f098cc46"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Create an array of random values:\n",
|
|
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
|
|
"# and use jax.device_put to distribute it across devices:\n",
|
|
"y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "jZ0ZY9Um9Jg4"
|
|
},
|
|
"source": [
|
|
"Next, we'll apply a computation to it and visualize how the result values are\n",
|
|
"stored across multiple devices too:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 166
|
|
},
|
|
"id": "-qCnHZl83vGI",
|
|
"outputId": "0e131c23-5765-43ae-f232-6417ae1acbb2"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"z = jnp.sin(y)\n",
|
|
"jax.debug.visualize_array_sharding(z)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "5qccVQoE9tEi"
|
|
},
|
|
"source": [
|
|
"The evaluation of the `jnp.sin` application was automatically parallelized\n",
|
|
"across the devices on which the input values (and output values) are stored:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "_VTzN0r03vGI",
|
|
"outputId": "c03eecab-4c86-4dac-d776-5fc72cbb5273"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
|
|
"25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# `x` is present on a single device\n",
|
|
"%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "QuzhU1g63vGI",
|
|
"outputId": "8135cca0-871b-4b6a-a7e5-02e78c2028c7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# `y` is sharded across 8 devices.\n",
|
|
"%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xWknFQbQ-bzV"
|
|
},
|
|
"source": [
|
|
"Now let's look at each of these pieces in more detail!\n",
|
|
"\n",
|
|
"\n",
|
|
"## `Sharding` describes how array values are laid out in memory across devices"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "W6HsXauGxL6w"
|
|
},
|
|
"source": [
|
|
"### Sharding basics, and the `NamedSharding` subclass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "NWDyp_EjVHkg"
|
|
},
|
|
"source": [
|
|
"To parallelize computation across multiple devices, we first must lay out input data across multiple devices.\n",
|
|
"\n",
|
|
"In JAX, `Sharding` objects describe distributed memory layouts. They can be used with `jax.device_put` to produce a value with distributed layout.\n",
|
|
"\n",
|
|
"For example, here's a value with a single-device `Sharding`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"id": "VmoX4SUp3vGJ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax\n",
|
|
"x = jax.random.normal(jax.random.key(0), (8192, 8192))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 199
|
|
},
|
|
"id": "vNRabO2J3vGJ",
|
|
"outputId": "40fd7172-a16c-4dd8-e2e1-17bb3afe5409"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.debug.visualize_array_sharding(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "HhCjhK0zXIqX"
|
|
},
|
|
"source": [
|
|
"Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n",
|
|
"\n",
|
|
"But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 166
|
|
},
|
|
"id": "zpB1JxyK3vGN",
|
|
"outputId": "8e385462-1c2c-4256-c38a-84299d3bd02c"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from jax.sharding import Mesh, PartitionSpec, NamedSharding\n",
|
|
"from jax.experimental import mesh_utils\n",
|
|
"\n",
|
|
"P = PartitionSpec\n",
|
|
"\n",
|
|
"devices = mesh_utils.create_device_mesh((4, 2))\n",
|
|
"mesh = Mesh(devices, axis_names=('a', 'b'))\n",
|
|
"y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "OW_Cc92G1-nr"
|
|
},
|
|
"source": [
|
|
"We can define a helper function to make things simpler:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {
|
|
"id": "8g0Md2Gd3vGO"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"devices = mesh_utils.create_device_mesh((4, 2))\n",
|
|
"default_mesh = Mesh(devices, axis_names=('a', 'b'))\n",
|
|
"\n",
|
|
"def mesh_sharding(\n",
|
|
" pspec: PartitionSpec, mesh: Optional[Mesh] = None,\n",
|
|
" ) -> NamedSharding:\n",
|
|
" if mesh is None:\n",
|
|
" mesh = default_mesh\n",
|
|
" return NamedSharding(mesh, pspec)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 166
|
|
},
|
|
"id": "zp3MfS4Y3vGO",
|
|
"outputId": "032fdd7e-19a1-45da-e1ad-b3227fa43ee6"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = jax.device_put(x, mesh_sharding(P('a', 'b')))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xZ88riVm1mv5"
|
|
},
|
|
"source": [
|
|
"Here, we use `P('a', 'b')` to express that the first and second axes of `x` should be sharded over the device mesh axes `'a'` and `'b'`, respectively. We can easily switch to `P('b', 'a')` to shard the axes of `x` over different devices:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 199
|
|
},
|
|
"id": "FigK5Zsa3vGO",
|
|
"outputId": "e488d073-9d02-4376-a6af-19d6d5509c7d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┬───────┬───────┬───────┐\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"├───────┼───────┼───────┼───────┤\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"└───────┴───────┴───────┴───────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────┬───────┬───────┬───────┐\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"├───────┼───────┼───────┼───────┤\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"└───────┴───────┴───────┴───────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = jax.device_put(x, mesh_sharding(P('b', 'a')))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 166
|
|
},
|
|
"id": "hI-HD0xN3vGO",
|
|
"outputId": "b0c2e863-3aee-4417-b45f-21b2187f6ef7"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# This `None` means that `x` is not sharded on its second dimension,\n",
|
|
"# and since the Mesh axis name 'b' is not mentioned, shards are\n",
|
|
"# replicated across it.\n",
|
|
"y = jax.device_put(x, mesh_sharding(P('a', None)))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "AqcAsNUgXCZz"
|
|
},
|
|
"source": [
|
|
"Here, because `P('a', None)` doesn't mention the `Mesh` axis name `'b'`, we get replication over the axis `'b'`. The `None` here is just acting as a placeholder to line up against the second axis of the value `x`, without expressing sharding over any mesh axis. (As a shorthand, trailing `None`s can be omitted, so that `P('a', None)` means the same thing as `P('a')`. But it doesn't hurt to be explicit!)\n",
|
|
"\n",
|
|
"To shard only over the second axis of `x`, we can use a `None` placeholder in the `PartitionSpec`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 199
|
|
},
|
|
"id": "EXBExMQC3vGP",
|
|
"outputId": "c80e6177-12a6-40ef-b4e4-934dad22da3d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = jax.device_put(x, mesh_sharding(P(None, 'b')))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 199
|
|
},
|
|
"id": "PjUpG8uz3vGP",
|
|
"outputId": "a0f59dc5-b509-4b8b-bd22-bcd69f696763"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┬───────┬───────┬───────┐\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>│\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"└───────┴───────┴───────┴───────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────┬───────┬───────┬───────┐\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"└───────┴───────┴───────┴───────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = jax.device_put(x, mesh_sharding(P(None, 'a')))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "--AZgW1P3HFT"
|
|
},
|
|
"source": [
|
|
"For a fixed mesh, we can even partition one logical axis of `x` over multiple device mesh axes:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 298
|
|
},
|
|
"id": "fVcPbDUA3vGP",
|
|
"outputId": "da3f435d-dfc1-4a41-ec90-691cd7c748a0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c1tTFudr3Ae7"
|
|
},
|
|
"source": [
|
|
"Using `NamedSharding` makes it easy to define a device mesh once and give its axes names, then just refer to those names in `PartitionSpec`s for each `device_put` as needed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "rhWzHgGf4mkg"
|
|
},
|
|
"source": [
|
|
"## Computation follows data sharding and is automatically parallelized"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "JukoaRhl4tXJ"
|
|
},
|
|
"source": [
|
|
"With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with `jax.jit` can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.\n",
|
|
"\n",
|
|
"For example, the simplest computation is an elementwise one:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {
|
|
"id": "_EmQwggc3vGQ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"devices = mesh_utils.create_device_mesh((4, 2))\n",
|
|
"mesh = Mesh(devices, axis_names=('a', 'b'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 349
|
|
},
|
|
"id": "LnT0vWjc3vGQ",
|
|
"outputId": "8e642049-61eb-458d-af79-ac449b58d11b"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input sharding:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"output sharding:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n",
|
|
"print('input sharding:')\n",
|
|
"jax.debug.visualize_array_sharding(x)\n",
|
|
"\n",
|
|
"y = jnp.sin(x)\n",
|
|
"print('output sharding:')\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7tY2gVRfazaT"
|
|
},
|
|
"source": [
|
|
"Here for the elementwise operation `jnp.sin` the compiler chose the output sharding to be the same as the input. Moreover, the compiler automatically parallelized the computation, so that each device computed its output shard from its input shard in parallel.\n",
|
|
"\n",
|
|
"In other words, even though we wrote the `jnp.sin` computation as if a single machine were to execute it, the compiler splits up the computation for us and executes it on multiple devices.\n",
|
|
"\n",
|
|
"We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 548
|
|
},
|
|
"id": "Dq043GkP3vGQ",
|
|
"outputId": "3eff7b67-d7f0-4212-c9d3-2cc271ac1f98"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"lhs sharding:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"rhs sharding:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"out sharding:\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"y = jax.device_put(x, NamedSharding(mesh, P('a', None)))\n",
|
|
"z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))\n",
|
|
"print('lhs sharding:')\n",
|
|
"jax.debug.visualize_array_sharding(y)\n",
|
|
"print('rhs sharding:')\n",
|
|
"jax.debug.visualize_array_sharding(z)\n",
|
|
"\n",
|
|
"w = jnp.dot(y, z)\n",
|
|
"print('out sharding:')\n",
|
|
"jax.debug.visualize_array_sharding(w)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_EPNaWzgazft"
|
|
},
|
|
"source": [
|
|
"Here the compiler chose the output sharding so that it could maximally parallelize the computation: without needing communication, each device already has the input shards it needs to compute its output shard.\n",
|
|
"\n",
|
|
"How can we be sure it's actually running in parallel? We can do a simple timing experiment:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 199
|
|
},
|
|
"id": "QjQ5u8qh3vGQ",
|
|
"outputId": "0aefc170-833c-4a6a-e003-5990d3db31d9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"x_single = jax.device_put(x, jax.devices()[0])\n",
|
|
"jax.debug.visualize_array_sharding(x_single)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "8tn8lOj73vGR",
|
|
"outputId": "d9898c93-7afc-416b-8c40-4d9551613cd0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"True"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"np.allclose(jnp.dot(x_single, x_single),\n",
|
|
" jnp.dot(y, z))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "D7PpZwhR3vGR",
|
|
"outputId": "4901a11b-2354-4d26-a897-b88def07a716"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "rgo_yVHF3vGR",
|
|
"outputId": "e51216cf-b073-4250-d422-67f9fd72f6aa"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "gglQIMXJnnJw"
|
|
},
|
|
"source": [
|
|
"Even copying a sharded `Array` produces a result with the sharding of the input:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 166
|
|
},
|
|
"id": "f1Zw-2lH3vGR",
|
|
"outputId": "43d7a642-fde4-47a6-901f-dfdc64d6a613"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"w_copy = jnp.copy(w)\n",
|
|
"jax.debug.visualize_array_sharding(w_copy)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3qfPjJdhgerc"
|
|
},
|
|
"source": [
|
|
"So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "QRB95LaWuT80"
|
|
},
|
|
"source": [
|
|
"### When explicit shardings disagree, JAX errors\n",
|
|
"\n",
|
|
"But what if two arguments to a computation are explicitly placed on different sets of devices, or with incompatible device orders?\n",
|
|
"In these ambiguous cases, an error is raised:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 94,
|
|
"metadata": {
|
|
"id": "1vAkZAOY3vGR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import textwrap\n",
|
|
"from termcolor import colored\n",
|
|
"\n",
|
|
"def print_exception(e):\n",
|
|
" name = colored(f'{type(e).__name__}', 'red', force_color=True)\n",
|
|
" print(textwrap.fill(f'{name}: {str(e)}'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 95,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "DHh0N3vn3vGS",
|
|
"outputId": "8c4652f7-c484-423b-ad78-182134280187"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n",
|
|
"computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n",
|
|
"device ids [0, 1, 2, 3] on platform TPU and argument x2 of\n",
|
|
"jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on\n",
|
|
"platform TPU\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))\n",
|
|
"sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))\n",
|
|
"\n",
|
|
"y = jax.device_put(x, sharding1)\n",
|
|
"z = jax.device_put(x, sharding2)\n",
|
|
"try: y + z\n",
|
|
"except ValueError as e: print_exception(e)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 96,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "Im7DkoOl3vGS",
|
|
"outputId": "1b6fcd7a-762b-4366-a96d-aea63bad7fe0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n",
|
|
"computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n",
|
|
"device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of\n",
|
|
"jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,\n",
|
|
"4, 5] on platform TPU\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"devices = jax.devices()\n",
|
|
"permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]\n",
|
|
"\n",
|
|
"sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))\n",
|
|
"sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))\n",
|
|
"\n",
|
|
"y = jax.device_put(x, sharding1)\n",
|
|
"z = jax.device_put(x, sharding2)\n",
|
|
"try: y + z\n",
|
|
"except ValueError as e: print_exception(e)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "6ZYcK8eXrn0p"
|
|
},
|
|
"source": [
|
|
"We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n",
|
|
"\n",
|
|
"When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device.\n",
|
|
"Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.\n",
|
|
"\n",
|
|
"For example, the output of `jnp.zeros`, `jnp.arange`, and `jnp.array` are uncommitted:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "_QvtKL8r3vGS",
|
|
"outputId": "761b1208-fe4b-4c09-a7d2-f62152183ef0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"no error!\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"y = jax.device_put(x, sharding1)\n",
|
|
"y + jnp.ones_like(y)\n",
|
|
"y + jnp.arange(y.size).reshape(y.shape)\n",
|
|
"print('no error!')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "dqMKl79NaIWF"
|
|
},
|
|
"source": [
|
|
"## Constraining shardings of intermediates in `jit`ted code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "g4LrDDcJwkHc"
|
|
},
|
|
"source": [
|
|
"While the compiler will attempt to decide how a function's intermediate values and outputs should be sharded, we can also give it hints using `jax.lax.with_sharding_constraint`. Using `jax.lax.with_sharding_constraint` is much like `jax.device_put`, except we use it inside staged-out (i.e. `jit`-decorated) functions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"metadata": {
|
|
"id": "jniSFm5V3vGT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {
|
|
"id": "Q1wuDp-L3vGT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
|
|
"x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"metadata": {
|
|
"id": "rqEDj0wB3vGT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def f(x):\n",
|
|
" x = x + 1\n",
|
|
" y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))\n",
|
|
" return y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 347
|
|
},
|
|
"id": "zYFS-n4r3vGT",
|
|
"outputId": "0ac96b8f-ed23-4413-aed9-edd00a841c37"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┬───────┬───────┬───────┐\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"├───────┼───────┼───────┼───────┤\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"└───────┴───────┴───────┴───────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────┬───────┬───────┬───────┐\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"├───────┼───────┼───────┼───────┤\n",
|
|
"│ │ │ │ │\n",
|
|
"│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"│ │ │ │ │\n",
|
|
"│ │ │ │ │\n",
|
|
"└───────┴───────┴───────┴───────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.debug.visualize_array_sharding(x)\n",
|
|
"y = f(x)\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {
|
|
"id": "8g_2Y8wp3vGT"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def f(x):\n",
|
|
" x = x + 1\n",
|
|
" y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))\n",
|
|
" return y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 347
|
|
},
|
|
"id": "AiRFtVsR3vGT",
|
|
"outputId": "2edacc2c-ac80-4519-c9d1-bee364a22b31"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌──────────┬──────────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │\n",
|
|
"└──────────┴──────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌──────────┬──────────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"├──────────┼──────────┤\n",
|
|
"│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
|
|
"└──────────┴──────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.debug.visualize_array_sharding(x)\n",
|
|
"y = f(x)\n",
|
|
"jax.debug.visualize_array_sharding(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "_Y1P5wLTzJSz"
|
|
},
|
|
"source": [
|
|
"By adding `with_sharding_constraint`, we've constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiler will use annotations to decide shardings for other values.\n",
|
|
"\n",
|
|
"It's often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "QUkXWG-baMUs"
|
|
},
|
|
"source": [
|
|
"## Examples: neural networks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "g7y0OJBSGoSW"
|
|
},
|
|
"source": [
|
|
"**⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with `jax.Array`, but it may not reflect best practices for real examples.** For instance, real examples may require more use of `with_sharding_constraint`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3ii_UPkG3gzP"
|
|
},
|
|
"source": [
|
|
"We can use `jax.device_put` and `jax.jit`'s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"metadata": {
|
|
"id": "mEKF3zIF3vGU"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax\n",
|
|
"import jax.numpy as jnp"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"metadata": {
|
|
"id": "Mocs3oGe3vGU"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def predict(params, inputs):\n",
|
|
" for W, b in params:\n",
|
|
" outputs = jnp.dot(inputs, W) + b\n",
|
|
" inputs = jnp.maximum(outputs, 0)\n",
|
|
" return outputs\n",
|
|
"\n",
|
|
"def loss(params, batch):\n",
|
|
" inputs, targets = batch\n",
|
|
" predictions = predict(params, inputs)\n",
|
|
" return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 50,
|
|
"metadata": {
|
|
"id": "glBB8tzW3vGU"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"loss_jit = jax.jit(loss)\n",
|
|
"gradfun = jax.jit(jax.grad(loss))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 51,
|
|
"metadata": {
|
|
"id": "R0x62AIa3vGU"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def init_layer(key, n_in, n_out):\n",
|
|
" k1, k2 = jax.random.split(key)\n",
|
|
" W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n",
|
|
" b = jax.random.normal(k2, (n_out,))\n",
|
|
" return W, b\n",
|
|
"\n",
|
|
"def init_model(key, layer_sizes, batch_size):\n",
|
|
" key, *keys = jax.random.split(key, len(layer_sizes))\n",
|
|
" params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
|
|
"\n",
|
|
" key, *keys = jax.random.split(key, 3)\n",
|
|
" inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n",
|
|
" targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n",
|
|
"\n",
|
|
" return params, (inputs, targets)\n",
|
|
"\n",
|
|
"layer_sizes = [784, 8192, 8192, 8192, 10]\n",
|
|
"batch_size = 8192\n",
|
|
"\n",
|
|
"params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "sJv_h0AS2drh"
|
|
},
|
|
"source": [
|
|
"### 8-way batch data parallelism"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 52,
|
|
"metadata": {
|
|
"id": "mJLqRPpSDX0i"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 54,
|
|
"metadata": {
|
|
"id": "_Q5NbdOn3vGV"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"sharding = NamedSharding(mesh, P('batch'))\n",
|
|
"replicated_sharding = NamedSharding(mesh, P())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 55,
|
|
"metadata": {
|
|
"id": "3KC6ieEe3vGV"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch = jax.device_put(batch, sharding)\n",
|
|
"params = jax.device_put(params, replicated_sharding)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "MUb-QE2b3vGV",
|
|
"outputId": "5a27f007-c572-44f8-9f49-6e745ee739e8"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array(23.469475, dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 56,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"loss_jit(params, batch)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "HUkw0u413vGV",
|
|
"outputId": "07e481a1-97fb-4bd0-d754-cb6d8317bff6"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10.760109\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"step_size = 1e-5\n",
|
|
"\n",
|
|
"for _ in range(30):\n",
|
|
" grads = gradfun(params, batch)\n",
|
|
" params = [(W - step_size * dW, b - step_size * db)\n",
|
|
" for (W, b), (dW, db) in zip(params, grads)]\n",
|
|
"\n",
|
|
"print(loss_jit(params, batch))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "paCw6Zaj3vGV",
|
|
"outputId": "ad4cce34-3a6a-4d44-9a86-477a7fee4841"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 59,
|
|
"metadata": {
|
|
"id": "BF86UWpg3vGV"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_single = jax.device_put(batch, jax.devices()[0])\n",
|
|
"params_single = jax.device_put(params, jax.devices()[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 60,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "Z1wgUKXk3vGV",
|
|
"outputId": "d66767b7-3f17-482f-b811-919bb1793277"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3AjeeB7B4NP6"
|
|
},
|
|
"source": [
|
|
"### 4-way batch data parallelism and 2-way model tensor parallelism"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 61,
|
|
"metadata": {
|
|
"id": "k1hxOfgRDwo0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 62,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 314
|
|
},
|
|
"id": "sgIWCjJK3vGW",
|
|
"outputId": "8cb0f19f-3942-415c-c57a-31bb81784f46"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┐\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>│\n",
|
|
"├───────┤\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>│\n",
|
|
"├───────┤\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>│\n",
|
|
"├───────┤\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>│\n",
|
|
"└───────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────┐\n",
|
|
"│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n",
|
|
"├───────┤\n",
|
|
"│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n",
|
|
"├───────┤\n",
|
|
"│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
|
|
"├───────┤\n",
|
|
"│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n",
|
|
"└───────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┐\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>│\n",
|
|
"├───────┤\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>│\n",
|
|
"├───────┤\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>│\n",
|
|
"├───────┤\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>│\n",
|
|
"└───────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────┐\n",
|
|
"│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n",
|
|
"├───────┤\n",
|
|
"│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n",
|
|
"├───────┤\n",
|
|
"│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
|
|
"├───────┤\n",
|
|
"│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n",
|
|
"└───────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))\n",
|
|
"jax.debug.visualize_array_sharding(batch[0])\n",
|
|
"jax.debug.visualize_array_sharding(batch[1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "q9PQP-0eEAO6"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"replicated_sharding = NamedSharding(mesh, P())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 67,
|
|
"metadata": {
|
|
"id": "BqCjYCgg3vGW"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params\n",
|
|
"\n",
|
|
"W1 = jax.device_put(W1, replicated_sharding)\n",
|
|
"b1 = jax.device_put(b1, replicated_sharding)\n",
|
|
"\n",
|
|
"W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))\n",
|
|
"b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))\n",
|
|
"\n",
|
|
"W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))\n",
|
|
"b3 = jax.device_put(b3, replicated_sharding)\n",
|
|
"\n",
|
|
"W4 = jax.device_put(W4, replicated_sharding)\n",
|
|
"b4 = jax.device_put(b4, replicated_sharding)\n",
|
|
"\n",
|
|
"params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 68,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 199
|
|
},
|
|
"id": "_lSJ63sh3vGW",
|
|
"outputId": "bcd3e33e-36b5-4787-9cd2-60623fd6e5fa"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.debug.visualize_array_sharding(W2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 69,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 199
|
|
},
|
|
"id": "fxkfWYkk3vGW",
|
|
"outputId": "59e60b16-fe37-47d4-8214-96096ffbd79c"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ │\n",
|
|
"│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.debug.visualize_array_sharding(W3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 70,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "uPCVs-_k3vGW",
|
|
"outputId": "618516e9-9736-4ca0-dd22-09d094ce57a2"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10.760109\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(loss_jit(params, batch))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 71,
|
|
"metadata": {
|
|
"id": "L9JebLK_3vGW"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"step_size = 1e-5\n",
|
|
"\n",
|
|
"for _ in range(30):\n",
|
|
" grads = gradfun(params, batch)\n",
|
|
" params = [(W - step_size * dW, b - step_size * db)\n",
|
|
" for (W, b), (dW, db) in zip(params, grads)]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 72,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "c9Sbl69e3vGX",
|
|
"outputId": "2ee3d432-7172-46ca-e01a-614e83345808"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10.752513\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(loss_jit(params, batch))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 73,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 380
|
|
},
|
|
"id": "lkAF0dAb3vGX",
|
|
"outputId": "6c1e317e-cded-4af4-8080-0de835fa4c71"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>│TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────┬───────────┐\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"│ │ │\n",
|
|
"└───────────┴───────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ │\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────────────────────┐\n",
|
|
"│ │\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"├───────────────────────┤\n",
|
|
"│ │\n",
|
|
"│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
|
|
"│ │\n",
|
|
"│ │\n",
|
|
"└───────────────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params\n",
|
|
"jax.debug.visualize_array_sharding(W2)\n",
|
|
"jax.debug.visualize_array_sharding(W3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "I1Npor3i3vGX",
|
|
"outputId": "479c4d81-cb0b-40a5-89ba-394c10dc3297"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "3diqi5VRBy6S"
|
|
},
|
|
"source": [
|
|
"## Sharp bits"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "OTfoXNnxFYDJ"
|
|
},
|
|
"source": [
|
|
"### Generating random numbers\n",
|
|
"\n",
|
|
"JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.\n",
|
|
"\n",
|
|
"JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.\n",
|
|
"\n",
|
|
"However, the existing stable RNG implementation is not automatically partitionable, for historical reasons."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ht_zYFVXNrjN"
|
|
},
|
|
"source": [
|
|
"Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 75,
|
|
"metadata": {
|
|
"id": "kwS-aQE_3vGX"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def f(key, x):\n",
|
|
" numbers = jax.random.uniform(key, x.shape)\n",
|
|
" return x + numbers\n",
|
|
"\n",
|
|
"key = jax.random.key(42)\n",
|
|
"mesh = Mesh(jax.devices(), 'x')\n",
|
|
"x_sharding = NamedSharding(mesh, P('x'))\n",
|
|
"x = jax.device_put(jnp.arange(24), x_sharding)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ZgSA9x9NLMaP"
|
|
},
|
|
"source": [
|
|
"On a partitioned input, the function `f` produces output that is also partitioned:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 76,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 67
|
|
},
|
|
"id": "Oi97rpLz3vGY",
|
|
"outputId": "9dd63254-a483-4847-c0f5-5a4367bf08e9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.debug.visualize_array_sharding(f(key, x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WnjlWDUYLkp6"
|
|
},
|
|
"source": [
|
|
"But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "64wIZuSJ3vGY",
|
|
"outputId": "fa166d45-ca9c-457a-be84-bcc9236d0730"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Communicating? True\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"f_exe = f.lower(key, x).compile()\n",
|
|
"print('Communicating?', 'collective-permute' in f_exe.as_text())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "AXp9i8fbL8DD"
|
|
},
|
|
"source": [
|
|
"One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the \"collective permute\" operation is now gone from the compiled computation:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 78,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "1I7bqxA63vGY",
|
|
"outputId": "756e0a36-ff14-438f-bbd4-3ef03f97a47b"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Communicating? False\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.config.update('jax_threefry_partitionable', True)\n",
|
|
"f_exe = f.lower(key, x).compile()\n",
|
|
"print('Communicating?', 'collective-permute' in f_exe.as_text())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WV8ZccM5SXOU"
|
|
},
|
|
"source": [
|
|
"The output is still partitioned:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 79,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 67
|
|
},
|
|
"id": "zHPJzdn23vGY",
|
|
"outputId": "3332de0f-4827-4f0b-b9ef-69249b7c6bc6"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
|
|
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span> │\n",
|
|
"└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
|
|
"│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
|
|
"└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.debug.visualize_array_sharding(f(key, x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "kaK--hPmSPpV"
|
|
},
|
|
"source": [
|
|
"One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 80,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "nBUHBBal3vGY",
|
|
"outputId": "4b9be948-ccab-4a31-a06f-37ec9c7b5235"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Stable:\n",
|
|
"[ 0.72503686 1.8532515 2.983416 3.083253 4.0332246 5.4782867\n",
|
|
" 6.1720605 7.6900277 8.602836 9.810046 10.861367 11.907651\n",
|
|
" 12.330483 13.456195 14.808557 15.960099 16.067581 17.739723\n",
|
|
" 18.335474 19.46401 20.390276 21.116539 22.858128 23.223194 ]\n",
|
|
"\n",
|
|
"Partitionable:\n",
|
|
"[ 0.48870957 1.6797972 2.6162715 3.561016 4.4506445 5.585866\n",
|
|
" 6.0748096 7.775133 8.698959 9.818634 10.350306 11.87282\n",
|
|
" 12.925881 13.86013 14.477554 15.818481 16.711355 17.586697\n",
|
|
" 18.073738 19.777622 20.404566 21.119123 22.026257 23.63918 ]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.config.update('jax_threefry_partitionable', False)\n",
|
|
"print('Stable:')\n",
|
|
"print(f(key, x))\n",
|
|
"print()\n",
|
|
"\n",
|
|
"jax.config.update('jax_threefry_partitionable', True)\n",
|
|
"print('Partitionable:')\n",
|
|
"print(f(key, x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "8BDPqgOrTMfK"
|
|
},
|
|
"source": [
|
|
"In `jax_threefry_partitionable` mode, the JAX PRNG remains deterministic, but its implementation is new (and under development). The random values generated for a given key will be the same at a given JAX version (or a given commit on the `main` branch), but may vary across releases."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "TPU",
|
|
"colab": {
|
|
"gpuType": "V28",
|
|
"provenance": [],
|
|
"toc_visible": true
|
|
},
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|