rocm_jax/docs/sharded-computation.md
Michael Hudgins 2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00

14 KiB

jupytext kernelspec
formats text_representation
ipynb,md:myst
extension format_name format_version jupytext_version
.md myst 0.13 1.16.4
display_name name
Python 3 python3

(sharded-computation)=

Introduction to parallel programming

This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.

The tutorial covers three modes of parallel computation:

  • Automatic parallelism via {func}jax.jit: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel").
  • Semi-automated parallelism using {func}jax.jit and {func}jax.lax.with_sharding_constraint
  • Fully manual parallelism with manual control using {func}jax.experimental.shard_map.shard_map: shard_map enables per-device code and explicit communication collectives

Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.

If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: Runtime > Change runtime type > Hardware accelerator > TPU v2 (which provides eight devices to work with).

:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456

import jax
jax.devices()

Key concept: Data sharding

Key to all of the distributed computation approaches below is the concept of data sharding, which describes how data is laid out on the available devices.

How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}jax.Array immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}jax.Array object is designed with distributed data and computation in mind. Every jax.Array has an associated {mod}jax.sharding.Sharding object, which describes which shard of the global data is required by each global device. When you create a {class}jax.Array from scratch, you also need to create its Sharding.

In the simplest cases, arrays are sharded on a single device, as demonstrated below:

:outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
:outputId: 536f773a-7ef4-4526-c58b-ab4d486bf5a1

arr.sharding

For a more visual representation of the storage layout, the {mod}jax.debug module provides some helpers to visualize the sharding of an array. For example, {func}jax.debug.visualize_array_sharding displays how the array is stored in memory of a single device:

:outputId: 74a793e9-b13b-4d07-d8ec-7e25c547036d

jax.debug.visualize_array_sharding(arr)

To create an array with a non-trivial sharding, you can define a {mod}jax.sharding specification for the array and pass this to {func}jax.device_put.

Here, define a {class}~jax.sharding.NamedSharding, which specifies an N-dimensional grid of devices with named axes, where {class}jax.sharding.Mesh allows for precise device placement:

:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5

from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

Passing this Sharding object to {func}jax.device_put, you can obtain a sharded array:

:outputId: c8ceedba-05ca-4156-e6e4-1e98bb664a66

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

+++ {"id": "UEObolTqw4pp"}

The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.

The {class}~jax.sharding.NamedSharding includes a parameter called memory_kind. This parameter determines the type of memory to be used and defaults to device. You can set this parameter to pinned_host if you prefer to place it on the host.

To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the with_memory_kind method on the existing sharding.

---
colab:
  base_uri: https://localhost:8080/
id: aKNeOHTJnqmS
outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2
---
s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
s_dev = s_host.with_memory_kind('device')
arr_host = jax.device_put(arr, s_host)
arr_dev = jax.device_put(arr, s_dev)
print(arr_host.sharding.memory_kind)
print(arr_dev.sharding.memory_kind)

+++ {"id": "jDHYnVqHwaST"}

1. Automatic parallelism via jit

Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}jax.jit-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.

The XLA compiler behind jit includes heuristics for optimizing computations across multiple devices. In the simplest of cases, those heuristics boil down to computation follows data.

To demonstrate how auto-parallelization works in JAX, below is an example that uses a {func}jax.jit-decorated staged-out function: it's a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:

:outputId: de46f86a-6907-49c8-f36c-ed835e78bc3d

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)

As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.

Here, you sum along the leading axis of x, and visualize how the result values are stored across multiple devices (with {func}jax.debug.visualize_array_sharding):

:outputId: 90c3b997-3653-4a7b-c8ff-12a270f11d02

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

+++ {"id": "Q4N5mrr9i_ki"}

The result is partially replicated: that is, the first two elements of the array are replicated on devices 0 and 6, the second on 1 and 7, and so on.

1.1 Sharding transformation between memory types

The output sharding of a {func}jax.jit function can differ from the input sharding if you specify the output sharding using the out_shardings parameter. Specifically, the memory_kind of the output can be different from that of the input array.

Example 1: Pinned host to device memory

In the example below, the {func}jax.jit function f takes an array sharded in pinned_host memory and generates an array in device memory.

---
colab:
  base_uri: https://localhost:8080/
id: PXu3MhafyRHo
outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b
---
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_dev = f(arr_host)
print(out_dev)
print(out_dev.sharding.memory_kind)

+++ {"id": "LuYFqpcBySiX"}

Example 2: Device to pinned_host memory

In the example below, the {func}jax.jit function g takes an array sharded in device memory and generates an array in pinned_host memory.

---
colab:
  base_uri: https://localhost:8080/
id: qLsgNlKfybRw
outputId: a16448b9-7e39-408f-b200-505f65ad4464
---
g = jax.jit(lambda x: x, out_shardings=s_host)
out_host = g(arr_dev)
print(out_host)
print(out_host.sharding.memory_kind)

+++ {"id": "7BGD31-owaSU"}

2. Semi-automated sharding with constraints

If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}~jax.lax.with_sharding_constraint function. You can use {func}jax.lax.with_sharding_constraint (in place of {func}jax.device_put()) together with {func}jax.jit for more control over how the compiler constraints how the intermediate values and outputs are distributed.

For example, suppose that within f_contract above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:

:outputId: 8468f5c6-76ca-4367-c9f2-93c723687cfd

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

This gives you a function with the particular output sharding you'd like.

3. Manual parallelism with shard_map

In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and jit will split that computation across multiple devices. By contrast, with {func}jax.experimental.shard_map.shard_map you write the function that will handle a single shard of data, and shard_map will construct the full function.

shard_map works by mapping a function across a particular mesh of devices (shard_map maps over shards). In the example below:

  • As before, {class}jax.sharding.Mesh allows for precise device placement, with the axis names parameter for logical and physical axis names.
  • The in_specs argument determines the shard sizes. The out_specs argument identifies how the blocks are assembled back together.

Note: {func}jax.experimental.shard_map.shard_map code can work inside {func}jax.jit if you need it.

:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2

from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)

The function you write only "sees" a single batch of the data, which you can check by printing the device local shape:

:outputId: 99a3dc6e-154a-4ef6-8eaa-3dd0b68fb1da

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)

Because each of your functions only "sees" the device-local part of the data, it means that aggregation-like functions require some extra thought.

For example, here's what a shard_map of a {func}jax.numpy.sum looks like:

:outputId: 1e9a45f5-5418-4246-c75b-f9bc6dcbbe72

def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)

Your function f operates separately on each shard, and the resulting summation reflects this.

If you want to sum across shards, you need to explicitly request it using collective operations like {func}jax.lax.psum:

:outputId: 4fd29e80-4fee-42b7-ff80-29f9887ab38d

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)

Because the output no longer has a sharded dimension, set out_specs=P() (recall that the out_specs argument identifies how the blocks are assembled back together in shard_map).

Comparing the three approaches

With these concepts fresh in our mind, let's compare the three approaches for a simple neural network layer.

Start by defining your canonical function like this:

:id: 1TdhfTsoiqS1

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)
:outputId: f3007fe4-f6f3-454e-e7c5-3638de484c0a

import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias)

You can automatically run this in a distributed manner using {func}jax.jit and passing appropriately sharded data.

If you shard the leading axis of both x and weights in the same way, then the matrix multiplication will automatically happen in parallel:

:outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5

mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias)

Alternatively, you can use {func}jax.lax.with_sharding_constraint in the function to automatically distribute unsharded inputs:

:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4

@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)  # pass in unsharded inputs

Finally, you can do the same thing with shard_map, using {func}jax.lax.psum to indicate the cross-shard collective required for the matrix product:

:outputId: 568d1c85-39a7-4dba-f09a-0e4f7c2ea918

from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias)

Next steps

This tutorial serves as a brief introduction of sharded and parallel computation in JAX.

To learn about each SPMD method in-depth, check out these docs:

  • {doc}../notebooks/Distributed_arrays_and_automatic_parallelization
  • {doc}../notebooks/shard_map