Merge pull request #13638 from google:fix

PiperOrigin-RevId: 495087417
This commit is contained in:
jax authors 2022-12-13 11:59:40 -08:00
commit c6eb632f57
3 changed files with 579 additions and 192 deletions

View File

@ -6,7 +6,7 @@ It provides composable transformations of Python+NumPy programs: differentiate,
parallelize, Just-In-Time compile to GPU/TPU, and more.
.. note::
JAX 0.4.0 introduces new parallelism APIs, including breaking changes to :func:`jax.experimental.pjit` and a new unified ``jax.Array`` type.
JAX 0.4.1 introduces new parallelism APIs, including breaking changes to :func:`jax.experimental.pjit` and a new unified ``jax.Array`` type.
Please see `Distributed arrays and automatic parallelization <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_ tutorial and the :ref:`jax-array-migration`
guide for more information.

View File

@ -17,13 +17,15 @@ kernelspec:
+++ {"id": "pFtQjv4SzHRj"}
**This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.0 and newer.**
**This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.**
See [`jax-array-migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide for migrating existing pre-v0.4.0 codebases to `jax.Array`.
See [`jax-array-migration`](https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration) guide for migrating existing pre-v0.4.1 codebases to `jax.Array`.
**The features required by `jax.Array` are not supported by the Colab TPU runtime at this time, but are available on Cloud TPU.**
```{code-cell}
:id: FNxScTfq3vGF
import os
import functools
@ -33,8 +35,6 @@ import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_array', True)
```
+++ {"id": "eyHMwyEfQJcz"}
@ -42,6 +42,8 @@ jax.config.update('jax_array', True)
⚠️ WARNING: notebook requires 8 devices to run.
```{code-cell}
:id: IZMLqOUV3vGG
if len(jax.local_devices()) < 8:
raise Exception("Notebook requires 8 devices to run")
```
@ -59,16 +61,23 @@ Before we think step by step, here's a quick example.
First, we'll create a `jax.Array` sharded across multiple devices:
```{code-cell}
:id: Gf2lO4ii3vGG
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
```
```{code-cell}
:id: q-XBTEoy3vGG
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
```
```{code-cell}
:id: vI39znW93vGH
:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b
# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
@ -82,6 +91,9 @@ Next, we'll apply a computation to it and visualize how the result values are
stored across multiple devices too:
```{code-cell}
:id: -qCnHZl83vGI
:outputId: 9da9c29e-ce88-4425-e1ec-e93e5bcf3106
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
```
@ -92,11 +104,17 @@ The evaluation of the `jnp.sin` application was automatically parallelized
across the devices on which the input values (and output values) are stored:
```{code-cell}
:id: _VTzN0r03vGI
:outputId: c9208010-984b-442b-d105-c8c6a3a010e6
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
```
```{code-cell}
:id: QuzhU1g63vGI
:outputId: d48fc76e-79a7-47b9-d392-b18a1c33c798
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
```
@ -121,11 +139,16 @@ In JAX, `Sharding` objects describe distributed memory layouts. They can be used
For example, here's a value with a single-device `Sharding`:
```{code-cell}
:id: VmoX4SUp3vGJ
import jax
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
```
```{code-cell}
:id: vNRabO2J3vGJ
:outputId: 73db7b6e-c2e7-467d-a0ef-c35e29e582dd
jax.debug.visualize_array_sharding(x)
```
@ -136,6 +159,8 @@ Here, we're using the `jax.debug.visualize_array_sharding` function to show wher
But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:
```{code-cell}
:id: VUIEIzRp3vGK
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
```
@ -145,6 +170,9 @@ devices = mesh_utils.create_device_mesh((8,))
Then, we create a `PositionalSharding` and use it with `device_put`:
```{code-cell}
:id: jwrWfZeB3vGK
:outputId: e6f126bd-f6bd-48c7-c130-6f02757e3342
from jax.sharding import PositionalSharding
sharding = PositionalSharding(devices)
@ -158,6 +186,9 @@ jax.debug.visualize_array_sharding(x)
Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements:
```{code-cell}
:id: zxWB82Kz3vGK
:outputId: 11384a6b-fabc-4c4c-bcad-a3be51eb0465
sharding
```
@ -166,10 +197,16 @@ sharding
By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:
```{code-cell}
:id: PLsnpSzc3vGL
:outputId: 9f4db733-cafe-46ae-c057-dc31046a6f66
sharding.reshape(8, 1)
```
```{code-cell}
:id: iqKdI4LO3vGL
:outputId: 6aa10fc2-cec4-4401-a0df-343e71646e0a
sharding.reshape(4, 2)
```
@ -185,11 +222,17 @@ def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:
For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`:
```{code-cell}
:id: SELr4xNi3vGL
:outputId: b2f4acec-0cd3-4829-ca16-cae2e0e8ca60
sharding = sharding.reshape(4, 2)
print(sharding)
```
```{code-cell}
:id: 8IVIsqfX3vGL
:outputId: 033d0e02-a643-4f4c-9d24-9cd8465bc69a
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)
```
@ -201,11 +244,17 @@ Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are st
Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result:
```{code-cell}
:id: cCjt6QCz3vGM
:outputId: 4ad8a611-596d-424f-b6c5-fc00f1adc306
sharding = sharding.reshape(1, 8)
print(sharding)
```
```{code-cell}
:id: yTK4Nz3u3vGM
:outputId: e445c6bc-4fe3-4e9d-cc9e-d82858f58312
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)
```
@ -217,11 +266,17 @@ In some cases, we don't just want to store each slice of `x` in a single device'
With `PositionalSharding`, we can express replication by calling the reducer method `replicate`:
```{code-cell}
:id: _jr6XYKx3vGM
:outputId: 59c8b9a4-b8af-493a-ba8d-da5931e88f93
sharding = sharding.reshape(4, 2)
print(sharding.replicate(axis=0, keepdims=True))
```
```{code-cell}
:id: S5vzjFuH3vGN
:outputId: b6ce2675-7261-4e57-fa8c-b4e87abf7e52
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y)
```
@ -233,11 +288,17 @@ Here the visualization shows that `x` is sharded two ways along its second dimen
The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:
```{code-cell}
:id: DR7VV-6e3vGN
:outputId: f879fc2c-5723-4199-b306-295bc1b3681e
print(sharding.replicate(0).shape)
print(sharding.replicate(1).shape)
```
```{code-cell}
:id: agUtVUVx3vGN
:outputId: 0e9789ef-ce52-4ed6-8bd5-c876b95f66e6
y = jax.device_put(x, sharding.replicate(1))
jax.debug.visualize_array_sharding(y)
```
@ -253,14 +314,18 @@ So far we've worked with `PositionalSharding`, but there are alternative ways to
Another convenient way to express sharding is with the `NamedSharding`:
```{code-cell}
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental import mesh_utils
:id: zpB1JxyK3vGN
:outputId: 46d5da37-840c-49d8-8380-a162811bae8a
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
P = PartitionSpec
devices = mesh_utils.create_device_mesh((4, 2))
mesh = maps.Mesh(devices, axis_names=('a', 'b'))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
```
@ -270,11 +335,13 @@ jax.debug.visualize_array_sharding(y)
We can define a helper function to make things simpler:
```{code-cell}
:id: 8g0Md2Gd3vGO
devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = maps.Mesh(devices, axis_names=('a', 'b'))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[maps.Mesh] = None,
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
if mesh is None:
mesh = default_mesh
@ -282,6 +349,9 @@ def mesh_sharding(
```
```{code-cell}
:id: zp3MfS4Y3vGO
:outputId: 2c2f7201-c2c1-49e5-f8a5-0730c124d89a
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)
```
@ -291,11 +361,17 @@ jax.debug.visualize_array_sharding(y)
Here, we use `P('a', 'b')` to express that the first and second axes of `x` should be sharded over the device mesh axes `'a'` and `'b'`, respectively. We can easily switch to `P('b', 'a')` to shard the axes of `x` over different devices:
```{code-cell}
:id: FigK5Zsa3vGO
:outputId: eca784e8-33fe-4e9b-a41d-21e9ee781a35
y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
```
```{code-cell}
:id: hI-HD0xN3vGO
:outputId: c3e7dc3c-4048-448a-ef0b-50683532fcdc
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
@ -310,11 +386,17 @@ Here, because `P('a', None)` doesn't mention the `Mesh` axis name `'b'`, we get
To shard only over the second axis of `x`, we can use a `None` placeholder in the `PartitionSpec`:
```{code-cell}
:id: EXBExMQC3vGP
:outputId: fe1c8d7e-3345-4438-b9d2-780e7854b4eb
y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
```
```{code-cell}
:id: PjUpG8uz3vGP
:outputId: 64d8224d-15d9-4ad4-d613-f7f85b1dc1af
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)
```
@ -324,6 +406,9 @@ jax.debug.visualize_array_sharding(y)
For a fixed mesh, we can even partition one logical axis of `x` over multiple device mesh axes:
```{code-cell}
:id: fVcPbDUA3vGP
:outputId: 7f524ba5-a6d8-4490-cda9-685ad11416f9
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
```
@ -343,12 +428,17 @@ With sharded input data, the compiler can give us parallel computation. In parti
For example, the simplest computation is an elementwise one:
```{code-cell}
:id: _EmQwggc3vGQ
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
```
```{code-cell}
:id: LnT0vWjc3vGQ
:outputId: 8089effc-aa4c-49e3-dd19-7064881dbad0
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
@ -367,6 +457,9 @@ In other words, even though we wrote the `jnp.sin` computation as if a single ma
We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:
```{code-cell}
:id: Dq043GkP3vGQ
:outputId: 350219a8-1e4a-4404-fe14-50f97ea3e7ba
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('lhs sharding:')
@ -386,20 +479,32 @@ Here the compiler chose the output sharding so that it could maximally paralleli
How can we be sure it's actually running in parallel? We can do a simple timing experiment:
```{code-cell}
:id: QjQ5u8qh3vGQ
:outputId: bd29edcd-b87c-486e-c568-906f06ae16be
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
```
```{code-cell}
:id: 8tn8lOj73vGR
:outputId: 5809b3c8-7333-4cd3-db97-a7aede943dce
np.allclose(jnp.dot(x_single, x_single),
jnp.dot(y, z))
```
```{code-cell}
:id: D7PpZwhR3vGR
:outputId: 4f0bd43d-0b32-4089-d3da-c8f1449e3526
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
```
```{code-cell}
:id: rgo_yVHF3vGR
:outputId: 97f19052-f1c9-4d30-f453-07b3a7208aa9
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
```
@ -408,6 +513,9 @@ np.allclose(jnp.dot(x_single, x_single),
Even copying a sharded `Array` produces a result with the sharding of the input:
```{code-cell}
:id: f1Zw-2lH3vGR
:outputId: a796bed4-07b0-497d-8fd8-31a22ab9762e
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
```
@ -424,6 +532,8 @@ But what if two arguments to a computation are explicitly placed on different se
In these ambiguous cases, an error is raised:
```{code-cell}
:id: 1vAkZAOY3vGR
import textwrap
from termcolor import colored
@ -433,6 +543,9 @@ def print_exception(e):
```
```{code-cell}
:id: DHh0N3vn3vGS
:outputId: e7741882-0ebf-4237-e5d1-e48c9b9c178f
sharding1 = PositionalSharding(jax.devices()[:4])
sharding2 = PositionalSharding(jax.devices()[4:])
@ -443,6 +556,9 @@ except ValueError as e: print_exception(e)
```
```{code-cell}
:id: Im7DkoOl3vGS
:outputId: 3adfe1cb-db52-4a9d-e98e-62c6455c3100
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
@ -465,6 +581,9 @@ Unlike committed arrays, uncommitted arrays can be moved and resharded automatic
For example, the output of `jnp.zeros`, `jnp.arange`, and `jnp.array` are uncommitted:
```{code-cell}
:id: _QvtKL8r3vGS
:outputId: e0078805-bdfd-436e-f94f-7cd256d2574f
y = jax.device_put(x, sharding1.reshape(4, 2))
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
@ -480,20 +599,21 @@ print('no error!')
While the compiler will attempt to decide how a function's intermediate values and outputs should be sharded, we can also give it hints using `jax.lax.with_sharding_constraint`. Using `jax.lax.with_sharding_constraint` is much like `jax.device_put`, except we use it inside staged-out (i.e. `jit`-decorated) functions:
```{code-cell}
# TODO(mattjj,yashkatariya): remove cell when with_sharding_constraint is in jax.lax
jax.lax.with_sharding_constraint = jax.experimental.pjit.with_sharding_constraint
```
:id: jniSFm5V3vGT
```{code-cell}
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
```
```{code-cell}
:id: Q1wuDp-L3vGT
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))
```
```{code-cell}
:id: rqEDj0wB3vGT
@jax.jit
def f(x):
x = x + 1
@ -502,12 +622,17 @@ def f(x):
```
```{code-cell}
:id: zYFS-n4r3vGT
:outputId: d23a7938-cb7d-44b4-b9c7-83edf1d1145e
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
```
```{code-cell}
:id: 8g_2Y8wp3vGT
@jax.jit
def f(x):
x = x + 1
@ -516,6 +641,9 @@ def f(x):
```
```{code-cell}
:id: AiRFtVsR3vGT
:outputId: f3e28a70-46cf-46fb-c801-82f0ddb447e4
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
@ -540,11 +668,15 @@ It's often a good practice to annotate the outputs of computations, for example
We can use `jax.device_put` and `jax.jit`'s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:
```{code-cell}
:id: mEKF3zIF3vGU
import jax
import jax.numpy as jnp
```
```{code-cell}
:id: Mocs3oGe3vGU
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
@ -558,11 +690,15 @@ def loss(params, batch):
```
```{code-cell}
:id: glBB8tzW3vGU
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
```
```{code-cell}
:id: R0x62AIa3vGU
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
@ -590,19 +726,29 @@ params, batch = init_model(jax.random.PRNGKey(0), layer_sizes, batch_size)
### 8-way batch data parallelism
```{code-cell}
:id: _Q5NbdOn3vGV
sharding = PositionalSharding(jax.devices()).reshape(8, 1)
```
```{code-cell}
:id: 3KC6ieEe3vGV
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate())
```
```{code-cell}
:id: MUb-QE2b3vGV
:outputId: 1f831ea5-5a30-49ad-8195-977ff7ed476a
loss_jit(params, batch)
```
```{code-cell}
:id: HUkw0u413vGV
:outputId: dfa2599c-9440-4657-9035-0dc3bbf625e1
step_size = 1e-5
for _ in range(30):
@ -614,15 +760,23 @@ print(loss_jit(params, batch))
```
```{code-cell}
:id: paCw6Zaj3vGV
:outputId: 8ab1c32c-f2b1-465c-df71-f5a599e7f19e
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
```
```{code-cell}
:id: BF86UWpg3vGV
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
```
```{code-cell}
:id: Z1wgUKXk3vGV
:outputId: 74df8892-c349-41dc-cb1b-e0843ec5c994
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
```
@ -631,16 +785,23 @@ params_single = jax.device_put(params, jax.devices()[0])
### 4-way batch data parallelism and 2-way model tensor parallelism
```{code-cell}
:id: N5-zzgW03vGW
sharding = sharding.reshape(4, 2)
```
```{code-cell}
:id: sgIWCjJK3vGW
:outputId: b2fdc556-05cc-4e68-fa04-48643d194dee
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
```
```{code-cell}
:id: BqCjYCgg3vGW
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
W1 = jax.device_put(W1, sharding.replicate())
@ -659,18 +820,29 @@ params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
```
```{code-cell}
:id: _lSJ63sh3vGW
:outputId: 5b37aa8b-3226-4805-8282-876e8d06edda
jax.debug.visualize_array_sharding(W2)
```
```{code-cell}
:id: fxkfWYkk3vGW
:outputId: 8a1063c3-540b-47c1-d990-a6845da861f7
jax.debug.visualize_array_sharding(W3)
```
```{code-cell}
:id: uPCVs-_k3vGW
:outputId: de01cdfc-36cb-4823-c692-22c692ef4220
print(loss_jit(params, batch))
```
```{code-cell}
:id: L9JebLK_3vGW
step_size = 1e-5
for _ in range(30):
@ -680,16 +852,25 @@ for _ in range(30):
```
```{code-cell}
:id: c9Sbl69e3vGX
:outputId: 8272c5fa-e59f-4953-c2d5-658c42a28712
print(loss_jit(params, batch))
```
```{code-cell}
:id: lkAF0dAb3vGX
:outputId: acf0df31-c5e1-4683-b73f-b0cd1b0929f8
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
```
```{code-cell}
:id: I1Npor3i3vGX
:outputId: 4099f6dd-7b46-4123-c1cb-5173c3d3278e
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
```
@ -712,6 +893,8 @@ However, the existing stable RNG implementation is not automatically partitionab
Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:
```{code-cell}
:id: kwS-aQE_3vGX
@jax.jit
def f(key, x):
numbers = jax.random.uniform(key, x.shape)
@ -727,6 +910,9 @@ x = jax.device_put(jnp.arange(24), x_sharding)
On a partitioned input, the function `f` produces output that is also partitioned:
```{code-cell}
:id: Oi97rpLz3vGY
:outputId: 204a7e8d-dc88-4b77-b7e3-0e72f306c5d3
jax.debug.visualize_array_sharding(f(key, x))
```
@ -735,6 +921,9 @@ jax.debug.visualize_array_sharding(f(key, x))
But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication:
```{code-cell}
:id: 64wIZuSJ3vGY
:outputId: 1054fe99-0476-44ec-9693-b0d8f98bf6a8
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
```
@ -744,6 +933,9 @@ print('Communicating?', 'collective-permute' in f_exe.as_text())
One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation:
```{code-cell}
:id: 1I7bqxA63vGY
:outputId: ec4c579d-f446-4b48-ceda-785c09ba299b
jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
@ -754,6 +946,9 @@ print('Communicating?', 'collective-permute' in f_exe.as_text())
The output is still partitioned:
```{code-cell}
:id: zHPJzdn23vGY
:outputId: a8904d20-4d04-4f59-8eae-281e47d29246
jax.debug.visualize_array_sharding(f(key, x))
```
@ -762,6 +957,9 @@ jax.debug.visualize_array_sharding(f(key, x))
One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:
```{code-cell}
:id: nBUHBBal3vGY
:outputId: f194c213-0688-4b7a-ffb8-c4453b82b1f1
jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))