shard_map (shmap) prototype and JEP

Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
This commit is contained in:
Matthew Johnson 2022-11-04 15:29:10 -07:00
parent fcb9dfb080
commit ff1e9b3973
12 changed files with 1971 additions and 15 deletions

640
docs/jep/14273-shard-map.md Normal file
View File

@ -0,0 +1,640 @@
# `shmap` (`shard_map`) for simple per-device code
*sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@*
*January 2023*
## Motivation
JAX supports two schools of thought for multi-device programming:
1. **Compiler, take the wheel!** Let the compiler automatically partition bulk
array functions over devices.
2. **Just let me write what I mean, damnit!** Give me per-device code and
explicit communication collectives.
We need great APIs for both, and rather than being mutually exclusive
alternatives, they need to compose with each other.
With `pjit` (now just `jit`) we have [a next-gen
API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
for the first school. But we haven't quite leveled-up the second school. `pmap`
follows the second school, but over time we found it has [fatal
flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws,
but it doesn't quite give us per-device shapes, and it includes several other
big ideas too. Meanwhile, new demands for per-device explicit-collectives
programming have emerged, like in [Efficiently Scaling Transformer
Inference](https://arxiv.org/abs/2211.05102).
We can level-up the second school with `shmap`. `shmap` is:
* a simple multi-device parallelism API which lets us write per-device code with
explicit collectives, where logical shapes match per-device physical buffer
shapes and collectives correspond exactly to cross-device communication;
* a specialization of `xmap` with scaled-back features and a few tweaks;
* a fairly direct surfacing of the XLA SPMD Partitioner's 'manual' mode;
* a fun-to-say Seussian name which could stand for `shard_map`,
`shpecialized_xmap`, `sholto_map`, or `sharad_map`.
**For `pjit` users**, `shmap` is a complementary tool. It can be used inside a
`pjit` computation to drop temporarily into a "manual collectives" mode, like an
escape hatch from the compiler's automatic partitioning. That way, users get the
convenience and familiar just-NumPy programming model of `pjit` for most of their
code, along with the ability to hand-optimize collective communication with
`shmap` wherever it's needed. It's the best of both worlds!
**For `pmap` users**, `shmap` is a strict upgrade. It's more expressive,
performant, and composable with other JAX APIs, without making basic batch data
parallelism any harder.
For more on practical use, you can jump to [When should you use `shmap` and when
should you use `pjit`?](#when-should-you-use-shmap-and-when-should-you-use-pjit).
If you're wondering why we need a new thing at all, or what
the problems with `pmap` are, jump to [Why don't `pmap` or `xmap` already solve
this?](#why-dont-pmap-or-xmap-already-solve-this).
Or keep reading the next section to see some `shmap` examples and the API spec.
## So, let's see `shmap`!
### TL;DR example (with a more detailed explanation to follow)
Sho shick:
```python
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 32]
z_partialsum = jnp.dot(a_block, b_block)
z_block = jax.lax.psum(z_partialsum, 'y')
return z_block
c = matmul_basic(a, b) # c: f32[8, 32]
```
Notice:
* no nesting needed (or `axis_index_groups`) for multiple axes of parallelism,
unlike `pmap`;
* no reshapes in the caller, unlike `pmap` and hard-`xmap`, and logical shapes
correspond to per-device physical shapes, unlike (non-hard) `xmap`;
* precise device placement control by using `mesh`, unlike `pmap`;
* there's only one set of axis names for logical and physical, unlike `xmap`;
* the result is a `jax.Array` which could be efficiently passed to a `pjit`,
unlike `pmap`;
* this same code works efficiently inside a `pjit`/`jit`, unlike `pmap`;
* this code works eagerly, so we can `pdb` in the middle and print values,
unlike `xmap`'s current implementation (though by design `xmap` without the
sequential schedule can in principle work eagerly too).
Here's another matmul variant with a fully sharded result:
```python
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', 'y'))
def matmul_reduce_scatter(a_block, b_block):
# c_partialsum: f32[8/X, 32]
c_partialsum = jnp.matmul(a_block, b_block)
# c_block: f32[8/X, 32/Y]
c_block = lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True)
return c_block
c = matmul_reduce_scatter(a, b)
```
### Slow down, start with the basics!
#### Rank-reducing vs rank-preserving maps over array axes
We can think of `pmap` (and `vmap` and `xmap`) as unstacking each array input
along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body
function to each piece, and stacking the results back together, at least when
collectives aren't involved:
```python
pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])
```
For example, if `xs` had shape `f32[8,5]` then each `x` has shape `f32[5]`, and
if each `f(x)` has shape `f32[3,7]` then the final stacked result `pmap(f)(xs)`
has shape `f32[8,3,7]`. That is, each application of the body function `f` takes
as argument inputs with one fewer axis than the corresponding argument to
`pmap(f)`. We can say these are *rank-reducing maps* with unstacking/stacking of
inputs/outputs.
The number of logical applications of `f` is determined by the size of the input
axis being mapped over: for example, if we map over an input axis of size 8,
semantically we get 8 logical applications of the function, which for pmap
always correspond to 8 devices physically computing them.
In contrast, `shmap` does not have this rank-reducing behavior. Instead, we can
think of it as slicing (or "unconcatenating") along input axes into blocks,
applying the body function, and concatenating the results back together (again
when collectives aren't involved):
```python
devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
```
Recall that `jnp.split` slices its input into equally-sized blocks with the same
rank, so that if in the above example `y` has shape `f32[8,5]` then each `y_blk`
has shape `f32[2,5]`, and if each `f(y_blk)` has shape `f32[3,7]` then the final
concatenated result `shmap(f, ...)(y)` has shape `f32[12,7]`. So `shmap`
(`shard_map`) maps over shards, or blocks, of its inputs. We can say it's a
*rank-preserving ma*p with unconcatenating/concatenating of its inputs/outputs.
The number of logical applications of `f` is determined by the mesh size, not by
any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4
devices) then semantically we get 4 logical applications of the function,
corresponding to the 4 devices physically computing them.
#### Controlling how each input is split (unconcatenated) and tiled with `in_specs`
Each of the `in_specs` identifies some of the corresponding input array's axes
with mesh axes by name using `PartitionSpec`s, representing how to split (or
unconcatenate) that input into the blocks to which the body function is applied.
That identification determines the shard sizes; when an input axis is identified
with a mesh axis, the input is split (unconcatenated) along that logical axis
into a number of pieces equal to the corresponding mesh axis size. (It's an
error if the corresponding mesh axis size does not evenly divide the input array
axis size.) If an input's pspec does not mention a mesh axis name, then there's
no splitting over that mesh axis. For example:
```python
devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))
@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape)
return x_block
x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1) # prints (3,12)
```
Here, because the input pspec did not mention the mesh axis name `'j'`, no input
array axis is split over that mesh axis; similarly, because the second axis of
the input array is not identified with (and hence split over) any mesh axis,
application of `f1` gets a full view of the input along that axis.
When a mesh axis is not mentioned in an input pspec, we can always rewrite to a
less efficient program where all mesh axes are mentioned but the caller performs
a `jnp.tile`, for example:
```python
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
```
In other words, because each input pspec can mention each mesh axis name zero or
one times, rather than having to mention each name exactly once, we can say that
in addition to the `jnp.split` built into its input, `shard_map` also has a
`jnp.tile` built into its input, at least logically (though the tiling may not
need to be carried out physically, depending on the arguments' physical sharding
layout). The tiling to use is not unique; we could also have tiled along the
first axis, and used the pspec `P(('j', 'i'), None)`.
Physical data movement is possible on inputs, as each device needs to have a
copy of the appropriate data.
#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`
Analogously to the input side, each of the `out_specs` identifies some of the
corresponding output array's axes with mesh axes by name, representing how the
output blocks (one for each application of the body function, or equivalently
one for each physical device) should be assembled back together to form the
final output value. For example, in both the `f1` and `f2` examples above the
`out_specs` indicate we should form the final output by concatenating together
the block results along both axes, resulting in both cases an array `y` of shape
`(12,24)`. (It's an error if an output shape of the body function, i.e. an
output block shape, has a rank too small for the concatenation described by the
corresponding output pspec.)
When a mesh axis name is not mentioned in an output pspec, it represents an
*un-tiling*: when the user writes an output pspec which does not mention one of
the mesh axis names, they promise that the output blocks are equal along that
mesh axis, and so only one block along that axis is used in the output (rather
than concatenating all the blocks together along that mesh axis). For example,
using the same mesh as above:
```python
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
```
Notice that the body function closing over an array value is equivalent to
passing it as an augment with a corresponding input pspec of `P(None, None)`. As
another example, following more closely to the other examples above:
```python
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape) # (12,6)
```
Notice that the result has a second axis size of 6, half the size of the input's
second axis. In this case, the un-tile expressed by not mentioning the mesh axis
name `'j'` in the output pspec was safe because of the collective `psum`, which
ensures each output block is equal along the corresponding mesh axis. Here are
two more examples where we vary which mesh axes are mentioned in the output
pspec:
```python
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
```
On the physical side, not mentioning a mesh axis name in an output pspec
assembles an `Array` from the output device buffers with replicated layout along
that mesh axis.
There is no runtime check that the output blocks are actually equal along a mesh
axis to be un-tiled along, or equivalently that the corresponding physical
buffers have equal values and thus can be interpreted as a replicated layout for
a single logical array. But we can provide a static check mechanism which raises
an error on all potentially-incorrect programs.
Because the `out_specs` can mention mesh axis names zero or one times, and
because they can be mentioned in any order, we can say that in addition to the
`jnp.concatenate` built into its output, `shard_map` also has both an untile and
a block transpose built into its output.
Physical data movement is not possible on outputs, no matter the output pspec.
Instead, `out_specs` just encodes how to assemble the block outputs into
`Array`s, or physically how to interpret the buffers across devices as the
physical layout of a single logical `Array`.
### API Specification
```python
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shmap(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
) -> Callable:
...
```
where:
* `mesh` encodes devices arranged in an array and with associated axis names,
just like it does for `xmap` and for `sharding.NamedSharding`;
* `in_specs` and `out_specs` are `PartitionSpec`s which can
[affinely](https://en.wikipedia.org/wiki/Substructural_type_system) mention
axis names from `mesh` (not separate logical names as in `xmap`) to express
slicing/unconcatenation and concatenation of inputs and outputs, respectively
(not unstacking and stacking like `pmap` and `xmap` do), with unmentioned
names corresponding to replication and untiling
(assert-replicated-so-give-me-one-copy), respectively;
* the shapes of the arguments passed to `f` have the same ranks as the arguments
passed to `shard_map`-of-`f` (unlike `pmap` and `xmap` where the ranks are
reduced), and the shape of an argument to `f` is computed from the shape
`shape` of the corresponding argument to `shard_map`-of-`f` and the
corresponding `PartitionSpec` spec as roughly
`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`;
* the body of `f` can apply collectives using names from `mesh`.
`shmap` is eager by default, meaning that we dispatch computations
primitive-by-primitive, so that the user can employ Python control flow on fully
replicated values and interactive `pdb` debugging to print any values. To stage
out and end-to-end compile a `shmap`ped function, just put a `jit` around it. A
consequence is that `shmap` doesn't have its own dispatch and compilation paths
like `xmap` and `pmap` currently do; it's just the `jit` path.
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to MHLO
is trivial: it just involves switching into 'manual SPMD mode' on the inputs,
and switching back on the outputs. (We don't currently plan to support
partially-manual-partially-automatic modes.)
The interaction with effects is the same as with `pmap`.
The interaction with autodiff is also just like `pmap` (rather than attempting
the new semantics that `xmap` did, corresponding to having unmapped
intermediates and hence `grad`'s `reduce_axes` as well as making `psum`
transpose to `pbroadcast` rather than `psum`). But it thus inherits an unsolved
problem from `pmap`: in some cases, instead of transposing `psum` to `psum`, and
thus performing a backward pass `psum` corresponding to the forward pass `psum`,
it can be beneficial to move the backward pass `psum` to elsewhere in the
backward pass, exploiting linearity. Many advanced `pmap` users addressed this
challenge by using `custom_vjp` to implement `psum_idrev` and `id_psumrev`
functions, but since it's easy to accidentally leave those imbalanced, that
technique is a foot-cannon. We have some ideas on how to provide this
functionality in a safer way.
## When should you use `shmap` and when should you use `pjit`?
One philosophy is: it is almost always simpler to write a program in `jit==pjit`
&mdash; but if a given part of the program is less optimized by the compiler than it
could be, drop into `shmap`!
### A realistic transformer example
In fact, we can implement a simple version of the ["collective
matmul"](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) algorithm
recently introduced in XLA to overlap communication and computation using `shmap`
and 30 lines of Python. The basic idea of the algorithm can be grasped with a
simple example.
Suppose we want to compute `C = A @ B` where `A` is sharded by a 1D mesh on the
0-th dimension while `B` and `C` are replicated.
```python
M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))
mesh = Mesh(np.array(jax.devices()), axis_names=('x'))
A_x = jax.device_put(A, NamedSharding(mesh, P('x', None)))
@jax.jit
def f(lhs, rhs):
return lhs @ rhs
C = f(A_x, B)
```
A profile shows the blocking all-gather across 8 devices before the matmul can
start. This is suboptimal because `A` is sharded on a non-contracting dimension,
and each shard of `A` can be matmul'ed with `B` independently and this chunked
computation can be overlapped with fetching of the next shard of `A` from
another device.
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507011-e854fb11-43d5-484d-993b-19a3349ed4b9.png">
This overlap can be implemented using `shmap` and explicit collectives.
```python
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.psum(1, axis_name='x')
axis_index = jax.lax.axis_index(axis_name='x')
chunk_size = lhs.shape[0]
def f(i, carrys):
accum, lhs = carrys
# matmul for a chunk
update = lhs @ rhs
# circular shift to the left
lhs = jax.lax.ppermute(
lhs,
axis_name='x',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
# device 0 computes chunks 0, 1, ...
# device 1 computes chunks 1, 2, ...
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum, lhs
accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
# fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
# accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
for i in range(0, axis_size - 1):
accum, lhs = f(i, (accum, lhs))
# compute the last chunk, without the ppermute
update = lhs @ rhs
i = axis_size - 1
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum
```
```
jit_sharded_f = jax.jit(shard_map(
collective_matmul_allgather_lhs_non_contracting, mesh,
in_specs=(P('x', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)
```
A profile shows that the all-gather is gone, and replaced with overlapped matmul
with async collective permute. This profile matches very closely with the
collective matmul paper result.
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507064-139f032c-d869-4b67-9e11-1587d4fd2de9.png">
This collective matmul technique can be used to speed up feedforward blocks in
transformer layers. This typically consists of two matrix multiplications
followed by a `ReduceScatter` (to resolve partial sums from a parallelized
matrix multiplication) and preceded by an `AllGather` (to collect the sharded
dimensions along some axes and allow partial sum computation). Together, the
`ReduceScatter` from one layer and the `AllGather` for the next amount to an
`AllReduce`.
In a typical profile, the two matmuls will be followed by an `AllReduce`, and
they will not be overlapped. Collective matmul can be used to achieve the
overlap, but is difficult to trigger, has a minimum slice size and does not yet
cover all topologies, tensor shapes and variants of collective matmul (i.e
latency and throughput optimized variants). [In a recent
paper](https://arxiv.org/abs/2211.05102), we found a ~40% gain in many
circumstances from manually implementing collective matmul variants in `shmap`
style.
But it isnt always more complex! We expect this to be a much more natural way
to think about pipelined computation, and plan to do some demos of that soon!
### Another realistic example
Here's how `shmap` might look in a transformer layer pass with a 2D weight
gathered pattern ([paper](https://arxiv.org/abs/2211.05102), Sec 3.2.3 on p. 5):
```python
def matmul_2D_wg_manual(xnorm, q_wi, layer):
'''Calls a custom manual implementation of matmul_reducescatter'''
# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
# -> (matmul)
# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
# -> (reducescatter over x into X heads, B batches)
# -> [batch, maxlen, heads.YZX, q_wi_per_head]
with jax.named_scope('q_wi'):
xnorm = intermediate_dtype(xnorm)
q_wi = matmul_reducescatter(
'bte,hed->bthd',
xnorm,
params.q_wi,
scatter_dimension=(0, 2),
axis_name='x',
layer=layer)
return q_wi
import partitioning.logical_to_physical as l2phys
def pjit_transformer_layer(
hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Forward pass through a single layer, returning output, K, V."""
def my_layer(t, axis=0):
"""Gets the parameters corresponding to a given layer."""
return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
# 2D: [batch.Z, time, embed.XY]
x = _with_sharding_constraint(
x, ('residual_batch', 'residual_time', 'residual_embed'))
xnorm = _layernorm(x)
# 2D: [batch, time, embed.X]
xnorm = _with_sharding_constraint(
xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
# jump into manual mode where you want to optimise
if manual:
q_wi = shard_map(matmul_2D_wg_manual, mesh
in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
else:
q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
# 2D: [batch, time, heads.YZX, None]
q_wi = _with_sharding_constraint(q_wi,
('post_norm_batch', 'time', 'heads', 'qkv'))
q = q_wi[:, :, :, :hparams.qkv]
q = _rope(sin, cos, q)
# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
# swiGLU with full d_ff dimension, rather than 2/3 scaled
wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
k = kv[:, :, 0, :hparams.qkv]
v = kv[:, :, 0, hparams.qkv:]
k = _rope(sin, cos, k)
y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
y_mlp = special2.swish2(wi0) * wi1
# 2D: [batch, time, heads.YZX, None]
y_mlp = _with_sharding_constraint(y_mlp,
('post_norm_batch', 'time', 'heads', None))
y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
# do the second half of the mlp and the self-attn projection in parallel
y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
# 2D: [batch.Z, time, embed.XY]
y_out = _with_sharding_constraint(
y_out, ('residual_batch', 'residual_time', 'residual_embed'))
z = y_out + x
z = _with_sharding_constraint(
z, ('residual_batch', 'residual_time', 'residual_embed'))
return z, k, v
```
In the profile below, both the first and second matmul were replaced by manually
lowered versions, where the compute (fusions) are fully overlapped with the
communication (ppermute)! One fun hint that we are using a latency optimised
variant is that the ppmerute pixels are jittered &mdash; because there are two
overlapping ppermutes using opposite ICI axes at the same time!
All-to-all is much harder to overlap, so was left on the table.
<img width="1085" alt="image" src="https://user-images.githubusercontent.com/1458824/216507137-adc35a1f-a76c-4704-a62d-389b42771090.png">
## Why don't `pmap` or `xmap` already solve this?
`pmap` was our first multi-device parallelism API. It follows the
per-device-code-and-explicit-collectives school. But it had major shortcomings
which make it unsuitable for today's programs:
* **Mapping multiple axes required nested `pmap`s.** Not only are nested `pmap`s
cumbersome to write, but also they make it difficult to control (or even
predict) the device placement of data and computation, and difficult to
preserve data sharding (see the next two bullets). Today's programs require
multiple axes of parallelism.
* **Controlling device placement was impossible.** Especially with multiple axes
of parallelism, programmers need to control how those axes are aligned with
hardware resources and their communication topologies. But (nested) `pmap`
doesn't offer control over how mapped program instances are placed on
hardware; there's just an automatic device order which the user can't control.
([Gopher](https://arxiv.org/abs/2112.11446)'s use of `axis_index_groups` and a
single un-nested `pmap` was essentially a hack to get around this by
flattening multiple axes of parallelism down to one.)
* **`jit`/`pjit` composability.** `jit`-of-`pmap` is a performance footgun, as
is nesting `pmap`s, as is e.g. `scan`-of-`pmap`, because sharding is not
preserved when returning from an inner `pmap`. To preserve sharding we would
need pattern matching on jaxprs to ensure we're working with perfectly nested
pmaps, or a pmap just inside a `jit`. Moreover, `pjit` was no help here
because `pmap` targets XLA replicas while `pjit` targets the XLA SPMD
Partitioner, and composing those two is hard.
* **`jax.Array` compatibility (and hence `pjit` compatibility).** Because the
sharding of `pmap` outputs can't be expressed as `Shardings` / `OpShardings`,
due to `pmap`'s stacking rather than concatenative semantics, the output of a
`pmap` computation can't currently be passed to a `pjit` computation without
bouncing to host (or dispatching a reshaping computation).
* **Multi-controller semantics (and hence `pjit` compatibility).**
Multi-controller `pmap` concatenates values across controllers, which works well
but differs from single-controller `pmap`'s stacking semantics. More
practically, it precludes the use of non-fully-addressable `jax.Array` inputs
and outputs as we use with multi-controller `pjit`.
* **Eager mode.** We didn't make `pmap` eager-first, and though we eventually
(after 4+ years!) added eager operation with `disable_jit()`, the fact that
`pmap` has `jit` fused into it means it has its own compilation and dispatch
path (actually two dispatch paths: in Python for handling `Tracer`s, and in
C++ for performance on raw `Array` inputs!), a heavy implementation burden.
* **Reshapes needed in the caller.** A typical use case with `pmap` on 8 devices
might look like starting with a batch axis of size 128, reshaping it to split
into two axes with sizes (8, 16), and then `pmap`ping over the first. These
reshapes are awkward and the compiler often interprets them as copies instead
of view &mdash; increasing memory and time usage.
These shortcomings aren't so bad when only doing batch data parallelism. But
when more parallelism is involved, `pmap` just can't cut it!
`xmap` paved the way as a next-gen evolution of `pmap` and solved (almost) all these
issues. `shmap` follows in `xmap`'s footsteps and solves these problems in
essentially the same ways; indeed, `shmap` is like a specialized subset of `xmap`
(what some call the "hard `xmap`" subset), with a few tweaks.
For the initial prototype, we chose to implement `shmap` as a separate primitive
from `xmap`, because limiting the set of features it supports makes it easier to
focus on the core functionality. For example, `shmap` doesn't allow unmapped
intermediates, making it easier not to worry about the interactions between
named axes and autodiff. Furthermore, not having to reason about interactions of
all pairs of features makes it easier to add capabilities beyond what's
implemented in `xmap` today, such as support for eager mode.
Both `shmap` and `xmap` share significant portions of the lowering code. We
could consider merging both in the future, or even focusing solely on `shmap`,
depending on how the usage will evolve.

View File

@ -46,6 +46,7 @@ Then create a pull request that adds a file named
10657: Sequencing side-effects in JAX <10657-sequencing-effects>
11830: `jax.remat` / `jax.checkpoint` new implementation <11830-new-remat-checkpoint>
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
Several early JEPs were converted in hindsight from other documentation,

View File

@ -100,6 +100,7 @@ py_library_providing_imports_info(
"experimental/pjit.py",
"experimental/global_device_array.py",
"experimental/multihost_utils.py",
"experimental/shard_map.py",
# until checkify is moved out of experimental
"experimental/checkify.py",
# to avoid circular dependencies

View File

@ -63,9 +63,6 @@ def _initial_style_jaxpr(fun, in_avals):
def _close_jaxpr(jaxpr):
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
def _initial_style_staging() -> bool:
return core.thread_local_state.trace_state.initial_style
def _sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)

View File

@ -49,7 +49,8 @@ from jax._src import profiler
from jax._src import stages
from jax._src import traceback_util
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
OpShardingSharding, Sharding)
OpShardingSharding, NamedSharding, PartitionSpec,
Sharding)
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.lib.mlir import ir
@ -563,7 +564,7 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]:
from jax.experimental import pjit
from jax.experimental import pjit, shard_map
for eqn in jaxpr.eqns:
if eqn.primitive is pjit.sharding_constraint_p:
@ -571,6 +572,12 @@ def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]:
elif eqn.primitive is pjit.pjit_p:
yield from eqn.params['in_shardings']
yield from eqn.params['out_shardings']
elif eqn.primitive is shard_map.shard_map_p:
def _names_to_pspec(names):
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
yield from (NamedSharding(eqn.params['mesh'], _names_to_pspec(names))
for names in eqn.params['in_names'])
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_shardings(subjaxpr)

View File

@ -403,20 +403,21 @@ def tuple_delete(t, idx):
class HashableFunction:
"""Decouples function equality and hash from its identity.
Local lambdas and functiond defs are reallocated on each function call, making
Local lambdas and function defs are reallocated on each function call, making
the functions created on different calls compare as unequal. This breaks our
caching logic, which should really only care about comparing the semantics and
not actual identity.
This class makes it possible to compare different functions based on their
semantics. The parts that are taken into account are: the bytecode of
the wrapped function (which is cached by the CPython interpreter and is stable
across the invocations of the surrounding function), and `closure` which should
contain all values in scope that affect the function semantics. In particular
`closure` should contain all elements of the function closure, or it should be
possible to derive the relevant elements of the true function closure based
solely on the contents of the `closure` argument (e.g. in case some closed-over
values are not hashable, but are entirely determined by hashable locals).
semantics. The parts that are taken into account are: the bytecode of the
wrapped function (which is cached by the CPython interpreter and is stable
across the invocations of the surrounding function), and `closure` which
should contain all values in scope that affect the function semantics. In
particular `closure` should contain all elements of the function closure, or
it should be possible to derive the relevant elements of the true function
closure based solely on the contents of the `closure` argument (e.g. in case
some closed-over values are not hashable, but are entirely determined by
hashable locals).
"""
def __init__(self, f, closure):

View File

@ -1316,6 +1316,7 @@ tf_not_yet_impl = [
"for",
"inspect_sharding",
"io_callback",
"shard_map",
# Not high priority?
"after_all",

View File

@ -0,0 +1,918 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
from functools import partial, lru_cache
import inspect
import operator as op
from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
Set, Tuple, TypeVar, Union, Protocol)
import numpy as np
import jax
from jax import core
from jax.core import Tracer
from jax.sharding import NamedSharding, PartitionSpec, Mesh
from jax._src import ad_util
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.lax import lax, parallel as lax_parallel
from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
memoize, partition_list, merge_lists)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax.experimental import maps
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import ad
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_leaves)
from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
_generate_key_paths, KeyPath)
P = PartitionSpec
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
traceback_util.register_exclusion(__file__)
# API
Specs = Any # PyTree[PartitionSpec]
@traceback_util.api_boundary
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
check_rep: bool = True):
if not callable(f):
raise TypeError("shard_map requires a callable for its first argument, "
f"but got {f} of type {type(f)}.")
if not isinstance(mesh, Mesh):
raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its "
f"second argument, but got {mesh} of type {type(mesh)}.")
_check_specs(SpecErrorType.input, in_specs)
_check_specs(SpecErrorType.out, out_specs)
@traceback_util.api_boundary
def wrapped(*args):
fun = lu.wrap_init(f)
args_flat, in_tree = tree_flatten(args)
try: in_specs_flat = broadcast_prefix(in_specs, args)
except ValueError:
e, *_ = prefix_errors(in_specs, args)
raise e('shard_map in_specs') from None
_check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
@memoize
def out_names_thunk():
dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves)
try: out_specs_flat = broadcast_prefix(out_specs, dummy)
except ValueError:
e, *_ = prefix_errors(out_specs, dummy)
raise e('shard_map out_specs') from None
return tuple(map(_canonicalize_spec, out_specs_flat))
try:
out_flat = shard_map_p.bind(
flat_fun, *args_flat, mesh=mesh, in_names=in_names_flat,
out_names_thunk=out_names_thunk, check_rep=check_rep)
except _SpecError as e:
fails, = e.args
msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails)
if any(fail is not no_fail and not fail.shape for fail in fails):
msg += (" In particular, for rank 0 outputs which are not constant "
"over the mesh, add at least one (singleton) axis to them so "
"that they can be concatenated using out_specs.")
raise ValueError(msg) from None
except _RepError as e:
fails, = e.args
msg = _rep_error(f, mesh, out_tree(), out_specs, fails)
raise ValueError(msg) from None
return tree_unflatten(out_tree(), out_flat)
return wrapped
# Internally use AxisNames = Dict[int, Tuple[AxisName, ...]], not PartitionSpecs
AxisName = Hashable
AxisNames = Dict[int, Tuple[AxisName, ...]] # TODO(mattjj): make it hashable
def _canonicalize_spec(spec: PartitionSpec) -> AxisNames:
if isinstance(spec, PartitionSpec):
return {i: names if isinstance(names, tuple) else (names,)
for i, names in enumerate(spec) if names is not None}
else:
return spec
# Error checking and messages
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return
prefix = 'in' if error_type == SpecErrorType.input else 'out'
msgs = [f" {prefix}_specs{key.pprint()} is {x} of type {type(x).__name__}, "
for key, x in _generate_key_paths(specs) if not isinstance(x, P)]
raise TypeError(
f"shard_map {prefix}_specs argument must be a pytree of "
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
+ '\n\n'.join(msgs) + '\n\n'
f"Check the {prefix}_specs values passed to shard_map.")
class NoFail: pass
no_fail = NoFail()
def _check_specs_vs_args(
f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs,
in_specs_flat: List[P], xs: List) -> None:
in_avals = map(shaped_abstractify, xs)
fail = [a if not len(p) <= a.ndim else no_fail
for p, a in zip(in_specs_flat, in_avals)]
if any(f is not no_fail for f in fail):
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
raise ValueError(msg)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
for d, ns in names.items()) else no_fail
for a, names in zip(in_avals, in_names_flat)]
if any(f is not no_fail for f in fail):
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
raise ValueError(msg)
def _spec_rank_error(
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
if error_type == SpecErrorType.input:
prefix, base = 'in', 'args'
ba = _try_infer_args(f, tree)
else:
prefix, base = 'out', f'{f.__name__}(*args)'
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if error_type == SpecErrorType.input and ba is not None:
arg_key, *_ = fail_key.keys
extra = (f", where {base}[{arg_key.key}] is bound to {f.__name__}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
else:
extra = ""
msgs.append(
f"{prefix}_specs{spec_key.pprint()} is {spec} which has length "
f"{len(spec)}, but "
f"{base}{fail_key.pprint()}{extra} has shape {aval.str_short()}, "
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
assert msgs
msg = (f"shard_map applied to the function '{f.__name__}' was given an "
f"{prefix}_specs entry which is too long to be compatible with the "
f"corresponding {prefix}put value from the function:\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Entries in {prefix}_specs must be of length no greater than the "
f"number of axes in the corresponding {prefix}put value.\n\n"
f"Either revise the spec to be shorter, or modify '{f.__name__}' so "
f"that its {prefix}puts have sufficient rank.")
return msg
def _spec_divisibility_error(
f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
ba = _try_infer_args(f, tree)
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if ba is not None:
arg_key, *_ = fail_key.keys
extra = (f", where args[{arg_key.key}] is bound to {f.__name__}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
names = _canonicalize_spec(spec)
for d, ns in names.items():
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
total = 'total ' if len(ns) > 1 else ''
sz = prod(mesh.shape[n] for n in ns)
msgs.append(
f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{spec_key.pprint()} of value {spec}, "
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
f"{aval.shape[d]}")
assert msgs
msg = (f"shard_map applied to the function '{f.__name__}' was given argument "
f"arrays with axis sizes that are not evenly divisible by the "
f"corresponding mesh axis sizes:\n\n"
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
f"axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Array arguments' axis sizes must be evenly divisible by the mesh "
f"axis or axes indicated by the corresponding elements of the "
f"argument's in_specs entry. Consider checking that in_specs are "
f"correct, and if so consider changing the mesh axis sizes or else "
f"padding the input and adapting '{f.__name__}' appropriately.")
return msg
def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
fails: List[Union[Set, NoFail]]) -> str:
msgs = []
for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails):
dst = _canonicalize_spec(spec)
unmentioned = _unmentioned(mesh, dst)
if len(unmentioned) > 1:
need_rep = ','.join(map(str, unmentioned))
got_rep = ','.join(map(str, rep))
diff = ','.join(map(str, unmentioned - rep))
msgs.append(
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axes "
f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, "
f"which is missing the required axes {diff}")
else:
need_rep_, = unmentioned
msgs.append(
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axis "
f"'{need_rep_}', but could not infer replication over any axes")
assert msgs
msg = (f"shard_map applied to the function '{f.__name__}' was given "
f"out_specs which require replication which can't be statically "
f"inferred given the mesh:\n\n"
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
f"axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
"Check if these output values are meant to be replicated over those "
"mesh axes. If not, consider revising the corresponding out_specs "
"entries. If so, consider ddisabling the check by passing the "
"check_rep=False argument to shard_map.")
return msg
def _unmentioned(mesh: Mesh, names: AxisNames) -> Set[AxisName]:
return set(mesh.axis_names) - {n for ns in names.values() for n in ns}
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
try:
return inspect.signature(f).bind(*dummy_args)
except (TypeError, ValueError):
return None
T = TypeVar('T')
def _iter_paths(tree: PyTreeDef, specs: Specs, fails: List[Union[T, NoFail]]
) -> List[Tuple[Tuple[KeyPath, P], Tuple[KeyPath, T]]]:
failures = tree_unflatten(tree, fails)
failures_aug = _generate_key_paths(failures)
specs_ = tree_unflatten(tree_structure(specs), _generate_key_paths(specs))
leaf = lambda x: type(x) is tuple and len(x) == 2 and type(x[1]) is P
specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf)
return [((spec_key, spec), (fail_key, fail_data))
for (spec_key, spec), (fail_key, fail_data)
in zip(specs_aug, failures_aug) if fail_data is not no_fail]
# Primitive
JaxType = Any
MaybeTracer = Union[JaxType, Tracer]
class ShardMapPrimitive(core.Primitive):
multiple_results = True
def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh,
in_names: Tuple[AxisNames, ...],
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
check_rep: bool) -> Sequence[MaybeTracer]:
top_trace = core.find_top_trace(args)
fun, env_todo = process_env_traces(fun, top_trace.level, mesh,
in_names, out_names_thunk, check_rep)
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
out_names = out_names_thunk()
_, xforms = env_todo()
for t in xforms:
out_names = t(out_names)
return out_names
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
shard_map_p, fun, tracers, mesh=mesh, in_names=in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
todos, _ = env_todo()
return map(core.full_lower, core.apply_todos(todos, outs))
def get_bind_params(self, params):
"""Goes from jaxpr form to python traceable form."""
new_params = dict(params)
jaxpr = new_params.pop('jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
axes = new_params.pop('out_names')
new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
shard_map_p = ShardMapPrimitive('shard_map')
@lu.transformation_with_aux
def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
*args: Any):
outs = yield args, {}
todos, out_names_transforms = [], []
while True:
tracers = [x for x in outs if isinstance(x, core.Tracer)
and (level is None or x._trace.level > level)]
if tracers:
ans = max(tracers, key=op.attrgetter('_trace.level'))
else:
break
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, (todo, xform) = trace.post_process_shard_map(
outs, mesh, in_names, out_names_thunk, check_rep)
todos.append(todo)
out_names_transforms.append(xform)
yield outs, (tuple(todos), tuple(out_names_transforms))
# Staging
def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, fun: lu.WrappedFun,
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
in_names: Tuple[AxisNames, ...],
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
check_rep: bool,
) -> Sequence[pe.DynamicJaxprTracer]:
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(
fun, trace.main, in_avals_)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
out_rep = _output_rep(mesh, jaxpr, in_rep)
_check_reps(mesh, out_names_thunk(), out_rep)
out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_)
source_info = source_info_util.current()
out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals]
invars = map(trace.getvar, in_tracers)
constvars = map(trace.getvar, map(trace.instantiate_const, consts))
outvars = map(trace.makevar, out_tracers)
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
params = dict(mesh=mesh, in_names=in_names_staged,
out_names=out_names_thunk(), jaxpr=jaxpr, check_rep=check_rep)
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
jaxpr.effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)))
else:
raise NotImplementedError # TODO(mattjj): add table with handlers
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)),
named_shape={k: v for k, v in aval.named_shape.items()
if k not in mesh.shape})
else:
raise NotImplementedError # TODO(mattjj): add table with handlers
# Type-checking
def _shard_map_typecheck(*in_atoms, jaxpr, mesh, in_names, out_names,
check_rep):
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
raise core.JaxprTypeError("shard_map argument avals not compatible with "
"jaxpr binder avals and in_names")
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
core.check_jaxpr(jaxpr)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
out_rep = _output_rep(mesh, jaxpr, in_rep)
for rep, dst in zip(out_rep, out_names):
if not _valid_repeats(mesh, rep, dst):
raise core.JaxprTypeError("shard_map can't prove output is sufficiently "
"replicated")
out_avals_sharded = [x.aval for x in jaxpr.outvars]
out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded)
return out_avals, jaxpr.effects
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> Set[AxisName]:
return set(mesh.axis_names) - set(n for ns in names.values() for n in ns)
def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]],
) -> Sequence[Set[AxisName]]:
env: Dict[core.Var, Set[AxisName]] = {}
def read(x: core.Atom) -> Set[AxisName]:
return env[x] if type(x) is core.Var else set(mesh.axis_names)
def write(v: core.Var, val: Set[AxisName]) -> None:
env[v] = val
map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
map(write, jaxpr.invars, in_rep)
for e in jaxpr.eqns:
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
out_rep = rule(mesh, *map(read, e.invars), **e.params)
if e.primitive.multiple_results:
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
map(write, e.outvars, out_rep)
else:
write(e.outvars[0], out_rep)
return map(read, jaxpr.outvars)
def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool:
return _unmentioned(mesh, dst).issubset(rep)
# Lowering
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
check_rep):
del check_rep
sharded_avals = [v.aval for v in jaxpr.invars]
in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in,
sharded_avals, in_nodes)
new_axis_context = mlir.SPMDAxisContext(mesh, frozenset(mesh.axis_names))
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
out_nodes_, _ = mlir.jaxpr_subcomp(sub_ctx, jaxpr, mlir.TokenSet(),
(), *in_nodes_,
dim_var_values=ctx.dim_var_values)
sharded_avals = [v.aval for v in jaxpr.outvars]
return map(partial(_xla_unshard, mesh), out_names, sharded_avals,
ctx.avals_out, out_nodes_)
mlir.register_lowering(shard_map_p, _shard_map_lowering)
def _xla_shard(mesh, names, aval_in, aval_out, x):
manual_proto = pxla._manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
axes = {name: i for i, ns in names.items() for name in ns}
sharding_proto = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval_in, axes).sharding_proto()
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=set())
return [mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, set())]
def _xla_unshard(mesh, names, aval_in, aval_out, xs):
x, = xs
manual_proto = pxla._manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
axes = {name: i for i, ns in names.items() for name in ns}
sharding_proto = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval_out, axes).sharding_proto()
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, set())
# Eager evaluation
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
check_rep):
del prim
args = map(partial(_unmatch_spec, mesh), in_names, args)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
with core.new_base_main(ShardMapTrace, mesh=mesh) as main:
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
t = main.with_cur_sublevel()
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(t.full_raise, ans)
outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers)
del main, t, in_tracers, ans, out_tracers
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
_check_names(out_names_thunk(), out_avals)
if check_rep: _check_reps(mesh, out_names_thunk(), out_rep)
return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_)
core.EvalTrace.process_shard_map = _shard_map_impl
def _names_to_pspec(names: AxisNames) -> PartitionSpec:
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType:
with core.eval_context():
return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x)
def _unmatch(mesh, src_tup, x):
src = _names_to_pspec(dict(src_tup))
dst = P(mesh.axis_names)
return shard_map(_add_singleton, mesh, (src,), dst)(x)
def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray]
) -> None:
fail = [a if not max(n, default=0) < a.ndim else no_fail
for n, a in zip(names, avals)]
if any(f is not no_fail for f in fail): raise _SpecError(fail)
class _SpecError(Exception): pass
def _check_reps(mesh, names, reps):
fail = [r if not _valid_repeats(mesh, r, n) else no_fail
for n, r in zip(names, reps)]
if any(f is not no_fail for f in fail): raise _RepError(fail)
class _RepError(Exception): pass
def _match_spec(mesh: Mesh, rep: Set[AxisName], dst: AxisNames, x: JaxType
) -> JaxType:
with core.eval_context():
return jax.jit(HashablePartial(_match, mesh, tuple(dst.items())))(x)
def _match(mesh, dst_tup, x):
src = P(mesh.axis_names)
dst = _names_to_pspec(dict(dst_tup))
return shard_map(_rem_singleton, mesh, (src,), dst, check_rep=False)(x)
def _rem_singleton(x): return x.reshape(x.shape[1:])
def _add_singleton(x): return x.reshape(1, *x.shape)
class ShardMapTrace(core.Trace):
mesh: Mesh
def __init__(self, *args, mesh):
super().__init__(*args)
self.mesh = mesh
def pure(self, val):
val_ = _unmatch_spec(self.mesh, {}, val)
return ShardMapTracer(self, set(self.mesh.axis_names), val_)
def sublift(self, tracer):
return ShardMapTracer(self, tracer.rep, tracer.val)
def process_primitive(self, prim, tracers, params):
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
with core.eval_context(), jax.disable_jit(False):
out_vals = jax.jit(f)(*in_vals)
rule = _rep_rules.get(prim, partial(_rep_rule, prim))
out_rep = rule(self.mesh, *in_rep, **params)
if prim.multiple_results:
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
return map(partial(ShardMapTracer, self), out_rep, out_vals)
return ShardMapTracer(self, out_rep, out_vals)
def process_call(self, call_primitive, fun, tracers, params):
if call_primitive is not xla.xla_call_p: raise NotImplementedError
fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit
bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr())
fake_primitive = pxla._FakePrimitive(multiple_results=True, bind=bind)
_rep_rules[fake_primitive] = lambda *_, **__: set()
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
out_vals = [t.val for t in out_tracers_]
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
return map(partial(ShardMapTracer, self), out_rep, out_vals)
@lu.transformation_with_aux
def _grab_jaxpr_shadily(*args):
out = yield args, {}
main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me
jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out)
yield out, jaxpr
class ShardMapTracer(core.Tracer):
rep: Set[AxisName]
val: JaxType
def __init__(self, trace, rep, val):
self._trace = trace
self.rep = rep
self.val = val
@property
def aval(self):
aval = core.get_aval(self.val)
if (isinstance(aval, core.ConcreteArray) and
self.rep == set(self._trace.mesh.axis_names)):
with core.eval_context():
return core.get_aval(self.val[0])
else:
aval = core.raise_to_shaped(aval)
return core.mapped_aval(self._trace.mesh.size, 0, aval)
def full_lower(self) -> ShardMapTracer:
return self
def __str__(self) -> str:
with core.eval_context():
blocks = list(self.val)
mesh = self._trace.mesh
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
return '\n'.join(
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
def _prim_applier(prim, params_tup, mesh, *args):
def apply(*args):
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
return tree_map(_add_singleton, outs)
return shard_map(apply, mesh, P(mesh.axis_names), P(mesh.axis_names))(*args)
# Static replication checking
def _rep_rule(prim, mesh, *in_rep, **params):
raise NotImplementedError(f"no replication rule for {prim}")
_rep_rules: Dict[core.Primitive, Callable] = {}
register_rule = lambda prim: lambda rule: _rep_rules.setdefault(prim, rule)
register_standard = lambda prim: _rep_rules.setdefault(prim, _standard_rep_rule)
def _standard_rep_rule(_, *in_rep, **__):
return set.intersection(*in_rep)
for o in lax.__dict__.values():
if isinstance(o, core.Primitive): register_standard(o)
register_standard(ad_util.add_any_p)
register_standard(lax_parallel.ppermute_p) # doesn't change replication
@register_rule(lax_parallel.psum_p)
def _psum_rule(_, *in_rep, axes, axis_index_groups):
if axis_index_groups is not None: raise NotImplementedError
axes = (axes,) if not isinstance(axes, tuple) else axes
return [r | set(axes) for r in in_rep] # introduces replication
@register_rule(lax_parallel.all_gather_p)
def _all_gather_rule(_, in_rep, *, all_gather_dimension, axis_name, axis_size,
axis_index_groups, tiled):
if axis_index_groups is not None: raise NotImplementedError
if not tiled: raise NotImplementedError
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
return in_rep | set(axis_name) # introduces replication
@register_rule(lax_parallel.reduce_scatter_p)
def _reduce_scatter_rule(_, in_rep, *, scatter_dimension, axis_name, axis_size,
axis_index_groups, tiled):
if axis_index_groups is not None: raise NotImplementedError
if not tiled: raise NotImplementedError
return in_rep - {axis_name} # removes replication
@register_rule(lax_parallel.all_to_all_p)
def _all_to_all_rule(_, in_rep, *, split_axis, concat_axis, axis_name,
axis_index_groups):
if axis_index_groups is not None: raise NotImplementedError
return in_rep - {axis_name} # removes replication
@register_rule(pjit.pjit_p)
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
# Batching
def _shard_map_batch(
trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun,
in_tracers: Sequence[batching.BatchTracer], mesh: Mesh,
in_names: Tuple[AxisNames, ...],
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
check_rep: bool) -> Sequence[batching.BatchTracer]:
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers)
if all(bdim is batching.not_mapped for bdim in in_dims):
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
out_names_thunk=out_names_thunk, check_rep=check_rep)
if trace.spmd_axis_name is not None:
raise NotImplementedError # TODO add named axis to specs
if any(isinstance(d, batching.ConcatAxis) for d in in_dims):
raise NotImplementedError
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims))
new_in_names = [{ax + (d is not batching.not_mapped and ax <= d): names[ax] # type: ignore
for ax in names} for names, d in zip(in_names, in_dims)]
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
out_names = out_names_thunk()
return [{ax + (d is not batching.not_mapped and ax <= d): names[ax]
for ax in names} for names, d in zip(out_names, out_dims())]
new_params = dict(mesh=mesh, in_names=new_in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
out_vals = prim.bind(fun, *in_vals, **new_params)
make_tracer = partial(batching.BatchTracer, trace,
source_info=source_info_util.current())
return map(make_tracer, out_vals, out_dims())
batching.BatchTrace.process_shard_map = _shard_map_batch
# Autodiff
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep):
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
which_nz = [ type(t) is not ad.Zero for t in tangents]
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
args, in_tree = tree_flatten((primals, tangents))
f_jvp = ad.jvp_subtrace(f, trace.main)
f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp)
tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz]
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
out_ax = out_names_thunk()
return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz))
params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
result = shard_map_p.bind(f_jvp, *args, **params)
primal_out, tangent_out = tree_unflatten(out_tree(), result)
tangent_out = [ad.Zero(ad.get_aval(p).at_least_vspace()) if t is None else t
for p, t in zip(primal_out, tangent_out)]
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
ad.JVPTrace.process_shard_map = _shard_map_jvp
def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep):
del mesh, in_names, out_names_thunk, check_rep
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out, treedef = tree_flatten((primals, tangents))
tangents_nz = [type(t) is not ad.Zero for t in tangents]
m = trace.main
def todo(x):
primals, tangents = tree_unflatten(treedef, x)
return map(partial(ad.JVPTracer, m.with_cur_sublevel()), primals, tangents)
def out_names_transform(out_names):
return (*out_names, *(n for n, nz in zip(out_names, tangents_nz) if nz))
return out, (todo, out_names_transform)
ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep):
in_pvals = [t.pval for t in tracers]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
f = pe.trace_to_subjaxpr_nounits(f, trace.main, False)
f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits(
f, (*in_knowns,), (*in_avals_sharded,))
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
@as_hashable_function(closure=out_names_thunk)
def known_out_names():
out_knowns, _, jaxpr, _ = aux()
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
assert not any(not v.aval.shape for v in jaxpr.constvars)
res_names = ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
return (*out_known_names, *res_names)
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
out_names_thunk=known_out_names, check_rep=check_rep)
out = shard_map_p.bind(f_known, *in_consts, **known_params)
out_knowns, out_avals_sharded, jaxpr, env = aux()
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
unk_in_names = (({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
+ (*unk_in_names,))
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)
unk_arg_tracers = [t for t in tracers if not t.is_known()]
unk_params = dict(mesh=mesh, in_names=unk_in_names,
out_names=unk_out_names, jaxpr=jaxpr, check_rep=False)
out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type]
out_tracers, shard_map_p, unk_params,
jaxpr.effects, source_info_util.current())
for t in out_tracers: t.recipe = eqn
return pe.merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
def _shard_map_partial_eval_post_process(
trace, tracers, mesh, in_names, out_names_thunk, check_rep):
del check_rep
unk_tracers = [t for t in tracers if not t.is_known()]
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
out = [*consts, *res]
main = trace.main
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr)
def todo(out):
trace = main.with_cur_sublevel()
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
out_names=(*out_names_unknown,), check_rep=False)
out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
name_stack = trace._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
shard_map_p, staged_params, jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
def out_names_transform(out_names):
nonlocal out_names_unknown
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
out_names_unknown: Optional[list] = None
return out, (todo, out_names_transform)
pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
@lu.transformation
def _promote_scalar_residuals(*args, **kwargs):
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs
which_scalar = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
for v in jaxpr.constvars]
out_consts_ = [jax.lax.broadcast(x, (1,)) if scalar else x
for x, scalar in zip(out_consts, which_scalar)]
@lu.wrap_init
def fun(*args):
out_consts = [x.reshape(*x.shape[1:]) if scalar else x
for x, scalar in zip(out_consts_, which_scalar)]
return core.eval_jaxpr(jaxpr, out_consts, *args)
in_avals = [v.aval for v in jaxpr.invars]
jaxpr, _, out_consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
yield jaxpr, (out_pvals, out_consts, env)
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
for ns, x in zip(in_names, args)]
all_args, in_tree = tree_flatten((out_cts, args))
@lu.wrap_init
def fun_trans(out_cts, args):
res, undefs = partition_list(map(ad.is_undefined_primal, args), args)
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False)
res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res)
out = ad.backward_pass(jaxpr_unknown.jaxpr, set(), False, (),
(*res_reshaped, *undefs), out_cts)
return [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
for ns, x in zip(in_names, out)]
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree)
new_in_names = \
[n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \
[n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal]
def new_out_names_thunk():
return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz)
out_flat = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
def _shard_map_axis_subst(params, subst, traverse):
if 'jaxpr' not in params:
return params
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['mesh'].shape else subst(name)
with core.extend_axis_env_nd(params['mesh'].shape.items()):
new_jaxpr = core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
return dict(params, jaxpr=new_jaxpr)
core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst
# TODO(mattjj): move this to _src/util.py
class HashablePartial:
def __init__(self, f, *args, **kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
def __eq__(self, other):
return (type(other) is HashablePartial and
self.f.__code__ == other.f.__code__ and
self.args == other.args and self.kwargs == other.kwargs)
def __hash__(self):
return hash((self.f.__code__, self.args, tuple(self.kwargs.items())))
def __call__(self, *args, **kwargs):
return self.f(*self.args, *args, **self.kwargs, **kwargs)

View File

@ -542,6 +542,7 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
# this expansion.
for p in xb.expand_platform_alias(platform):
_platform_specific_lowerings[p][prim] = rule
return rule
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x

View File

@ -426,7 +426,9 @@ def jaxpr_collectives(jaxpr):
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
def _xla_call_partial_eval_update_params(params, kept_inputs, num_new_inputs):
def _xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers

View File

@ -1007,6 +1007,11 @@ jax_test(
},
)
jax_test(
name = "shard_map_test",
srcs = ["shard_map_test.py"],
)
jax_test(
name = "clear_backends_test",
srcs = ["clear_backends_test.py"],

382
tests/shard_map_test.py Normal file
View File

@ -0,0 +1,382 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import os
import unittest
from absl.testing import absltest
import numpy as np
import jax
from jax import lax
from jax.config import config
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec as P
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
config.parse_flags_with_absl()
# Helper for some tests.
def create_inputs(a_sharding, b_sharding):
x, y, z = 2, 2, 2 # pylint: disable=invalid-name
devices = np.array(jax.devices()[:x * y * z]).reshape((x, y, z))
mesh = Mesh(devices, axis_names=('x', 'y', 'z'))
b, e, f = 8, 8, 8 # pylint: disable=invalid-name
m1 = jax.device_put(
jnp.arange(b * e).reshape((b, e)),
jax.sharding.NamedSharding(mesh, a_sharding))
m2 = jax.device_put(
jnp.arange(e * f).reshape((e, f)),
jax.sharding.NamedSharding(mesh, b_sharding))
return mesh, m1, m2
# Run all tests with 8 CPU devices.
prev_xla_flags = None
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
if len(jax.devices()) < 8:
raise unittest.SkipTest("tests require 8 devices")
if not jax.config.jax_array:
raise unittest.SkipTest("test requires jax_array")
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
class ShardMapTest(jtu.JaxTestCase):
def test_identity(self):
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
assert a.device_buffers[0].shape == (4, 2)
def identity(x):
return x
@jax.jit
def fwd(a):
c = shard_map(
lambda x: x,
mesh,
in_specs=(P('z', ('x', 'y')),),
out_specs=P('z', ('x', 'y')))(a)
return c
c = fwd(a)
self.assertEqual(c.device_buffers[0].shape, (4, 2))
def test_all_gather(self):
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
assert a.device_buffers[0].shape == (4, 2)
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y')))
def fwd(a):
return lax.all_gather(a, 'z', axis=0, tiled=True)
c = fwd(a)
self.assertEqual(c.device_buffers[0].shape, (8, 2))
def test_matmul_partial(self):
raise unittest.SkipTest("invalid replication asserted by out_spec?")
mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
assert a.device_buffers[0].shape == (4, 4)
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('z', 'y'), P('y', None)), out_specs=P('z', None))
def fwd(a):
c = jnp.matmul(a, b) # [B.z, F] {y.unreduced}
return c
c = fwd(a)
self.assertEqual(c.device_buffers[0].shape, (4, 8))
def test_matmul_reduce_scatter(self):
mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
assert a.device_buffers[0].shape == (4, 4)
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('z', 'y'), P('y', None)),
out_specs=P(('z', 'y'), None))
def fwd(a, b):
c = jnp.matmul(a, b) # [B.z, F] {y.unreduced}
return lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True)
c = fwd(a, b)
self.assertEqual(c.device_buffers[0].shape, (2, 8))
def test_collective_permute(self):
devices = np.array(jax.devices())
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)))
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('x', None),),
out_specs=P('x', None))
def fwd(a):
axis_size = lax.psum(1, 'x')
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
return lax.ppermute(a, 'x', perm=perm)
c = fwd(a)
self.assertAllClose(c[1, :], a[0, :])
def test_all_to_all(self):
devices = np.array(jax.devices())
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)))
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('x', None),), out_specs=P(None, 'x'))
def fwd(a):
return lax.all_to_all(a, 'x', split_axis=1, concat_axis=1, tiled=True)
c = fwd(a)
assert (c == jnp.reshape(a.T, (1, 64))).all()
def test_eager_repr(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
s = None
@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y'))
def f(x):
nonlocal s
s = str(x)
return x
_ = f(np.arange(8 * 8.).reshape(8, 8))
self.assertIsInstance(s, str)
self.assertIn('at mesh coordinates', s)
def test_jvp_basic(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
args = np.arange(4 * 4.).reshape(4, 4),
jtu.check_grads(g, args, 2, ['fwd'])
jtu.check_grads(jax.jit(g), args, 2, ['fwd'])
def test_linearize_basic(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
x = np.arange(4 * 4.).reshape(4, 4)
y, y_dot = jax.jvp(g, [x], [x])
y_, g_lin = jax.linearize(g, x)
y_dot_ = g_lin(x)
self.assertAllClose(y, y_, check_dtypes=False)
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
def test_linearize_basic_repres(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh,
in_specs=(P('x',),), out_specs=P('x',))
x = np.arange(4.)
y, y_dot = jax.jvp(g, [x], [x])
y_, g_lin = jax.linearize(g, x)
y_dot_ = g_lin(x)
self.assertAllClose(y, y_, check_dtypes=False)
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
def test_linearize_basic_repres_jit(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
in_specs=(P('x',),), out_specs=P('x',))
x = np.arange(4.)
y, y_dot = jax.jvp(g, [x], [x])
y_, g_lin = jax.linearize(g, x)
y_dot_ = g_lin(x)
self.assertAllClose(y, y_, check_dtypes=False)
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
def test_replication_checker_eager(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
x = np.arange(8 * 8.).reshape(8, 8)
def f(x):
return 2 * x
def g(x):
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
with self.assertRaisesRegex(ValueError, 'statically inferred'):
g(x)
def f2(x):
return jax.lax.psum(x, 'x')
def g2(x):
return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
_ = g2(x) # doesn't crash
def test_replication_checker_jit(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
x = np.arange(8 * 8.).reshape(8, 8)
def f(x):
return 2 * x
def g(x):
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
with self.assertRaisesRegex(ValueError, 'statically inferred'):
jax.jit(g)(x)
def f2(x):
return jax.lax.psum(x, 'x')
def g2(x):
return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
_ = jax.jit(g2)(x) # doesn't crash
def test_process_env_traces(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
x = np.arange(8.)
def g(x):
y = (3. * x).sum()
z = shard_map(lambda x: 2 * x * y, mesh,
in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.))
return z
jtu.check_grads(g, (x,), modes=['fwd'], order=2)
def test_eager_control_flow(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
x = jnp.arange(2 * 2.).reshape(2, 2)
def f(x):
y = jax.lax.psum(x, ('x', 'y'))
if y < 0:
return x
else:
return -x
def g(x):
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x)
y = g(x)
self.assertAllClose(y, -x, check_dtypes=False)
def test_outer_jit_detects_shard_map_mesh(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
_ = jax.jit(f)(jnp.array(2.0)) # doesnt crash
def test_vmap_basic(self):
if jax.config.jax_jit_pjit_api_merge:
raise unittest.SkipTest("pjit batcher error") # TODO(mattjj)
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
x = jnp.arange(8 * 8.).reshape(8, 8)
def g(x):
return shard_map(lambda x: 2. * x, mesh,
in_specs=P('y'), out_specs=P('y'))(x)
y = jax.vmap(g, axis_name='x')(x)
self.assertAllClose(y, 2 * x, check_dtypes=False)
def test_tree_prefix_error(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
@partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y'))
def f(x):
return x
x = jnp.arange(8 * 8.).reshape(8, 8)
with self.assertRaisesRegex(ValueError, r'shard_map in_specs\[0\]'):
f([x, x])
def test_rank_errors(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
def foo():
return {'hi': [3.]}
with self.assertRaisesRegex(ValueError, 'which has length 1'):
shard_map(foo, mesh=mesh, in_specs=(), out_specs={'hi': P('x')})()
with self.assertRaisesRegex(ValueError, 'which has length 1'):
jax.jit(lambda: shard_map(foo, mesh=mesh,
in_specs=(), out_specs={'hi': P('x')})())()
with self.assertRaisesRegex(ValueError, 'which has rank 0'):
shard_map(foo, mesh=mesh, in_specs=({'hi': P('x')},), out_specs=())(
{'hi': [jnp.array(3.)]})
def test_reverse_mode_ad(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('x',), P(None)), out_specs=P('x',))
def f(x, y):
return jnp.sin(x) + 3 + jnp.tan(2.) * jnp.cos(x) + y
x = jnp.arange(8.) / 10.
y = jnp.arange(4.) / 10.
jtu.check_grads(f, (x, y), modes=['fwd', 'rev'], order=2)
def test_post_process(self):
# JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
def f(x):
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def g(y):
return jnp.sin(y) * jnp.sin(x).sum()
return g(jnp.arange(8.))
x = jnp.arange(8.)
_, f_lin = jax.linearize(f, x)
y_dot = f_lin(x)
y_dot_expected = jnp.sin(jnp.arange(8.)) * (jnp.cos(x) * x).sum()
self.assertAllClose(y_dot, y_dot_expected, check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())