From ff1e9b3973430055ceb558596f207b9c0ebe7232 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 4 Nov 2022 15:29:10 -0700 Subject: [PATCH] shard_map (shmap) prototype and JEP Co-authored-by: Sharad Vikram Co-authored-by: Sholto Douglas --- docs/jep/14273-shard-map.md | 640 +++++++++++++++++++++ docs/jep/index.rst | 1 + jax/BUILD | 1 + jax/_src/custom_derivatives.py | 3 - jax/_src/dispatch.py | 11 +- jax/_src/util.py | 19 +- jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/shard_map.py | 918 ++++++++++++++++++++++++++++++ jax/interpreters/mlir.py | 1 + jax/interpreters/xla.py | 4 +- tests/BUILD | 5 + tests/shard_map_test.py | 382 +++++++++++++ 12 files changed, 1971 insertions(+), 15 deletions(-) create mode 100644 docs/jep/14273-shard-map.md create mode 100644 jax/experimental/shard_map.py create mode 100644 tests/shard_map_test.py diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md new file mode 100644 index 000000000..3db9df687 --- /dev/null +++ b/docs/jep/14273-shard-map.md @@ -0,0 +1,640 @@ +# `shmap` (`shard_map`) for simple per-device code +*sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@* + +*January 2023* + +## Motivation + +JAX supports two schools of thought for multi-device programming: +1. **Compiler, take the wheel!** Let the compiler automatically partition bulk + array functions over devices. +2. **Just let me write what I mean, damnit!** Give me per-device code and + explicit communication collectives. + +We need great APIs for both, and rather than being mutually exclusive +alternatives, they need to compose with each other. + +With `pjit` (now just `jit`) we have [a next-gen +API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +for the first school. But we haven't quite leveled-up the second school. `pmap` +follows the second school, but over time we found it has [fatal +flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws, +but it doesn't quite give us per-device shapes, and it includes several other +big ideas too. Meanwhile, new demands for per-device explicit-collectives +programming have emerged, like in [Efficiently Scaling Transformer +Inference](https://arxiv.org/abs/2211.05102). + +We can level-up the second school with `shmap`. `shmap` is: +* a simple multi-device parallelism API which lets us write per-device code with + explicit collectives, where logical shapes match per-device physical buffer + shapes and collectives correspond exactly to cross-device communication; +* a specialization of `xmap` with scaled-back features and a few tweaks; +* a fairly direct surfacing of the XLA SPMD Partitioner's 'manual' mode; +* a fun-to-say Seussian name which could stand for `shard_map`, + `shpecialized_xmap`, `sholto_map`, or `sharad_map`. + +**For `pjit` users**, `shmap` is a complementary tool. It can be used inside a +`pjit` computation to drop temporarily into a "manual collectives" mode, like an +escape hatch from the compiler's automatic partitioning. That way, users get the +convenience and familiar just-NumPy programming model of `pjit` for most of their +code, along with the ability to hand-optimize collective communication with +`shmap` wherever it's needed. It's the best of both worlds! + +**For `pmap` users**, `shmap` is a strict upgrade. It's more expressive, +performant, and composable with other JAX APIs, without making basic batch data +parallelism any harder. + +For more on practical use, you can jump to [When should you use `shmap` and when +should you use `pjit`?](#when-should-you-use-shmap-and-when-should-you-use-pjit). +If you're wondering why we need a new thing at all, or what +the problems with `pmap` are, jump to [Why don't `pmap` or `xmap` already solve +this?](#why-dont-pmap-or-xmap-already-solve-this). +Or keep reading the next section to see some `shmap` examples and the API spec. + + +## So, let's see `shmap`! + +### TL;DR example (with a more detailed explanation to follow) + +Sho shick: + +```python +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PartitionSpec as P +from jax.experimental import mesh_utils + +devices = mesh_utils.create_device_mesh((4, 2)) +mesh = Mesh(devices, axis_names=('x', 'y')) + +a = jnp.arange( 8 * 16.).reshape(8, 16) +b = jnp.arange(16 * 32.).reshape(16, 32) + +@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), + out_specs=P('x', None)) +def matmul_basic(a_block, b_block): + # a_block: f32[2, 8] + # b_block: f32[8, 32] + z_partialsum = jnp.dot(a_block, b_block) + z_block = jax.lax.psum(z_partialsum, 'y') + return z_block + +c = matmul_basic(a, b) # c: f32[8, 32] +``` + +Notice: +* no nesting needed (or `axis_index_groups`) for multiple axes of parallelism, + unlike `pmap`; +* no reshapes in the caller, unlike `pmap` and hard-`xmap`, and logical shapes + correspond to per-device physical shapes, unlike (non-hard) `xmap`; +* precise device placement control by using `mesh`, unlike `pmap`; +* there's only one set of axis names for logical and physical, unlike `xmap`; +* the result is a `jax.Array` which could be efficiently passed to a `pjit`, + unlike `pmap`; +* this same code works efficiently inside a `pjit`/`jit`, unlike `pmap`; +* this code works eagerly, so we can `pdb` in the middle and print values, + unlike `xmap`'s current implementation (though by design `xmap` without the + sequential schedule can in principle work eagerly too). + +Here's another matmul variant with a fully sharded result: + +```python +@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), + out_specs=P('x', 'y')) +def matmul_reduce_scatter(a_block, b_block): + # c_partialsum: f32[8/X, 32] + c_partialsum = jnp.matmul(a_block, b_block) + # c_block: f32[8/X, 32/Y] + c_block = lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True) + return c_block + +c = matmul_reduce_scatter(a, b) +``` + +### Slow down, start with the basics! + +#### Rank-reducing vs rank-preserving maps over array axes + +We can think of `pmap` (and `vmap` and `xmap`) as unstacking each array input +along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body +function to each piece, and stacking the results back together, at least when +collectives aren't involved: + +```python +pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs]) +``` + +For example, if `xs` had shape `f32[8,5]` then each `x` has shape `f32[5]`, and +if each `f(x)` has shape `f32[3,7]` then the final stacked result `pmap(f)(xs)` +has shape `f32[8,3,7]`. That is, each application of the body function `f` takes +as argument inputs with one fewer axis than the corresponding argument to +`pmap(f)`. We can say these are *rank-reducing maps* with unstacking/stacking of +inputs/outputs. + +The number of logical applications of `f` is determined by the size of the input +axis being mapped over: for example, if we map over an input axis of size 8, +semantically we get 8 logical applications of the function, which for pmap +always correspond to 8 devices physically computing them. + +In contrast, `shmap` does not have this rank-reducing behavior. Instead, we can +think of it as slicing (or "unconcatenating") along input axes into blocks, +applying the body function, and concatenating the results back together (again +when collectives aren't involved): + +```python +devices = np.array(jax.devices()[:4]) +m = Mesh(devices, ('i',)) # mesh.shape['i'] = 4 + +shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y) +== +jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)]) +``` + +Recall that `jnp.split` slices its input into equally-sized blocks with the same +rank, so that if in the above example `y` has shape `f32[8,5]` then each `y_blk` +has shape `f32[2,5]`, and if each `f(y_blk)` has shape `f32[3,7]` then the final +concatenated result `shmap(f, ...)(y)` has shape `f32[12,7]`. So `shmap` +(`shard_map`) maps over shards, or blocks, of its inputs. We can say it's a +*rank-preserving ma*p with unconcatenating/concatenating of its inputs/outputs. + +The number of logical applications of `f` is determined by the mesh size, not by +any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4 +devices) then semantically we get 4 logical applications of the function, +corresponding to the 4 devices physically computing them. + +#### Controlling how each input is split (unconcatenated) and tiled with `in_specs` + +Each of the `in_specs` identifies some of the corresponding input array's axes +with mesh axes by name using `PartitionSpec`s, representing how to split (or +unconcatenate) that input into the blocks to which the body function is applied. +That identification determines the shard sizes; when an input axis is identified +with a mesh axis, the input is split (unconcatenated) along that logical axis +into a number of pieces equal to the corresponding mesh axis size. (It's an +error if the corresponding mesh axis size does not evenly divide the input array +axis size.) If an input's pspec does not mention a mesh axis name, then there's +no splitting over that mesh axis. For example: + +```python +devices = np.array(jax.devices()) +m = Mesh(devices.reshape(4, 2), ('i', 'j')) + +@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j')) +def f1(x_block): + print(x_block.shape) + return x_block + +x1 = np.arange(12 * 12).reshape(12, 12) +y = f1(x1) # prints (3,12) +``` + +Here, because the input pspec did not mention the mesh axis name `'j'`, no input +array axis is split over that mesh axis; similarly, because the second axis of +the input array is not identified with (and hence split over) any mesh axis, +application of `f1` gets a full view of the input along that axis. + +When a mesh axis is not mentioned in an input pspec, we can always rewrite to a +less efficient program where all mesh axes are mentioned but the caller performs +a `jnp.tile`, for example: + +```python +@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j')) +def f2(x_block): + print(x_block.shape) + return x_block + +x = np.arange(12 * 12).reshape(12, 12) +x_ = jnp.tile(x, (1, mesh.axis_size['j'])) # x_ has shape (12, 24) +y = f2(x_) # prints (3,12), and f1(x) == f2(x_) +``` + +In other words, because each input pspec can mention each mesh axis name zero or +one times, rather than having to mention each name exactly once, we can say that +in addition to the `jnp.split` built into its input, `shard_map` also has a +`jnp.tile` built into its input, at least logically (though the tiling may not +need to be carried out physically, depending on the arguments' physical sharding +layout). The tiling to use is not unique; we could also have tiled along the +first axis, and used the pspec `P(('j', 'i'), None)`. + +Physical data movement is possible on inputs, as each device needs to have a +copy of the appropriate data. + +#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs` + +Analogously to the input side, each of the `out_specs` identifies some of the +corresponding output array's axes with mesh axes by name, representing how the +output blocks (one for each application of the body function, or equivalently +one for each physical device) should be assembled back together to form the +final output value. For example, in both the `f1` and `f2` examples above the +`out_specs` indicate we should form the final output by concatenating together +the block results along both axes, resulting in both cases an array `y` of shape +`(12,24)`. (It's an error if an output shape of the body function, i.e. an +output block shape, has a rank too small for the concatenation described by the +corresponding output pspec.) + +When a mesh axis name is not mentioned in an output pspec, it represents an +*un-tiling*: when the user writes an output pspec which does not mention one of +the mesh axis names, they promise that the output blocks are equal along that +mesh axis, and so only one block along that axis is used in the output (rather +than concatenating all the blocks together along that mesh axis). For example, +using the same mesh as above: + +```python +x = jnp.array([[3.]]) + +z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))() +print(z) # prints the same as jnp.tile(x, (4, 2)) + +z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))() +print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,)) + +z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))() +print(z) # prints the same as jnp.tile(x, (1, 1)), or just x +``` + +Notice that the body function closing over an array value is equivalent to +passing it as an augment with a corresponding input pspec of `P(None, None)`. As +another example, following more closely to the other examples above: + +```python +@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None)) +def f3(x_block): + return jax.lax.psum(x_block, 'j') + +x = np.arange(12 * 12).reshape(12, 12) +y3 = f3(x) +print(y3.shape) # (12,6) +``` + +Notice that the result has a second axis size of 6, half the size of the input's +second axis. In this case, the un-tile expressed by not mentioning the mesh axis +name `'j'` in the output pspec was safe because of the collective `psum`, which +ensures each output block is equal along the corresponding mesh axis. Here are +two more examples where we vary which mesh axes are mentioned in the output +pspec: + +```python +@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +def f4(x_block): + return jax.lax.psum(x_block, 'i') + +x = np.arange(12 * 12).reshape(12, 12) +y4 = f4(x) +print(y4.shape) # (3,12) + + +@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None)) +def f5(x_block): + return jax.lax.psum(x_block, ('i', 'j')) + +y5 = f5(x) +print(y5.shape) # (3,6) +``` +On the physical side, not mentioning a mesh axis name in an output pspec +assembles an `Array` from the output device buffers with replicated layout along +that mesh axis. + +There is no runtime check that the output blocks are actually equal along a mesh +axis to be un-tiled along, or equivalently that the corresponding physical +buffers have equal values and thus can be interpreted as a replicated layout for +a single logical array. But we can provide a static check mechanism which raises +an error on all potentially-incorrect programs. + +Because the `out_specs` can mention mesh axis names zero or one times, and +because they can be mentioned in any order, we can say that in addition to the +`jnp.concatenate` built into its output, `shard_map` also has both an untile and +a block transpose built into its output. + +Physical data movement is not possible on outputs, no matter the output pspec. +Instead, `out_specs` just encodes how to assemble the block outputs into +`Array`s, or physically how to interpret the buffers across devices as the +physical layout of a single logical `Array`. + +### API Specification + + +```python +from jax.sharding import Mesh +Specs = PyTree[PartitionSpec] + +def shmap(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs + ) -> Callable: + ... +``` +where: +* `mesh` encodes devices arranged in an array and with associated axis names, + just like it does for `xmap` and for `sharding.NamedSharding`; +* `in_specs` and `out_specs` are `PartitionSpec`s which can + [affinely](https://en.wikipedia.org/wiki/Substructural_type_system) mention + axis names from `mesh` (not separate logical names as in `xmap`) to express + slicing/unconcatenation and concatenation of inputs and outputs, respectively + (not unstacking and stacking like `pmap` and `xmap` do), with unmentioned + names corresponding to replication and untiling + (assert-replicated-so-give-me-one-copy), respectively; +* the shapes of the arguments passed to `f` have the same ranks as the arguments + passed to `shard_map`-of-`f` (unlike `pmap` and `xmap` where the ranks are + reduced), and the shape of an argument to `f` is computed from the shape + `shape` of the corresponding argument to `shard_map`-of-`f` and the + corresponding `PartitionSpec` spec as roughly +`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`; +* the body of `f` can apply collectives using names from `mesh`. + +`shmap` is eager by default, meaning that we dispatch computations +primitive-by-primitive, so that the user can employ Python control flow on fully +replicated values and interactive `pdb` debugging to print any values. To stage +out and end-to-end compile a `shmap`ped function, just put a `jit` around it. A +consequence is that `shmap` doesn't have its own dispatch and compilation paths +like `xmap` and `pmap` currently do; it's just the `jit` path. + +When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to MHLO +is trivial: it just involves switching into 'manual SPMD mode' on the inputs, +and switching back on the outputs. (We don't currently plan to support +partially-manual-partially-automatic modes.) + +The interaction with effects is the same as with `pmap`. + +The interaction with autodiff is also just like `pmap` (rather than attempting +the new semantics that `xmap` did, corresponding to having unmapped +intermediates and hence `grad`'s `reduce_axes` as well as making `psum` +transpose to `pbroadcast` rather than `psum`). But it thus inherits an unsolved +problem from `pmap`: in some cases, instead of transposing `psum` to `psum`, and +thus performing a backward pass `psum` corresponding to the forward pass `psum`, +it can be beneficial to move the backward pass `psum` to elsewhere in the +backward pass, exploiting linearity. Many advanced `pmap` users addressed this +challenge by using `custom_vjp` to implement `psum_idrev` and `id_psumrev` +functions, but since it's easy to accidentally leave those imbalanced, that +technique is a foot-cannon. We have some ideas on how to provide this +functionality in a safer way. + +## When should you use `shmap` and when should you use `pjit`? + +One philosophy is: it is almost always simpler to write a program in `jit==pjit` +— but if a given part of the program is less optimized by the compiler than it +could be, drop into `shmap`! + +### A realistic transformer example + +In fact, we can implement a simple version of the ["collective +matmul"](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) algorithm +recently introduced in XLA to overlap communication and computation using `shmap` +and 30 lines of Python. The basic idea of the algorithm can be grasped with a +simple example. + +Suppose we want to compute `C = A @ B` where `A` is sharded by a 1D mesh on the +0-th dimension while `B` and `C` are replicated. + +```python +M, K, N = 4096, 2048, 1024 +A = jnp.arange(np.prod((M, K))).reshape((M, K)) +B = jnp.arange(np.prod((K, N))).reshape((K, N)) + +mesh = Mesh(np.array(jax.devices()), axis_names=('x')) +A_x = jax.device_put(A, NamedSharding(mesh, P('x', None))) + +@jax.jit +def f(lhs, rhs): + return lhs @ rhs + +C = f(A_x, B) +``` + +A profile shows the blocking all-gather across 8 devices before the matmul can +start. This is suboptimal because `A` is sharded on a non-contracting dimension, +and each shard of `A` can be matmul'ed with `B` independently and this chunked +computation can be overlapped with fetching of the next shard of `A` from +another device. + +image + +This overlap can be implemented using `shmap` and explicit collectives. + +```python +def collective_matmul_allgather_lhs_non_contracting(lhs, rhs): + # lhs is the looped operand; rhs is the local operand + axis_size = jax.lax.psum(1, axis_name='x') + axis_index = jax.lax.axis_index(axis_name='x') + chunk_size = lhs.shape[0] + + def f(i, carrys): + accum, lhs = carrys + # matmul for a chunk + update = lhs @ rhs + # circular shift to the left + lhs = jax.lax.ppermute( + lhs, + axis_name='x', + perm=[(j, (j - 1) % axis_size) for j in range(axis_size)] + ) + # device 0 computes chunks 0, 1, ... + # device 1 computes chunks 1, 2, ... + update_index = (((axis_index + i) % axis_size) * chunk_size, 0) + accum = jax.lax.dynamic_update_slice(accum, update, update_index) + return accum, lhs + + accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype) + # fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual() + # accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs)) + for i in range(0, axis_size - 1): + accum, lhs = f(i, (accum, lhs)) + + # compute the last chunk, without the ppermute + update = lhs @ rhs + i = axis_size - 1 + update_index = (((axis_index + i) % axis_size) * chunk_size, 0) + accum = jax.lax.dynamic_update_slice(accum, update, update_index) + + return accum +``` + +``` +jit_sharded_f = jax.jit(shard_map( + collective_matmul_allgather_lhs_non_contracting, mesh, + in_specs=(P('x', None), P()), out_specs=P())) +C = jit_sharded_f(A_x, B) +``` + +A profile shows that the all-gather is gone, and replaced with overlapped matmul +with async collective permute. This profile matches very closely with the +collective matmul paper result. + +image + +This collective matmul technique can be used to speed up feedforward blocks in +transformer layers. This typically consists of two matrix multiplications +followed by a `ReduceScatter` (to resolve partial sums from a parallelized +matrix multiplication) and preceded by an `AllGather` (to collect the sharded +dimensions along some axes and allow partial sum computation). Together, the +`ReduceScatter` from one layer and the `AllGather` for the next amount to an +`AllReduce`. + +In a typical profile, the two matmuls will be followed by an `AllReduce`, and +they will not be overlapped. Collective matmul can be used to achieve the +overlap, but is difficult to trigger, has a minimum slice size and does not yet +cover all topologies, tensor shapes and variants of collective matmul (i.e +latency and throughput optimized variants). [In a recent +paper](https://arxiv.org/abs/2211.05102), we found a ~40% gain in many +circumstances from manually implementing collective matmul variants in `shmap` +style. + +But it isn’t always more complex! We expect this to be a much more natural way +to think about pipelined computation, and plan to do some demos of that soon! + +### Another realistic example + +Here's how `shmap` might look in a transformer layer pass with a 2D weight +gathered pattern ([paper](https://arxiv.org/abs/2211.05102), Sec 3.2.3 on p. 5): + +```python +def matmul_2D_wg_manual(xnorm, q_wi, layer): + '''Calls a custom manual implementation of matmul_reducescatter''' + # [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head] + # -> (matmul) + # -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced} + # -> (reducescatter over x into X heads, B batches) + # -> [batch, maxlen, heads.YZX, q_wi_per_head] + with jax.named_scope('q_wi'): + xnorm = intermediate_dtype(xnorm) + q_wi = matmul_reducescatter( + 'bte,hed->bthd', + xnorm, + params.q_wi, + scatter_dimension=(0, 2), + axis_name='x', + layer=layer) + return q_wi + + +import partitioning.logical_to_physical as l2phys + +def pjit_transformer_layer( + hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray, + cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache], + x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Forward pass through a single layer, returning output, K, V.""" + + def my_layer(t, axis=0): + """Gets the parameters corresponding to a given layer.""" + return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False) + + # 2D: [batch.Z, time, embed.XY] + x = _with_sharding_constraint( + x, ('residual_batch', 'residual_time', 'residual_embed')) + xnorm = _layernorm(x) + # 2D: [batch, time, embed.X] + xnorm = _with_sharding_constraint( + xnorm, ('post_norm_batch', 'time', 'post_norm_embed')) + # jump into manual mode where you want to optimise + if manual: + q_wi = shard_map(matmul_2D_wg_manual, mesh + in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'), + l2phys('layers', 'heads', 'embed', 'q_wi_per_head')), + out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer) + else: + q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi)) + # 2D: [batch, time, heads.YZX, None] + q_wi = _with_sharding_constraint(q_wi, + ('post_norm_batch', 'time', 'heads', 'qkv')) + q = q_wi[:, :, :, :hparams.qkv] + q = _rope(sin, cos, q) + # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements + # swiGLU with full d_ff dimension, rather than 2/3 scaled + wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)] + wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):] + kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv)) + k = kv[:, :, 0, :hparams.qkv] + v = kv[:, :, 0, hparams.qkv:] + k = _rope(sin, cos, k) + + y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer)) + + y_mlp = special2.swish2(wi0) * wi1 + # 2D: [batch, time, heads.YZX, None] + y_mlp = _with_sharding_constraint(y_mlp, + ('post_norm_batch', 'time', 'heads', None)) + + y_fused = jnp.concatenate([y_att, y_mlp], axis=-1) + # do the second half of the mlp and the self-attn projection in parallel + y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo)) + # 2D: [batch.Z, time, embed.XY] + y_out = _with_sharding_constraint( + y_out, ('residual_batch', 'residual_time', 'residual_embed')) + z = y_out + x + z = _with_sharding_constraint( + z, ('residual_batch', 'residual_time', 'residual_embed')) + return z, k, v +``` + +In the profile below, both the first and second matmul were replaced by manually +lowered versions, where the compute (fusions) are fully overlapped with the +communication (ppermute)! One fun hint that we are using a latency optimised +variant is that the ppmerute pixels are jittered — because there are two +overlapping ppermutes using opposite ICI axes at the same time! + +All-to-all is much harder to overlap, so was left on the table. + +image + +## Why don't `pmap` or `xmap` already solve this? + +`pmap` was our first multi-device parallelism API. It follows the +per-device-code-and-explicit-collectives school. But it had major shortcomings +which make it unsuitable for today's programs: +* **Mapping multiple axes required nested `pmap`s.** Not only are nested `pmap`s + cumbersome to write, but also they make it difficult to control (or even + predict) the device placement of data and computation, and difficult to + preserve data sharding (see the next two bullets). Today's programs require + multiple axes of parallelism. +* **Controlling device placement was impossible.** Especially with multiple axes + of parallelism, programmers need to control how those axes are aligned with + hardware resources and their communication topologies. But (nested) `pmap` + doesn't offer control over how mapped program instances are placed on + hardware; there's just an automatic device order which the user can't control. + ([Gopher](https://arxiv.org/abs/2112.11446)'s use of `axis_index_groups` and a + single un-nested `pmap` was essentially a hack to get around this by + flattening multiple axes of parallelism down to one.) +* **`jit`/`pjit` composability.** `jit`-of-`pmap` is a performance footgun, as + is nesting `pmap`s, as is e.g. `scan`-of-`pmap`, because sharding is not + preserved when returning from an inner `pmap`. To preserve sharding we would + need pattern matching on jaxprs to ensure we're working with perfectly nested + pmaps, or a pmap just inside a `jit`. Moreover, `pjit` was no help here + because `pmap` targets XLA replicas while `pjit` targets the XLA SPMD + Partitioner, and composing those two is hard. +* **`jax.Array` compatibility (and hence `pjit` compatibility).** Because the + sharding of `pmap` outputs can't be expressed as `Shardings` / `OpShardings`, + due to `pmap`'s stacking rather than concatenative semantics, the output of a + `pmap` computation can't currently be passed to a `pjit` computation without + bouncing to host (or dispatching a reshaping computation). +* **Multi-controller semantics (and hence `pjit` compatibility).** + Multi-controller `pmap` concatenates values across controllers, which works well + but differs from single-controller `pmap`'s stacking semantics. More + practically, it precludes the use of non-fully-addressable `jax.Array` inputs + and outputs as we use with multi-controller `pjit`. +* **Eager mode.** We didn't make `pmap` eager-first, and though we eventually + (after 4+ years!) added eager operation with `disable_jit()`, the fact that + `pmap` has `jit` fused into it means it has its own compilation and dispatch + path (actually two dispatch paths: in Python for handling `Tracer`s, and in + C++ for performance on raw `Array` inputs!), a heavy implementation burden. +* **Reshapes needed in the caller.** A typical use case with `pmap` on 8 devices + might look like starting with a batch axis of size 128, reshaping it to split + into two axes with sizes (8, 16), and then `pmap`ping over the first. These + reshapes are awkward and the compiler often interprets them as copies instead + of view — increasing memory and time usage. + +These shortcomings aren't so bad when only doing batch data parallelism. But +when more parallelism is involved, `pmap` just can't cut it! + +`xmap` paved the way as a next-gen evolution of `pmap` and solved (almost) all these +issues. `shmap` follows in `xmap`'s footsteps and solves these problems in +essentially the same ways; indeed, `shmap` is like a specialized subset of `xmap` +(what some call the "hard `xmap`" subset), with a few tweaks. + +For the initial prototype, we chose to implement `shmap` as a separate primitive +from `xmap`, because limiting the set of features it supports makes it easier to +focus on the core functionality. For example, `shmap` doesn't allow unmapped +intermediates, making it easier not to worry about the interactions between +named axes and autodiff. Furthermore, not having to reason about interactions of +all pairs of features makes it easier to add capabilities beyond what's +implemented in `xmap` today, such as support for eager mode. + +Both `shmap` and `xmap` share significant portions of the lowering code. We +could consider merging both in the future, or even focusing solely on `shmap`, +depending on how the usage will evolve. diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 417d2dbf5..79f3c43d6 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -46,6 +46,7 @@ Then create a pull request that adds a file named 10657: Sequencing side-effects in JAX <10657-sequencing-effects> 11830: `jax.remat` / `jax.checkpoint` new implementation <11830-new-remat-checkpoint> 12049: Type Annotation Roadmap for JAX <12049-type-annotations> + 14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map> Several early JEPs were converted in hindsight from other documentation, diff --git a/jax/BUILD b/jax/BUILD index 13c27fe7b..dfec1ca2b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -100,6 +100,7 @@ py_library_providing_imports_info( "experimental/pjit.py", "experimental/global_device_array.py", "experimental/multihost_utils.py", + "experimental/shard_map.py", # until checkify is moved out of experimental "experimental/checkify.py", # to avoid circular dependencies diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 81baf7b08..c06091fa9 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -63,9 +63,6 @@ def _initial_style_jaxpr(fun, in_avals): def _close_jaxpr(jaxpr): return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) -def _initial_style_staging() -> bool: - return core.thread_local_state.trace_state.initial_style - def _sum_tangents(_, x, *xs): return reduce(ad.add_tangents, xs, x) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 594fc14f8..ff04e90a8 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -49,7 +49,8 @@ from jax._src import profiler from jax._src import stages from jax._src import traceback_util from jax._src.sharding import (PmapSharding, SingleDeviceSharding, - OpShardingSharding, Sharding) + OpShardingSharding, NamedSharding, PartitionSpec, + Sharding) from jax._src.abstract_arrays import array_types from jax._src.config import config, flags from jax._src.lib.mlir import ir @@ -563,7 +564,7 @@ def jaxpr_has_primitive(jaxpr, prim_name: str): def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]: - from jax.experimental import pjit + from jax.experimental import pjit, shard_map for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: @@ -571,6 +572,12 @@ def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]: elif eqn.primitive is pjit.pjit_p: yield from eqn.params['in_shardings'] yield from eqn.params['out_shardings'] + elif eqn.primitive is shard_map.shard_map_p: + def _names_to_pspec(names): + ndmin = max(names) + 1 if names else 0 + return PartitionSpec(*(names.get(i) for i in range(ndmin))) + yield from (NamedSharding(eqn.params['mesh'], _names_to_pspec(names)) + for names in eqn.params['in_names']) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_shardings(subjaxpr) diff --git a/jax/_src/util.py b/jax/_src/util.py index 7bf8db9e9..3f3b54602 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -403,20 +403,21 @@ def tuple_delete(t, idx): class HashableFunction: """Decouples function equality and hash from its identity. - Local lambdas and functiond defs are reallocated on each function call, making + Local lambdas and function defs are reallocated on each function call, making the functions created on different calls compare as unequal. This breaks our caching logic, which should really only care about comparing the semantics and not actual identity. This class makes it possible to compare different functions based on their - semantics. The parts that are taken into account are: the bytecode of - the wrapped function (which is cached by the CPython interpreter and is stable - across the invocations of the surrounding function), and `closure` which should - contain all values in scope that affect the function semantics. In particular - `closure` should contain all elements of the function closure, or it should be - possible to derive the relevant elements of the true function closure based - solely on the contents of the `closure` argument (e.g. in case some closed-over - values are not hashable, but are entirely determined by hashable locals). + semantics. The parts that are taken into account are: the bytecode of the + wrapped function (which is cached by the CPython interpreter and is stable + across the invocations of the surrounding function), and `closure` which + should contain all values in scope that affect the function semantics. In + particular `closure` should contain all elements of the function closure, or + it should be possible to derive the relevant elements of the true function + closure based solely on the contents of the `closure` argument (e.g. in case + some closed-over values are not hashable, but are entirely determined by + hashable locals). """ def __init__(self, f, closure): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 3f315e356..d75752cdf 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1316,6 +1316,7 @@ tf_not_yet_impl = [ "for", "inspect_sharding", "io_callback", + "shard_map", # Not high priority? "after_all", diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py new file mode 100644 index 000000000..812844dd6 --- /dev/null +++ b/jax/experimental/shard_map.py @@ -0,0 +1,918 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import enum +from functools import partial, lru_cache +import inspect +import operator as op +from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence, + Set, Tuple, TypeVar, Union, Protocol) + +import numpy as np + +import jax +from jax import core +from jax.core import Tracer +from jax.sharding import NamedSharding, PartitionSpec, Mesh +from jax._src import ad_util +from jax._src import linear_util as lu +from jax._src import pjit +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.lax import lax, parallel as lax_parallel +from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function, + memoize, partition_list, merge_lists) +from jax.api_util import flatten_fun_nokwargs, shaped_abstractify +from jax.experimental import maps +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import xla +from jax.interpreters import pxla +from jax.interpreters import ad +from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, + tree_structure, tree_leaves) +from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, + _generate_key_paths, KeyPath) + +P = PartitionSpec + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip +traceback_util.register_exclusion(__file__) + +# API + +Specs = Any # PyTree[PartitionSpec] + +@traceback_util.api_boundary +def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, + check_rep: bool = True): + if not callable(f): + raise TypeError("shard_map requires a callable for its first argument, " + f"but got {f} of type {type(f)}.") + if not isinstance(mesh, Mesh): + raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its " + f"second argument, but got {mesh} of type {type(mesh)}.") + _check_specs(SpecErrorType.input, in_specs) + _check_specs(SpecErrorType.out, out_specs) + + @traceback_util.api_boundary + def wrapped(*args): + fun = lu.wrap_init(f) + args_flat, in_tree = tree_flatten(args) + try: in_specs_flat = broadcast_prefix(in_specs, args) + except ValueError: + e, *_ = prefix_errors(in_specs, args) + raise e('shard_map in_specs') from None + _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat) + in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) + flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + + @memoize + def out_names_thunk(): + dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) + try: out_specs_flat = broadcast_prefix(out_specs, dummy) + except ValueError: + e, *_ = prefix_errors(out_specs, dummy) + raise e('shard_map out_specs') from None + return tuple(map(_canonicalize_spec, out_specs_flat)) + try: + out_flat = shard_map_p.bind( + flat_fun, *args_flat, mesh=mesh, in_names=in_names_flat, + out_names_thunk=out_names_thunk, check_rep=check_rep) + except _SpecError as e: + fails, = e.args + msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails) + if any(fail is not no_fail and not fail.shape for fail in fails): + msg += (" In particular, for rank 0 outputs which are not constant " + "over the mesh, add at least one (singleton) axis to them so " + "that they can be concatenated using out_specs.") + raise ValueError(msg) from None + except _RepError as e: + fails, = e.args + msg = _rep_error(f, mesh, out_tree(), out_specs, fails) + raise ValueError(msg) from None + return tree_unflatten(out_tree(), out_flat) + return wrapped + +# Internally use AxisNames = Dict[int, Tuple[AxisName, ...]], not PartitionSpecs +AxisName = Hashable +AxisNames = Dict[int, Tuple[AxisName, ...]] # TODO(mattjj): make it hashable +def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: + if isinstance(spec, PartitionSpec): + return {i: names if isinstance(names, tuple) else (names,) + for i, names in enumerate(spec) if names is not None} + else: + return spec + +# Error checking and messages + +SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) + +def _check_specs(error_type: SpecErrorType, specs: Any) -> None: + if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return + prefix = 'in' if error_type == SpecErrorType.input else 'out' + msgs = [f" {prefix}_specs{key.pprint()} is {x} of type {type(x).__name__}, " + for key, x in _generate_key_paths(specs) if not isinstance(x, P)] + raise TypeError( + f"shard_map {prefix}_specs argument must be a pytree of " + f"`jax.sharding.PartitionSpec` instances, but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") + +class NoFail: pass +no_fail = NoFail() + +def _check_specs_vs_args( + f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs, + in_specs_flat: List[P], xs: List) -> None: + in_avals = map(shaped_abstractify, xs) + fail = [a if not len(p) <= a.ndim else no_fail + for p, a in zip(in_specs_flat, in_avals)] + if any(f is not no_fail for f in fail): + msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) + raise ValueError(msg) + in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) + fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) + for d, ns in names.items()) else no_fail + for a, names in zip(in_avals, in_names_flat)] + if any(f is not no_fail for f in fail): + msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) + raise ValueError(msg) + +def _spec_rank_error( + error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, + fails: List[Union[core.ShapedArray, NoFail]]) -> str: + if error_type == SpecErrorType.input: + prefix, base = 'in', 'args' + ba = _try_infer_args(f, tree) + else: + prefix, base = 'out', f'{f.__name__}(*args)' + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + if error_type == SpecErrorType.input and ba is not None: + arg_key, *_ = fail_key.keys + extra = (f", where {base}[{arg_key.key}] is bound to {f.__name__}'s " + f"parameter '{list(ba.arguments.keys())[arg_key.key]}',") + else: + extra = "" + msgs.append( + f"{prefix}_specs{spec_key.pprint()} is {spec} which has length " + f"{len(spec)}, but " + f"{base}{fail_key.pprint()}{extra} has shape {aval.str_short()}, " + f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") + assert msgs + msg = (f"shard_map applied to the function '{f.__name__}' was given an " + f"{prefix}_specs entry which is too long to be compatible with the " + f"corresponding {prefix}put value from the function:\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Entries in {prefix}_specs must be of length no greater than the " + f"number of axes in the corresponding {prefix}put value.\n\n" + f"Either revise the spec to be shorter, or modify '{f.__name__}' so " + f"that its {prefix}puts have sufficient rank.") + return msg + +def _spec_divisibility_error( + f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, + fails: List[Union[core.ShapedArray, NoFail]]) -> str: + ba = _try_infer_args(f, tree) + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + if ba is not None: + arg_key, *_ = fail_key.keys + extra = (f", where args[{arg_key.key}] is bound to {f.__name__}'s " + f"parameter '{list(ba.arguments.keys())[arg_key.key]}',") + names = _canonicalize_spec(spec) + for d, ns in names.items(): + if aval.shape[d] % prod(mesh.shape[n] for n in ns): + axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" + total = 'total ' if len(ns) > 1 else '' + sz = prod(mesh.shape[n] for n in ns) + msgs.append( + f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} " + f"corresponds to in_specs{spec_key.pprint()} of value {spec}, " + f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " + f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " + f"{aval.shape[d]}") + assert msgs + msg = (f"shard_map applied to the function '{f.__name__}' was given argument " + f"arrays with axis sizes that are not evenly divisible by the " + f"corresponding mesh axis sizes:\n\n" + f"The mesh given has shape {mesh.device_ids.shape} with corresponding " + f"axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Array arguments' axis sizes must be evenly divisible by the mesh " + f"axis or axes indicated by the corresponding elements of the " + f"argument's in_specs entry. Consider checking that in_specs are " + f"correct, and if so consider changing the mesh axis sizes or else " + f"padding the input and adapting '{f.__name__}' appropriately.") + return msg + +def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, + fails: List[Union[Set, NoFail]]) -> str: + msgs = [] + for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): + dst = _canonicalize_spec(spec) + unmentioned = _unmentioned(mesh, dst) + if len(unmentioned) > 1: + need_rep = ','.join(map(str, unmentioned)) + got_rep = ','.join(map(str, rep)) + diff = ','.join(map(str, unmentioned - rep)) + msgs.append( + f"out_specs{spec_key.pprint()} is {spec} which implies that the " + f"corresponding output value is replicated across mesh axes " + f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, " + f"which is missing the required axes {diff}") + else: + need_rep_, = unmentioned + msgs.append( + f"out_specs{spec_key.pprint()} is {spec} which implies that the " + f"corresponding output value is replicated across mesh axis " + f"'{need_rep_}', but could not infer replication over any axes") + assert msgs + msg = (f"shard_map applied to the function '{f.__name__}' was given " + f"out_specs which require replication which can't be statically " + f"inferred given the mesh:\n\n" + f"The mesh given has shape {mesh.device_ids.shape} with corresponding " + f"axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + "Check if these output values are meant to be replicated over those " + "mesh axes. If not, consider revising the corresponding out_specs " + "entries. If so, consider ddisabling the check by passing the " + "check_rep=False argument to shard_map.") + return msg + +def _unmentioned(mesh: Mesh, names: AxisNames) -> Set[AxisName]: + return set(mesh.axis_names) - {n for ns in names.values() for n in ns} + +def _try_infer_args(f, tree): + dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) + try: + return inspect.signature(f).bind(*dummy_args) + except (TypeError, ValueError): + return None + +T = TypeVar('T') +def _iter_paths(tree: PyTreeDef, specs: Specs, fails: List[Union[T, NoFail]] + ) -> List[Tuple[Tuple[KeyPath, P], Tuple[KeyPath, T]]]: + failures = tree_unflatten(tree, fails) + failures_aug = _generate_key_paths(failures) + specs_ = tree_unflatten(tree_structure(specs), _generate_key_paths(specs)) + leaf = lambda x: type(x) is tuple and len(x) == 2 and type(x[1]) is P + specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) + return [((spec_key, spec), (fail_key, fail_data)) + for (spec_key, spec), (fail_key, fail_data) + in zip(specs_aug, failures_aug) if fail_data is not no_fail] + +# Primitive + +JaxType = Any +MaybeTracer = Union[JaxType, Tracer] + +class ShardMapPrimitive(core.Primitive): + multiple_results = True + + def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh, + in_names: Tuple[AxisNames, ...], + out_names_thunk: Callable[[], Tuple[AxisNames, ...]], + check_rep: bool) -> Sequence[MaybeTracer]: + top_trace = core.find_top_trace(args) + fun, env_todo = process_env_traces(fun, top_trace.level, mesh, + in_names, out_names_thunk, check_rep) + + @as_hashable_function(closure=out_names_thunk) + def new_out_names_thunk(): + out_names = out_names_thunk() + _, xforms = env_todo() + for t in xforms: + out_names = t(out_names) + return out_names + + tracers = map(top_trace.full_raise, args) + outs = top_trace.process_shard_map( # pytype: disable=attribute-error + shard_map_p, fun, tracers, mesh=mesh, in_names=in_names, + out_names_thunk=new_out_names_thunk, check_rep=check_rep) + todos, _ = env_todo() + return map(core.full_lower, core.apply_todos(todos, outs)) + + def get_bind_params(self, params): + """Goes from jaxpr form to python traceable form.""" + new_params = dict(params) + jaxpr = new_params.pop('jaxpr') + subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ()) + axes = new_params.pop('out_names') + new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes) + return [subfun], new_params + +shard_map_p = ShardMapPrimitive('shard_map') + +@lu.transformation_with_aux +def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep, + *args: Any): + outs = yield args, {} + todos, out_names_transforms = [], [] + while True: + tracers = [x for x in outs if isinstance(x, core.Tracer) + and (level is None or x._trace.level > level)] + if tracers: + ans = max(tracers, key=op.attrgetter('_trace.level')) + else: + break + trace = ans._trace.main.with_cur_sublevel() + outs = map(trace.full_raise, outs) + outs, (todo, xform) = trace.post_process_shard_map( + outs, mesh, in_names, out_names_thunk, check_rep) + todos.append(todo) + out_names_transforms.append(xform) + yield outs, (tuple(todos), tuple(out_names_transforms)) + +# Staging + +def _shard_map_staging( + trace: pe.DynamicJaxprTrace, prim: core.Primitive, fun: lu.WrappedFun, + in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh, + in_names: Tuple[AxisNames, ...], + out_names_thunk: Callable[[], Tuple[AxisNames, ...]], + check_rep: bool, + ) -> Sequence[pe.DynamicJaxprTracer]: + in_avals = [t.aval for t in in_tracers] + in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) + with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): + jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic( + fun, trace.main, in_avals_) + _check_names(out_names_thunk(), out_avals_) + if check_rep: + in_rep = map(partial(_in_names_to_rep, mesh), in_names) + out_rep = _output_rep(mesh, jaxpr, in_rep) + _check_reps(mesh, out_names_thunk(), out_rep) + out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_) + source_info = source_info_util.current() + out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] + invars = map(trace.getvar, in_tracers) + constvars = map(trace.getvar, map(trace.instantiate_const, consts)) + outvars = map(trace.makevar, out_tracers) + in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr = pe.convert_constvars_jaxpr(jaxpr) + params = dict(mesh=mesh, in_names=in_names_staged, + out_names=out_names_thunk(), jaxpr=jaxpr, check_rep=check_rep) + eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, + jaxpr.effects, source_info) + trace.frame.add_eqn(eqn) + return out_tracers +pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging + +def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue + ) -> core.AbstractValue: + if isinstance(aval, core.ShapedArray): + return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape))) + else: + raise NotImplementedError # TODO(mattjj): add table with handlers + +def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue + ) -> core.AbstractValue: + if isinstance(aval, core.ShapedArray): + return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)), + named_shape={k: v for k, v in aval.named_shape.items() + if k not in mesh.shape}) + else: + raise NotImplementedError # TODO(mattjj): add table with handlers + +# Type-checking + +def _shard_map_typecheck(*in_atoms, jaxpr, mesh, in_names, out_names, + check_rep): + for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): + if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)): + raise core.JaxprTypeError("shard_map argument avals not compatible with " + "jaxpr binder avals and in_names") + with core.extend_axis_env_nd(tuple(mesh.shape.items())): + core.check_jaxpr(jaxpr) + if check_rep: + in_rep = map(partial(_in_names_to_rep, mesh), in_names) + out_rep = _output_rep(mesh, jaxpr, in_rep) + for rep, dst in zip(out_rep, out_names): + if not _valid_repeats(mesh, rep, dst): + raise core.JaxprTypeError("shard_map can't prove output is sufficiently " + "replicated") + out_avals_sharded = [x.aval for x in jaxpr.outvars] + out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded) + return out_avals, jaxpr.effects +core.custom_typechecks[shard_map_p] = _shard_map_typecheck + +def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> Set[AxisName]: + return set(mesh.axis_names) - set(n for ns in names.values() for n in ns) + +def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]], + ) -> Sequence[Set[AxisName]]: + env: Dict[core.Var, Set[AxisName]] = {} + + def read(x: core.Atom) -> Set[AxisName]: + return env[x] if type(x) is core.Var else set(mesh.axis_names) + + def write(v: core.Var, val: Set[AxisName]) -> None: + env[v] = val + + map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars)) + map(write, jaxpr.invars, in_rep) + for e in jaxpr.eqns: + rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive)) + out_rep = rule(mesh, *map(read, e.invars), **e.params) + if e.primitive.multiple_results: + out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep + map(write, e.outvars, out_rep) + else: + write(e.outvars[0], out_rep) + return map(read, jaxpr.outvars) + +def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool: + return _unmentioned(mesh, dst).issubset(rep) + +# Lowering + +def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, + check_rep): + del check_rep + sharded_avals = [v.aval for v in jaxpr.invars] + in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in, + sharded_avals, in_nodes) + new_axis_context = mlir.SPMDAxisContext(mesh, frozenset(mesh.axis_names)) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + with core.extend_axis_env_nd(tuple(mesh.shape.items())): + out_nodes_, _ = mlir.jaxpr_subcomp(sub_ctx, jaxpr, mlir.TokenSet(), + (), *in_nodes_, + dim_var_values=ctx.dim_var_values) + sharded_avals = [v.aval for v in jaxpr.outvars] + return map(partial(_xla_unshard, mesh), out_names, sharded_avals, + ctx.avals_out, out_nodes_) +mlir.register_lowering(shard_map_p, _shard_map_lowering) + +def _xla_shard(mesh, names, aval_in, aval_out, x): + manual_proto = pxla._manual_proto(aval_in, frozenset(mesh.axis_names), mesh) + result_type, = mlir.aval_to_ir_types(aval_out) + axes = {name: i for i, ns in names.items() for name in ns} + sharding_proto = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)( + aval_in, axes).sharding_proto() + sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=set()) + return [mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, set())] + +def _xla_unshard(mesh, names, aval_in, aval_out, xs): + x, = xs + manual_proto = pxla._manual_proto(aval_in, frozenset(mesh.axis_names), mesh) + result_type, = mlir.aval_to_ir_types(aval_out) + sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set()) + axes = {name: i for i, ns in names.items() for name in ns} + sharding_proto = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)( + aval_out, axes).sharding_proto() + return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, set()) + +# Eager evaluation + +def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, + check_rep): + del prim + args = map(partial(_unmatch_spec, mesh), in_names, args) + in_rep = map(partial(_in_names_to_rep, mesh), in_names) + with core.new_base_main(ShardMapTrace, mesh=mesh) as main: + with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): + t = main.with_cur_sublevel() + in_tracers = map(partial(ShardMapTracer, t), in_rep, args) + ans = fun.call_wrapped(*in_tracers) + out_tracers = map(t.full_raise, ans) + outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers) + del main, t, in_tracers, ans, out_tracers + out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_] + _check_names(out_names_thunk(), out_avals) + if check_rep: _check_reps(mesh, out_names_thunk(), out_rep) + return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_) +core.EvalTrace.process_shard_map = _shard_map_impl + +def _names_to_pspec(names: AxisNames) -> PartitionSpec: + ndmin = max(names) + 1 if names else 0 + return PartitionSpec(*(names.get(i) for i in range(ndmin))) + +def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType: + with core.eval_context(): + return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x) + +def _unmatch(mesh, src_tup, x): + src = _names_to_pspec(dict(src_tup)) + dst = P(mesh.axis_names) + return shard_map(_add_singleton, mesh, (src,), dst)(x) + +def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] + ) -> None: + fail = [a if not max(n, default=0) < a.ndim else no_fail + for n, a in zip(names, avals)] + if any(f is not no_fail for f in fail): raise _SpecError(fail) +class _SpecError(Exception): pass + +def _check_reps(mesh, names, reps): + fail = [r if not _valid_repeats(mesh, r, n) else no_fail + for n, r in zip(names, reps)] + if any(f is not no_fail for f in fail): raise _RepError(fail) +class _RepError(Exception): pass + +def _match_spec(mesh: Mesh, rep: Set[AxisName], dst: AxisNames, x: JaxType + ) -> JaxType: + with core.eval_context(): + return jax.jit(HashablePartial(_match, mesh, tuple(dst.items())))(x) + +def _match(mesh, dst_tup, x): + src = P(mesh.axis_names) + dst = _names_to_pspec(dict(dst_tup)) + return shard_map(_rem_singleton, mesh, (src,), dst, check_rep=False)(x) + +def _rem_singleton(x): return x.reshape(x.shape[1:]) +def _add_singleton(x): return x.reshape(1, *x.shape) + +class ShardMapTrace(core.Trace): + mesh: Mesh + + def __init__(self, *args, mesh): + super().__init__(*args) + self.mesh = mesh + + def pure(self, val): + val_ = _unmatch_spec(self.mesh, {}, val) + return ShardMapTracer(self, set(self.mesh.axis_names), val_) + + def sublift(self, tracer): + return ShardMapTracer(self, tracer.rep, tracer.val) + + def process_primitive(self, prim, tracers, params): + in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) + f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) + with core.eval_context(), jax.disable_jit(False): + out_vals = jax.jit(f)(*in_vals) + rule = _rep_rules.get(prim, partial(_rep_rule, prim)) + out_rep = rule(self.mesh, *in_rep, **params) + if prim.multiple_results: + out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep + return map(partial(ShardMapTracer, self), out_rep, out_vals) + return ShardMapTracer(self, out_rep, out_vals) + + def process_call(self, call_primitive, fun, tracers, params): + if call_primitive is not xla.xla_call_p: raise NotImplementedError + fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit + bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr()) + fake_primitive = pxla._FakePrimitive(multiple_results=True, bind=bind) + _rep_rules[fake_primitive] = lambda *_, **__: set() + out_tracers_ = self.process_primitive(fake_primitive, tracers, params) + out_vals = [t.val for t in out_tracers_] + out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers]) + return map(partial(ShardMapTracer, self), out_rep, out_vals) + +@lu.transformation_with_aux +def _grab_jaxpr_shadily(*args): + out = yield args, {} + main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me + jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out) + yield out, jaxpr + +class ShardMapTracer(core.Tracer): + rep: Set[AxisName] + val: JaxType + + def __init__(self, trace, rep, val): + self._trace = trace + self.rep = rep + self.val = val + + @property + def aval(self): + aval = core.get_aval(self.val) + if (isinstance(aval, core.ConcreteArray) and + self.rep == set(self._trace.mesh.axis_names)): + with core.eval_context(): + return core.get_aval(self.val[0]) + else: + aval = core.raise_to_shaped(aval) + return core.mapped_aval(self._trace.mesh.size, 0, aval) + + def full_lower(self) -> ShardMapTracer: + return self + + def __str__(self) -> str: + with core.eval_context(): + blocks = list(self.val) + mesh = self._trace.mesh + axis_names = f"({', '.join(map(str, mesh.axis_names))},)" + return '\n'.join( + f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" + for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) + +def _prim_applier(prim, params_tup, mesh, *args): + def apply(*args): + outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) + return tree_map(_add_singleton, outs) + return shard_map(apply, mesh, P(mesh.axis_names), P(mesh.axis_names))(*args) + +# Static replication checking + +def _rep_rule(prim, mesh, *in_rep, **params): + raise NotImplementedError(f"no replication rule for {prim}") + +_rep_rules: Dict[core.Primitive, Callable] = {} +register_rule = lambda prim: lambda rule: _rep_rules.setdefault(prim, rule) +register_standard = lambda prim: _rep_rules.setdefault(prim, _standard_rep_rule) + +def _standard_rep_rule(_, *in_rep, **__): + return set.intersection(*in_rep) + +for o in lax.__dict__.values(): + if isinstance(o, core.Primitive): register_standard(o) +register_standard(ad_util.add_any_p) + +register_standard(lax_parallel.ppermute_p) # doesn't change replication + +@register_rule(lax_parallel.psum_p) +def _psum_rule(_, *in_rep, axes, axis_index_groups): + if axis_index_groups is not None: raise NotImplementedError + axes = (axes,) if not isinstance(axes, tuple) else axes + return [r | set(axes) for r in in_rep] # introduces replication + +@register_rule(lax_parallel.all_gather_p) +def _all_gather_rule(_, in_rep, *, all_gather_dimension, axis_name, axis_size, + axis_index_groups, tiled): + if axis_index_groups is not None: raise NotImplementedError + if not tiled: raise NotImplementedError + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + return in_rep | set(axis_name) # introduces replication + +@register_rule(lax_parallel.reduce_scatter_p) +def _reduce_scatter_rule(_, in_rep, *, scatter_dimension, axis_name, axis_size, + axis_index_groups, tiled): + if axis_index_groups is not None: raise NotImplementedError + if not tiled: raise NotImplementedError + return in_rep - {axis_name} # removes replication + +@register_rule(lax_parallel.all_to_all_p) +def _all_to_all_rule(_, in_rep, *, split_axis, concat_axis, axis_name, + axis_index_groups): + if axis_index_groups is not None: raise NotImplementedError + return in_rep - {axis_name} # removes replication + +@register_rule(pjit.pjit_p) +def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs): + return _output_rep(mesh, jaxpr.jaxpr, in_rep) + +# Batching + +def _shard_map_batch( + trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, + in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, + in_names: Tuple[AxisNames, ...], + out_names_thunk: Callable[[], Tuple[AxisNames, ...]], + check_rep: bool) -> Sequence[batching.BatchTracer]: + in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers) + if all(bdim is batching.not_mapped for bdim in in_dims): + return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names, + out_names_thunk=out_names_thunk, check_rep=check_rep) + if trace.spmd_axis_name is not None: + raise NotImplementedError # TODO add named axis to specs + if any(isinstance(d, batching.ConcatAxis) for d in in_dims): + raise NotImplementedError + fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims)) + new_in_names = [{ax + (d is not batching.not_mapped and ax <= d): names[ax] # type: ignore + for ax in names} for names, d in zip(in_names, in_dims)] + @as_hashable_function(closure=out_names_thunk) + def new_out_names_thunk(): + out_names = out_names_thunk() + return [{ax + (d is not batching.not_mapped and ax <= d): names[ax] + for ax in names} for names, d in zip(out_names, out_dims())] + new_params = dict(mesh=mesh, in_names=new_in_names, + out_names_thunk=new_out_names_thunk, check_rep=check_rep) + out_vals = prim.bind(fun, *in_vals, **new_params) + make_tracer = partial(batching.BatchTracer, trace, + source_info=source_info_util.current()) + return map(make_tracer, out_vals, out_dims()) +batching.BatchTrace.process_shard_map = _shard_map_batch + +# Autodiff + +def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, + out_names_thunk, check_rep): + primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) + which_nz = [ type(t) is not ad.Zero for t in tangents] + tangents = [t if type(t) is not ad.Zero else None for t in tangents] + args, in_tree = tree_flatten((primals, tangents)) + f_jvp = ad.jvp_subtrace(f, trace.main) + f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) + tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] + + @as_hashable_function(closure=out_names_thunk) + def new_out_names_thunk(): + out_ax = out_names_thunk() + return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) + params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), + out_names_thunk=new_out_names_thunk, check_rep=check_rep) + f_jvp, out_tree = ad.traceable(f_jvp, in_tree) + result = shard_map_p.bind(f_jvp, *args, **params) + primal_out, tangent_out = tree_unflatten(out_tree(), result) + tangent_out = [ad.Zero(ad.get_aval(p).at_least_vspace()) if t is None else t + for p, t in zip(primal_out, tangent_out)] + return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] +ad.JVPTrace.process_shard_map = _shard_map_jvp + +def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names, + out_names_thunk, check_rep): + del mesh, in_names, out_names_thunk, check_rep + primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) + out, treedef = tree_flatten((primals, tangents)) + tangents_nz = [type(t) is not ad.Zero for t in tangents] + m = trace.main + def todo(x): + primals, tangents = tree_unflatten(treedef, x) + return map(partial(ad.JVPTracer, m.with_cur_sublevel()), primals, tangents) + def out_names_transform(out_names): + return (*out_names, *(n for n, nz in zip(out_names, tangents_nz) if nz)) + return out, (todo, out_names_transform) +ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process + +def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, + out_names_thunk, check_rep): + in_pvals = [t.pval for t in tracers] + in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) + unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) + in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals) + f = pe.trace_to_subjaxpr_nounits(f, trace.main, False) + f = _promote_scalar_residuals(f) + f_known, aux = pe.partial_eval_wrapper_nounits( + f, (*in_knowns,), (*in_avals_sharded,)) + unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) + + @as_hashable_function(closure=out_names_thunk) + def known_out_names(): + out_knowns, _, jaxpr, _ = aux() + _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) + assert not any(not v.aval.shape for v in jaxpr.constvars) + res_names = ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars) + return (*out_known_names, *res_names) + + known_params = dict(mesh=mesh, in_names=(*known_in_names,), + out_names_thunk=known_out_names, check_rep=check_rep) + out = shard_map_p.bind(f_known, *in_consts, **known_params) + out_knowns, out_avals_sharded, jaxpr, env = aux() + out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)]) + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr = pe.convert_constvars_jaxpr(jaxpr) + unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) + unk_in_names = (({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env) + + (*unk_in_names,)) + const_tracers = map(trace.new_instantiated_const, res) + env_tracers = map(trace.full_raise, env) + unk_arg_tracers = [t for t in tracers if not t.is_known()] + unk_params = dict(mesh=mesh, in_names=unk_in_names, + out_names=unk_out_names, jaxpr=jaxpr, check_rep=False) + out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded) + out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) + for a in out_avals] + eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type] + out_tracers, shard_map_p, unk_params, + jaxpr.effects, source_info_util.current()) + for t in out_tracers: t.recipe = eqn + return pe.merge_lists(out_knowns, out_tracers, out_consts) +pe.JaxprTrace.process_shard_map = _shard_map_partial_eval + +def _shard_map_partial_eval_post_process( + trace, tracers, mesh, in_names, out_names_thunk, check_rep): + del check_rep + unk_tracers = [t for t in tracers if not t.is_known()] + jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers) + out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers]) + out = [*consts, *res] + main = trace.main + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr_ = pe.convert_constvars_jaxpr(jaxpr) + + def todo(out): + trace = main.with_cur_sublevel() + out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)]) + const_tracers = map(trace.new_instantiated_const, res) + env_tracers = map(trace.full_raise, env) + + staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env) + staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names, + out_names=(*out_names_unknown,), check_rep=False) + + out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_) + out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) + for a in out_avals] + name_stack = trace._current_truncated_name_stack() + source = source_info_util.current().replace(name_stack=name_stack) + eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, + shard_map_p, staged_params, jaxpr.effects, source) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) + + def out_names_transform(out_names): + nonlocal out_names_unknown + out_names_unknown, out_names_known = partition_list(out_knowns, out_names) + return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars) + out_names_unknown: Optional[list] = None + + return out, (todo, out_names_transform) +pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process + +@lu.transformation +def _promote_scalar_residuals(*args, **kwargs): + jaxpr, (out_pvals, out_consts, env) = yield args, kwargs + which_scalar = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape + for v in jaxpr.constvars] + out_consts_ = [jax.lax.broadcast(x, (1,)) if scalar else x + for x, scalar in zip(out_consts, which_scalar)] + + @lu.wrap_init + def fun(*args): + out_consts = [x.reshape(*x.shape[1:]) if scalar else x + for x, scalar in zip(out_consts_, which_scalar)] + return core.eval_jaxpr(jaxpr, out_consts, *args) + in_avals = [v.aval for v in jaxpr.invars] + jaxpr, _, out_consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) + yield jaxpr, (out_pvals, out_consts, env) + +def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, + check_rep): + mb_div = lambda x, y: x / y if y != 1 else x + out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero + else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns)))) + for ns, x in zip(out_names, out_cts)] + args = [x if type(x) is not ad.UndefinedPrimal else + ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) + for ns, x in zip(in_names, args)] + all_args, in_tree = tree_flatten((out_cts, args)) + + @lu.wrap_init + def fun_trans(out_cts, args): + res, undefs = partition_list(map(ad.is_undefined_primal, args), args) + jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( + pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False) + res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res) + out = ad.backward_pass(jaxpr_unknown.jaxpr, set(), False, (), + (*res_reshaped, *undefs), out_cts) + return [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero + else jax.lax.psum(x, tuple(_unmentioned(mesh, ns))) + for ns, x in zip(in_names, out)] + + fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) + fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree) + + new_in_names = \ + [n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \ + [n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal] + + def new_out_names_thunk(): + return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) + + out_flat = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), + out_names_thunk=new_out_names_thunk, check_rep=check_rep) + return tree_unflatten(out_tree(), out_flat) +ad.primitive_transposes[shard_map_p] = _shard_map_transpose + +def _shard_map_axis_subst(params, subst, traverse): + if 'jaxpr' not in params: + return params + if not traverse: + return params + def shadowed_subst(name): + return (name,) if name in params['mesh'].shape else subst(name) + with core.extend_axis_env_nd(params['mesh'].shape.items()): + new_jaxpr = core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst) + return dict(params, jaxpr=new_jaxpr) +core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst + +# TODO(mattjj): move this to _src/util.py +class HashablePartial: + def __init__(self, f, *args, **kwargs): + self.f = f + self.args = args + self.kwargs = kwargs + + def __eq__(self, other): + return (type(other) is HashablePartial and + self.f.__code__ == other.f.__code__ and + self.args == other.args and self.kwargs == other.kwargs) + + def __hash__(self): + return hash((self.f.__code__, self.args, tuple(self.kwargs.items()))) + + def __call__(self, *args, **kwargs): + return self.f(*self.args, *args, **self.kwargs, **kwargs) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index a38346b11..2fee8a5f5 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -542,6 +542,7 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule, # this expansion. for p in xb.expand_platform_alias(platform): _platform_specific_lowerings[p][prim] = rule + return rule def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index d81bc830f..36c7f7c97 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -426,7 +426,9 @@ def jaxpr_collectives(jaxpr): xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call') xla_call = xla_call_p.bind -def _xla_call_partial_eval_update_params(params, kept_inputs, num_new_inputs): +def _xla_call_partial_eval_update_params( + params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int + ) -> core.ParamDict: donated_invars = params['donated_invars'] if not kept_inputs and donated_invars: # JaxprTrace.post_process_call creates a call with no input tracers diff --git a/tests/BUILD b/tests/BUILD index a56df29ea..c443e66e2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1007,6 +1007,11 @@ jax_test( }, ) +jax_test( + name = "shard_map_test", + srcs = ["shard_map_test.py"], +) + jax_test( name = "clear_backends_test", srcs = ["clear_backends_test.py"], diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py new file mode 100644 index 000000000..8607940b5 --- /dev/null +++ b/tests/shard_map_test.py @@ -0,0 +1,382 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import os +import unittest + +from absl.testing import absltest +import numpy as np + +import jax +from jax import lax +from jax.config import config +from jax.experimental.maps import Mesh +from jax.experimental.pjit import PartitionSpec as P +from jax._src import test_util as jtu +from jax._src.lib import xla_bridge +import jax.numpy as jnp + +from jax.experimental.shard_map import shard_map + +config.parse_flags_with_absl() + +# Helper for some tests. +def create_inputs(a_sharding, b_sharding): + x, y, z = 2, 2, 2 # pylint: disable=invalid-name + devices = np.array(jax.devices()[:x * y * z]).reshape((x, y, z)) + mesh = Mesh(devices, axis_names=('x', 'y', 'z')) + b, e, f = 8, 8, 8 # pylint: disable=invalid-name + m1 = jax.device_put( + jnp.arange(b * e).reshape((b, e)), + jax.sharding.NamedSharding(mesh, a_sharding)) + m2 = jax.device_put( + jnp.arange(e * f).reshape((e, f)), + jax.sharding.NamedSharding(mesh, b_sharding)) + return mesh, m1, m2 + +# Run all tests with 8 CPU devices. +prev_xla_flags = None + +# Run all tests with 8 CPU devices. +def setUpModule(): + global prev_xla_flags + prev_xla_flags = os.getenv("XLA_FLAGS") + flags_str = prev_xla_flags or "" + # Don't override user-specified device count, or other XLA flags. + if "xla_force_host_platform_device_count" not in flags_str: + os.environ["XLA_FLAGS"] = (flags_str + + " --xla_force_host_platform_device_count=8") + # Clear any cached backends so new CPU backend will pick up the env var. + xla_bridge.get_backend.cache_clear() + + if len(jax.devices()) < 8: + raise unittest.SkipTest("tests require 8 devices") + if not jax.config.jax_array: + raise unittest.SkipTest("test requires jax_array") + +# Reset to previous configuration in case other test modules will be run. +def tearDownModule(): + if prev_xla_flags is None: + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = prev_xla_flags + xla_bridge.get_backend.cache_clear() + + +class ShardMapTest(jtu.JaxTestCase): + + def test_identity(self): + mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) + assert a.device_buffers[0].shape == (4, 2) + + def identity(x): + return x + + @jax.jit + def fwd(a): + c = shard_map( + lambda x: x, + mesh, + in_specs=(P('z', ('x', 'y')),), + out_specs=P('z', ('x', 'y')))(a) + return c + + c = fwd(a) + self.assertEqual(c.device_buffers[0].shape, (4, 2)) + + def test_all_gather(self): + mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) + assert a.device_buffers[0].shape == (4, 2) + + @jax.jit + @partial(shard_map, mesh=mesh, + in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y'))) + def fwd(a): + return lax.all_gather(a, 'z', axis=0, tiled=True) + + c = fwd(a) + self.assertEqual(c.device_buffers[0].shape, (8, 2)) + + def test_matmul_partial(self): + raise unittest.SkipTest("invalid replication asserted by out_spec?") + + mesh, a, b = create_inputs(P('z', 'y'), P('y', None)) + assert a.device_buffers[0].shape == (4, 4) + + @jax.jit + @partial(shard_map, mesh=mesh, + in_specs=(P('z', 'y'), P('y', None)), out_specs=P('z', None)) + def fwd(a): + c = jnp.matmul(a, b) # [B.z, F] {y.unreduced} + return c + + c = fwd(a) + self.assertEqual(c.device_buffers[0].shape, (4, 8)) + + def test_matmul_reduce_scatter(self): + mesh, a, b = create_inputs(P('z', 'y'), P('y', None)) + assert a.device_buffers[0].shape == (4, 4) + + @jax.jit + @partial(shard_map, mesh=mesh, + in_specs=(P('z', 'y'), P('y', None)), + out_specs=P(('z', 'y'), None)) + def fwd(a, b): + c = jnp.matmul(a, b) # [B.z, F] {y.unreduced} + return lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True) + + c = fwd(a, b) + self.assertEqual(c.device_buffers[0].shape, (2, 8)) + + def test_collective_permute(self): + devices = np.array(jax.devices()) + mesh = Mesh(devices, axis_names=('x')) + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None))) + + @jax.jit + @partial(shard_map, mesh=mesh, in_specs=(P('x', None),), + out_specs=P('x', None)) + def fwd(a): + axis_size = lax.psum(1, 'x') + perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] + return lax.ppermute(a, 'x', perm=perm) + + c = fwd(a) + self.assertAllClose(c[1, :], a[0, :]) + + def test_all_to_all(self): + devices = np.array(jax.devices()) + mesh = Mesh(devices, axis_names=('x')) + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None))) + + @jax.jit + @partial(shard_map, mesh=mesh, + in_specs=(P('x', None),), out_specs=P(None, 'x')) + def fwd(a): + return lax.all_to_all(a, 'x', split_axis=1, concat_axis=1, tiled=True) + + c = fwd(a) + assert (c == jnp.reshape(a.T, (1, 64))).all() + + def test_eager_repr(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + s = None + + @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y')) + def f(x): + nonlocal s + s = str(x) + return x + _ = f(np.arange(8 * 8.).reshape(8, 8)) + + self.assertIsInstance(s, str) + self.assertIn('at mesh coordinates', s) + + def test_jvp_basic(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) + args = np.arange(4 * 4.).reshape(4, 4), + jtu.check_grads(g, args, 2, ['fwd']) + jtu.check_grads(jax.jit(g), args, 2, ['fwd']) + + def test_linearize_basic(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) + x = np.arange(4 * 4.).reshape(4, 4) + + y, y_dot = jax.jvp(g, [x], [x]) + + y_, g_lin = jax.linearize(g, x) + y_dot_ = g_lin(x) + + self.assertAllClose(y, y_, check_dtypes=False) + self.assertAllClose(y_dot, y_dot_, check_dtypes=False) + + def test_linearize_basic_repres(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh, + in_specs=(P('x',),), out_specs=P('x',)) + x = np.arange(4.) + + y, y_dot = jax.jvp(g, [x], [x]) + + y_, g_lin = jax.linearize(g, x) + y_dot_ = g_lin(x) + + self.assertAllClose(y, y_, check_dtypes=False) + self.assertAllClose(y_dot, y_dot_, check_dtypes=False) + + def test_linearize_basic_repres_jit(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + in_specs=(P('x',),), out_specs=P('x',)) + x = np.arange(4.) + + y, y_dot = jax.jvp(g, [x], [x]) + + y_, g_lin = jax.linearize(g, x) + y_dot_ = g_lin(x) + + self.assertAllClose(y, y_, check_dtypes=False) + self.assertAllClose(y_dot, y_dot_, check_dtypes=False) + + def test_replication_checker_eager(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + x = np.arange(8 * 8.).reshape(8, 8) + + def f(x): + return 2 * x + def g(x): + return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + + with self.assertRaisesRegex(ValueError, 'statically inferred'): + g(x) + + def f2(x): + return jax.lax.psum(x, 'x') + def g2(x): + return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + _ = g2(x) # doesn't crash + + def test_replication_checker_jit(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + x = np.arange(8 * 8.).reshape(8, 8) + + def f(x): + return 2 * x + def g(x): + return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + + with self.assertRaisesRegex(ValueError, 'statically inferred'): + jax.jit(g)(x) + + def f2(x): + return jax.lax.psum(x, 'x') + def g2(x): + return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + _ = jax.jit(g2)(x) # doesn't crash + + def test_process_env_traces(self): + mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) + x = np.arange(8.) + + def g(x): + y = (3. * x).sum() + z = shard_map(lambda x: 2 * x * y, mesh, + in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.)) + return z + + jtu.check_grads(g, (x,), modes=['fwd'], order=2) + + def test_eager_control_flow(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + x = jnp.arange(2 * 2.).reshape(2, 2) + + def f(x): + y = jax.lax.psum(x, ('x', 'y')) + if y < 0: + return x + else: + return -x + + def g(x): + return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) + y = g(x) + self.assertAllClose(y, -x, check_dtypes=False) + + def test_outer_jit_detects_shard_map_mesh(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) + _ = jax.jit(f)(jnp.array(2.0)) # doesnt crash + + def test_vmap_basic(self): + if jax.config.jax_jit_pjit_api_merge: + raise unittest.SkipTest("pjit batcher error") # TODO(mattjj) + + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + x = jnp.arange(8 * 8.).reshape(8, 8) + + def g(x): + return shard_map(lambda x: 2. * x, mesh, + in_specs=P('y'), out_specs=P('y'))(x) + y = jax.vmap(g, axis_name='x')(x) + self.assertAllClose(y, 2 * x, check_dtypes=False) + + def test_tree_prefix_error(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + + @partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y')) + def f(x): + return x + + x = jnp.arange(8 * 8.).reshape(8, 8) + with self.assertRaisesRegex(ValueError, r'shard_map in_specs\[0\]'): + f([x, x]) + + def test_rank_errors(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + + def foo(): + return {'hi': [3.]} + + with self.assertRaisesRegex(ValueError, 'which has length 1'): + shard_map(foo, mesh=mesh, in_specs=(), out_specs={'hi': P('x')})() + + with self.assertRaisesRegex(ValueError, 'which has length 1'): + jax.jit(lambda: shard_map(foo, mesh=mesh, + in_specs=(), out_specs={'hi': P('x')})())() + + with self.assertRaisesRegex(ValueError, 'which has rank 0'): + shard_map(foo, mesh=mesh, in_specs=({'hi': P('x')},), out_specs=())( + {'hi': [jnp.array(3.)]}) + + def test_reverse_mode_ad(self): + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + + @jax.jit + @partial(shard_map, mesh=mesh, + in_specs=(P('x',), P(None)), out_specs=P('x',)) + def f(x, y): + return jnp.sin(x) + 3 + jnp.tan(2.) * jnp.cos(x) + y + + x = jnp.arange(8.) / 10. + y = jnp.arange(4.) / 10. + jtu.check_grads(f, (x, y), modes=['fwd', 'rev'], order=2) + + def test_post_process(self): + # JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + + def f(x): + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def g(y): + return jnp.sin(y) * jnp.sin(x).sum() + return g(jnp.arange(8.)) + + x = jnp.arange(8.) + _, f_lin = jax.linearize(f, x) + y_dot = f_lin(x) + + y_dot_expected = jnp.sin(jnp.arange(8.)) * (jnp.cos(x) * x).sum() + self.assertAllClose(y_dot, y_dot_expected, check_dtypes=False) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())