mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com> Co-authored-by: Sholto Douglas <sholto@google.com>
This commit is contained in:
parent
fcb9dfb080
commit
ff1e9b3973
640
docs/jep/14273-shard-map.md
Normal file
640
docs/jep/14273-shard-map.md
Normal file
@ -0,0 +1,640 @@
|
||||
# `shmap` (`shard_map`) for simple per-device code
|
||||
*sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@*
|
||||
|
||||
*January 2023*
|
||||
|
||||
## Motivation
|
||||
|
||||
JAX supports two schools of thought for multi-device programming:
|
||||
1. **Compiler, take the wheel!** Let the compiler automatically partition bulk
|
||||
array functions over devices.
|
||||
2. **Just let me write what I mean, damnit!** Give me per-device code and
|
||||
explicit communication collectives.
|
||||
|
||||
We need great APIs for both, and rather than being mutually exclusive
|
||||
alternatives, they need to compose with each other.
|
||||
|
||||
With `pjit` (now just `jit`) we have [a next-gen
|
||||
API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
|
||||
for the first school. But we haven't quite leveled-up the second school. `pmap`
|
||||
follows the second school, but over time we found it has [fatal
|
||||
flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws,
|
||||
but it doesn't quite give us per-device shapes, and it includes several other
|
||||
big ideas too. Meanwhile, new demands for per-device explicit-collectives
|
||||
programming have emerged, like in [Efficiently Scaling Transformer
|
||||
Inference](https://arxiv.org/abs/2211.05102).
|
||||
|
||||
We can level-up the second school with `shmap`. `shmap` is:
|
||||
* a simple multi-device parallelism API which lets us write per-device code with
|
||||
explicit collectives, where logical shapes match per-device physical buffer
|
||||
shapes and collectives correspond exactly to cross-device communication;
|
||||
* a specialization of `xmap` with scaled-back features and a few tweaks;
|
||||
* a fairly direct surfacing of the XLA SPMD Partitioner's 'manual' mode;
|
||||
* a fun-to-say Seussian name which could stand for `shard_map`,
|
||||
`shpecialized_xmap`, `sholto_map`, or `sharad_map`.
|
||||
|
||||
**For `pjit` users**, `shmap` is a complementary tool. It can be used inside a
|
||||
`pjit` computation to drop temporarily into a "manual collectives" mode, like an
|
||||
escape hatch from the compiler's automatic partitioning. That way, users get the
|
||||
convenience and familiar just-NumPy programming model of `pjit` for most of their
|
||||
code, along with the ability to hand-optimize collective communication with
|
||||
`shmap` wherever it's needed. It's the best of both worlds!
|
||||
|
||||
**For `pmap` users**, `shmap` is a strict upgrade. It's more expressive,
|
||||
performant, and composable with other JAX APIs, without making basic batch data
|
||||
parallelism any harder.
|
||||
|
||||
For more on practical use, you can jump to [When should you use `shmap` and when
|
||||
should you use `pjit`?](#when-should-you-use-shmap-and-when-should-you-use-pjit).
|
||||
If you're wondering why we need a new thing at all, or what
|
||||
the problems with `pmap` are, jump to [Why don't `pmap` or `xmap` already solve
|
||||
this?](#why-dont-pmap-or-xmap-already-solve-this).
|
||||
Or keep reading the next section to see some `shmap` examples and the API spec.
|
||||
|
||||
|
||||
## So, let's see `shmap`!
|
||||
|
||||
### TL;DR example (with a more detailed explanation to follow)
|
||||
|
||||
Sho shick:
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh, PartitionSpec as P
|
||||
from jax.experimental import mesh_utils
|
||||
|
||||
devices = mesh_utils.create_device_mesh((4, 2))
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
|
||||
a = jnp.arange( 8 * 16.).reshape(8, 16)
|
||||
b = jnp.arange(16 * 32.).reshape(16, 32)
|
||||
|
||||
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
|
||||
out_specs=P('x', None))
|
||||
def matmul_basic(a_block, b_block):
|
||||
# a_block: f32[2, 8]
|
||||
# b_block: f32[8, 32]
|
||||
z_partialsum = jnp.dot(a_block, b_block)
|
||||
z_block = jax.lax.psum(z_partialsum, 'y')
|
||||
return z_block
|
||||
|
||||
c = matmul_basic(a, b) # c: f32[8, 32]
|
||||
```
|
||||
|
||||
Notice:
|
||||
* no nesting needed (or `axis_index_groups`) for multiple axes of parallelism,
|
||||
unlike `pmap`;
|
||||
* no reshapes in the caller, unlike `pmap` and hard-`xmap`, and logical shapes
|
||||
correspond to per-device physical shapes, unlike (non-hard) `xmap`;
|
||||
* precise device placement control by using `mesh`, unlike `pmap`;
|
||||
* there's only one set of axis names for logical and physical, unlike `xmap`;
|
||||
* the result is a `jax.Array` which could be efficiently passed to a `pjit`,
|
||||
unlike `pmap`;
|
||||
* this same code works efficiently inside a `pjit`/`jit`, unlike `pmap`;
|
||||
* this code works eagerly, so we can `pdb` in the middle and print values,
|
||||
unlike `xmap`'s current implementation (though by design `xmap` without the
|
||||
sequential schedule can in principle work eagerly too).
|
||||
|
||||
Here's another matmul variant with a fully sharded result:
|
||||
|
||||
```python
|
||||
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
|
||||
out_specs=P('x', 'y'))
|
||||
def matmul_reduce_scatter(a_block, b_block):
|
||||
# c_partialsum: f32[8/X, 32]
|
||||
c_partialsum = jnp.matmul(a_block, b_block)
|
||||
# c_block: f32[8/X, 32/Y]
|
||||
c_block = lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True)
|
||||
return c_block
|
||||
|
||||
c = matmul_reduce_scatter(a, b)
|
||||
```
|
||||
|
||||
### Slow down, start with the basics!
|
||||
|
||||
#### Rank-reducing vs rank-preserving maps over array axes
|
||||
|
||||
We can think of `pmap` (and `vmap` and `xmap`) as unstacking each array input
|
||||
along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body
|
||||
function to each piece, and stacking the results back together, at least when
|
||||
collectives aren't involved:
|
||||
|
||||
```python
|
||||
pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])
|
||||
```
|
||||
|
||||
For example, if `xs` had shape `f32[8,5]` then each `x` has shape `f32[5]`, and
|
||||
if each `f(x)` has shape `f32[3,7]` then the final stacked result `pmap(f)(xs)`
|
||||
has shape `f32[8,3,7]`. That is, each application of the body function `f` takes
|
||||
as argument inputs with one fewer axis than the corresponding argument to
|
||||
`pmap(f)`. We can say these are *rank-reducing maps* with unstacking/stacking of
|
||||
inputs/outputs.
|
||||
|
||||
The number of logical applications of `f` is determined by the size of the input
|
||||
axis being mapped over: for example, if we map over an input axis of size 8,
|
||||
semantically we get 8 logical applications of the function, which for pmap
|
||||
always correspond to 8 devices physically computing them.
|
||||
|
||||
In contrast, `shmap` does not have this rank-reducing behavior. Instead, we can
|
||||
think of it as slicing (or "unconcatenating") along input axes into blocks,
|
||||
applying the body function, and concatenating the results back together (again
|
||||
when collectives aren't involved):
|
||||
|
||||
```python
|
||||
devices = np.array(jax.devices()[:4])
|
||||
m = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
|
||||
|
||||
shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
|
||||
==
|
||||
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
|
||||
```
|
||||
|
||||
Recall that `jnp.split` slices its input into equally-sized blocks with the same
|
||||
rank, so that if in the above example `y` has shape `f32[8,5]` then each `y_blk`
|
||||
has shape `f32[2,5]`, and if each `f(y_blk)` has shape `f32[3,7]` then the final
|
||||
concatenated result `shmap(f, ...)(y)` has shape `f32[12,7]`. So `shmap`
|
||||
(`shard_map`) maps over shards, or blocks, of its inputs. We can say it's a
|
||||
*rank-preserving ma*p with unconcatenating/concatenating of its inputs/outputs.
|
||||
|
||||
The number of logical applications of `f` is determined by the mesh size, not by
|
||||
any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4
|
||||
devices) then semantically we get 4 logical applications of the function,
|
||||
corresponding to the 4 devices physically computing them.
|
||||
|
||||
#### Controlling how each input is split (unconcatenated) and tiled with `in_specs`
|
||||
|
||||
Each of the `in_specs` identifies some of the corresponding input array's axes
|
||||
with mesh axes by name using `PartitionSpec`s, representing how to split (or
|
||||
unconcatenate) that input into the blocks to which the body function is applied.
|
||||
That identification determines the shard sizes; when an input axis is identified
|
||||
with a mesh axis, the input is split (unconcatenated) along that logical axis
|
||||
into a number of pieces equal to the corresponding mesh axis size. (It's an
|
||||
error if the corresponding mesh axis size does not evenly divide the input array
|
||||
axis size.) If an input's pspec does not mention a mesh axis name, then there's
|
||||
no splitting over that mesh axis. For example:
|
||||
|
||||
```python
|
||||
devices = np.array(jax.devices())
|
||||
m = Mesh(devices.reshape(4, 2), ('i', 'j'))
|
||||
|
||||
@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
|
||||
def f1(x_block):
|
||||
print(x_block.shape)
|
||||
return x_block
|
||||
|
||||
x1 = np.arange(12 * 12).reshape(12, 12)
|
||||
y = f1(x1) # prints (3,12)
|
||||
```
|
||||
|
||||
Here, because the input pspec did not mention the mesh axis name `'j'`, no input
|
||||
array axis is split over that mesh axis; similarly, because the second axis of
|
||||
the input array is not identified with (and hence split over) any mesh axis,
|
||||
application of `f1` gets a full view of the input along that axis.
|
||||
|
||||
When a mesh axis is not mentioned in an input pspec, we can always rewrite to a
|
||||
less efficient program where all mesh axes are mentioned but the caller performs
|
||||
a `jnp.tile`, for example:
|
||||
|
||||
```python
|
||||
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
|
||||
def f2(x_block):
|
||||
print(x_block.shape)
|
||||
return x_block
|
||||
|
||||
x = np.arange(12 * 12).reshape(12, 12)
|
||||
x_ = jnp.tile(x, (1, mesh.axis_size['j'])) # x_ has shape (12, 24)
|
||||
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
|
||||
```
|
||||
|
||||
In other words, because each input pspec can mention each mesh axis name zero or
|
||||
one times, rather than having to mention each name exactly once, we can say that
|
||||
in addition to the `jnp.split` built into its input, `shard_map` also has a
|
||||
`jnp.tile` built into its input, at least logically (though the tiling may not
|
||||
need to be carried out physically, depending on the arguments' physical sharding
|
||||
layout). The tiling to use is not unique; we could also have tiled along the
|
||||
first axis, and used the pspec `P(('j', 'i'), None)`.
|
||||
|
||||
Physical data movement is possible on inputs, as each device needs to have a
|
||||
copy of the appropriate data.
|
||||
|
||||
#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`
|
||||
|
||||
Analogously to the input side, each of the `out_specs` identifies some of the
|
||||
corresponding output array's axes with mesh axes by name, representing how the
|
||||
output blocks (one for each application of the body function, or equivalently
|
||||
one for each physical device) should be assembled back together to form the
|
||||
final output value. For example, in both the `f1` and `f2` examples above the
|
||||
`out_specs` indicate we should form the final output by concatenating together
|
||||
the block results along both axes, resulting in both cases an array `y` of shape
|
||||
`(12,24)`. (It's an error if an output shape of the body function, i.e. an
|
||||
output block shape, has a rank too small for the concatenation described by the
|
||||
corresponding output pspec.)
|
||||
|
||||
When a mesh axis name is not mentioned in an output pspec, it represents an
|
||||
*un-tiling*: when the user writes an output pspec which does not mention one of
|
||||
the mesh axis names, they promise that the output blocks are equal along that
|
||||
mesh axis, and so only one block along that axis is used in the output (rather
|
||||
than concatenating all the blocks together along that mesh axis). For example,
|
||||
using the same mesh as above:
|
||||
|
||||
```python
|
||||
x = jnp.array([[3.]])
|
||||
|
||||
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
|
||||
print(z) # prints the same as jnp.tile(x, (4, 2))
|
||||
|
||||
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
|
||||
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
|
||||
|
||||
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
|
||||
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
|
||||
```
|
||||
|
||||
Notice that the body function closing over an array value is equivalent to
|
||||
passing it as an augment with a corresponding input pspec of `P(None, None)`. As
|
||||
another example, following more closely to the other examples above:
|
||||
|
||||
```python
|
||||
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
|
||||
def f3(x_block):
|
||||
return jax.lax.psum(x_block, 'j')
|
||||
|
||||
x = np.arange(12 * 12).reshape(12, 12)
|
||||
y3 = f3(x)
|
||||
print(y3.shape) # (12,6)
|
||||
```
|
||||
|
||||
Notice that the result has a second axis size of 6, half the size of the input's
|
||||
second axis. In this case, the un-tile expressed by not mentioning the mesh axis
|
||||
name `'j'` in the output pspec was safe because of the collective `psum`, which
|
||||
ensures each output block is equal along the corresponding mesh axis. Here are
|
||||
two more examples where we vary which mesh axes are mentioned in the output
|
||||
pspec:
|
||||
|
||||
```python
|
||||
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
|
||||
def f4(x_block):
|
||||
return jax.lax.psum(x_block, 'i')
|
||||
|
||||
x = np.arange(12 * 12).reshape(12, 12)
|
||||
y4 = f4(x)
|
||||
print(y4.shape) # (3,12)
|
||||
|
||||
|
||||
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
|
||||
def f5(x_block):
|
||||
return jax.lax.psum(x_block, ('i', 'j'))
|
||||
|
||||
y5 = f5(x)
|
||||
print(y5.shape) # (3,6)
|
||||
```
|
||||
On the physical side, not mentioning a mesh axis name in an output pspec
|
||||
assembles an `Array` from the output device buffers with replicated layout along
|
||||
that mesh axis.
|
||||
|
||||
There is no runtime check that the output blocks are actually equal along a mesh
|
||||
axis to be un-tiled along, or equivalently that the corresponding physical
|
||||
buffers have equal values and thus can be interpreted as a replicated layout for
|
||||
a single logical array. But we can provide a static check mechanism which raises
|
||||
an error on all potentially-incorrect programs.
|
||||
|
||||
Because the `out_specs` can mention mesh axis names zero or one times, and
|
||||
because they can be mentioned in any order, we can say that in addition to the
|
||||
`jnp.concatenate` built into its output, `shard_map` also has both an untile and
|
||||
a block transpose built into its output.
|
||||
|
||||
Physical data movement is not possible on outputs, no matter the output pspec.
|
||||
Instead, `out_specs` just encodes how to assemble the block outputs into
|
||||
`Array`s, or physically how to interpret the buffers across devices as the
|
||||
physical layout of a single logical `Array`.
|
||||
|
||||
### API Specification
|
||||
|
||||
|
||||
```python
|
||||
from jax.sharding import Mesh
|
||||
Specs = PyTree[PartitionSpec]
|
||||
|
||||
def shmap(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
|
||||
) -> Callable:
|
||||
...
|
||||
```
|
||||
where:
|
||||
* `mesh` encodes devices arranged in an array and with associated axis names,
|
||||
just like it does for `xmap` and for `sharding.NamedSharding`;
|
||||
* `in_specs` and `out_specs` are `PartitionSpec`s which can
|
||||
[affinely](https://en.wikipedia.org/wiki/Substructural_type_system) mention
|
||||
axis names from `mesh` (not separate logical names as in `xmap`) to express
|
||||
slicing/unconcatenation and concatenation of inputs and outputs, respectively
|
||||
(not unstacking and stacking like `pmap` and `xmap` do), with unmentioned
|
||||
names corresponding to replication and untiling
|
||||
(assert-replicated-so-give-me-one-copy), respectively;
|
||||
* the shapes of the arguments passed to `f` have the same ranks as the arguments
|
||||
passed to `shard_map`-of-`f` (unlike `pmap` and `xmap` where the ranks are
|
||||
reduced), and the shape of an argument to `f` is computed from the shape
|
||||
`shape` of the corresponding argument to `shard_map`-of-`f` and the
|
||||
corresponding `PartitionSpec` spec as roughly
|
||||
`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`;
|
||||
* the body of `f` can apply collectives using names from `mesh`.
|
||||
|
||||
`shmap` is eager by default, meaning that we dispatch computations
|
||||
primitive-by-primitive, so that the user can employ Python control flow on fully
|
||||
replicated values and interactive `pdb` debugging to print any values. To stage
|
||||
out and end-to-end compile a `shmap`ped function, just put a `jit` around it. A
|
||||
consequence is that `shmap` doesn't have its own dispatch and compilation paths
|
||||
like `xmap` and `pmap` currently do; it's just the `jit` path.
|
||||
|
||||
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to MHLO
|
||||
is trivial: it just involves switching into 'manual SPMD mode' on the inputs,
|
||||
and switching back on the outputs. (We don't currently plan to support
|
||||
partially-manual-partially-automatic modes.)
|
||||
|
||||
The interaction with effects is the same as with `pmap`.
|
||||
|
||||
The interaction with autodiff is also just like `pmap` (rather than attempting
|
||||
the new semantics that `xmap` did, corresponding to having unmapped
|
||||
intermediates and hence `grad`'s `reduce_axes` as well as making `psum`
|
||||
transpose to `pbroadcast` rather than `psum`). But it thus inherits an unsolved
|
||||
problem from `pmap`: in some cases, instead of transposing `psum` to `psum`, and
|
||||
thus performing a backward pass `psum` corresponding to the forward pass `psum`,
|
||||
it can be beneficial to move the backward pass `psum` to elsewhere in the
|
||||
backward pass, exploiting linearity. Many advanced `pmap` users addressed this
|
||||
challenge by using `custom_vjp` to implement `psum_idrev` and `id_psumrev`
|
||||
functions, but since it's easy to accidentally leave those imbalanced, that
|
||||
technique is a foot-cannon. We have some ideas on how to provide this
|
||||
functionality in a safer way.
|
||||
|
||||
## When should you use `shmap` and when should you use `pjit`?
|
||||
|
||||
One philosophy is: it is almost always simpler to write a program in `jit==pjit`
|
||||
— but if a given part of the program is less optimized by the compiler than it
|
||||
could be, drop into `shmap`!
|
||||
|
||||
### A realistic transformer example
|
||||
|
||||
In fact, we can implement a simple version of the ["collective
|
||||
matmul"](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) algorithm
|
||||
recently introduced in XLA to overlap communication and computation using `shmap`
|
||||
and 30 lines of Python. The basic idea of the algorithm can be grasped with a
|
||||
simple example.
|
||||
|
||||
Suppose we want to compute `C = A @ B` where `A` is sharded by a 1D mesh on the
|
||||
0-th dimension while `B` and `C` are replicated.
|
||||
|
||||
```python
|
||||
M, K, N = 4096, 2048, 1024
|
||||
A = jnp.arange(np.prod((M, K))).reshape((M, K))
|
||||
B = jnp.arange(np.prod((K, N))).reshape((K, N))
|
||||
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x'))
|
||||
A_x = jax.device_put(A, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jax.jit
|
||||
def f(lhs, rhs):
|
||||
return lhs @ rhs
|
||||
|
||||
C = f(A_x, B)
|
||||
```
|
||||
|
||||
A profile shows the blocking all-gather across 8 devices before the matmul can
|
||||
start. This is suboptimal because `A` is sharded on a non-contracting dimension,
|
||||
and each shard of `A` can be matmul'ed with `B` independently and this chunked
|
||||
computation can be overlapped with fetching of the next shard of `A` from
|
||||
another device.
|
||||
|
||||
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507011-e854fb11-43d5-484d-993b-19a3349ed4b9.png">
|
||||
|
||||
This overlap can be implemented using `shmap` and explicit collectives.
|
||||
|
||||
```python
|
||||
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
|
||||
# lhs is the looped operand; rhs is the local operand
|
||||
axis_size = jax.lax.psum(1, axis_name='x')
|
||||
axis_index = jax.lax.axis_index(axis_name='x')
|
||||
chunk_size = lhs.shape[0]
|
||||
|
||||
def f(i, carrys):
|
||||
accum, lhs = carrys
|
||||
# matmul for a chunk
|
||||
update = lhs @ rhs
|
||||
# circular shift to the left
|
||||
lhs = jax.lax.ppermute(
|
||||
lhs,
|
||||
axis_name='x',
|
||||
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
|
||||
)
|
||||
# device 0 computes chunks 0, 1, ...
|
||||
# device 1 computes chunks 1, 2, ...
|
||||
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
|
||||
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
|
||||
return accum, lhs
|
||||
|
||||
accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
|
||||
# fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
|
||||
# accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
|
||||
for i in range(0, axis_size - 1):
|
||||
accum, lhs = f(i, (accum, lhs))
|
||||
|
||||
# compute the last chunk, without the ppermute
|
||||
update = lhs @ rhs
|
||||
i = axis_size - 1
|
||||
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
|
||||
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
|
||||
|
||||
return accum
|
||||
```
|
||||
|
||||
```
|
||||
jit_sharded_f = jax.jit(shard_map(
|
||||
collective_matmul_allgather_lhs_non_contracting, mesh,
|
||||
in_specs=(P('x', None), P()), out_specs=P()))
|
||||
C = jit_sharded_f(A_x, B)
|
||||
```
|
||||
|
||||
A profile shows that the all-gather is gone, and replaced with overlapped matmul
|
||||
with async collective permute. This profile matches very closely with the
|
||||
collective matmul paper result.
|
||||
|
||||
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507064-139f032c-d869-4b67-9e11-1587d4fd2de9.png">
|
||||
|
||||
This collective matmul technique can be used to speed up feedforward blocks in
|
||||
transformer layers. This typically consists of two matrix multiplications
|
||||
followed by a `ReduceScatter` (to resolve partial sums from a parallelized
|
||||
matrix multiplication) and preceded by an `AllGather` (to collect the sharded
|
||||
dimensions along some axes and allow partial sum computation). Together, the
|
||||
`ReduceScatter` from one layer and the `AllGather` for the next amount to an
|
||||
`AllReduce`.
|
||||
|
||||
In a typical profile, the two matmuls will be followed by an `AllReduce`, and
|
||||
they will not be overlapped. Collective matmul can be used to achieve the
|
||||
overlap, but is difficult to trigger, has a minimum slice size and does not yet
|
||||
cover all topologies, tensor shapes and variants of collective matmul (i.e
|
||||
latency and throughput optimized variants). [In a recent
|
||||
paper](https://arxiv.org/abs/2211.05102), we found a ~40% gain in many
|
||||
circumstances from manually implementing collective matmul variants in `shmap`
|
||||
style.
|
||||
|
||||
But it isn’t always more complex! We expect this to be a much more natural way
|
||||
to think about pipelined computation, and plan to do some demos of that soon!
|
||||
|
||||
### Another realistic example
|
||||
|
||||
Here's how `shmap` might look in a transformer layer pass with a 2D weight
|
||||
gathered pattern ([paper](https://arxiv.org/abs/2211.05102), Sec 3.2.3 on p. 5):
|
||||
|
||||
```python
|
||||
def matmul_2D_wg_manual(xnorm, q_wi, layer):
|
||||
'''Calls a custom manual implementation of matmul_reducescatter'''
|
||||
# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
|
||||
# -> (matmul)
|
||||
# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
|
||||
# -> (reducescatter over x into X heads, B batches)
|
||||
# -> [batch, maxlen, heads.YZX, q_wi_per_head]
|
||||
with jax.named_scope('q_wi'):
|
||||
xnorm = intermediate_dtype(xnorm)
|
||||
q_wi = matmul_reducescatter(
|
||||
'bte,hed->bthd',
|
||||
xnorm,
|
||||
params.q_wi,
|
||||
scatter_dimension=(0, 2),
|
||||
axis_name='x',
|
||||
layer=layer)
|
||||
return q_wi
|
||||
|
||||
|
||||
import partitioning.logical_to_physical as l2phys
|
||||
|
||||
def pjit_transformer_layer(
|
||||
hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
|
||||
cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
|
||||
x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
||||
"""Forward pass through a single layer, returning output, K, V."""
|
||||
|
||||
def my_layer(t, axis=0):
|
||||
"""Gets the parameters corresponding to a given layer."""
|
||||
return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
|
||||
|
||||
# 2D: [batch.Z, time, embed.XY]
|
||||
x = _with_sharding_constraint(
|
||||
x, ('residual_batch', 'residual_time', 'residual_embed'))
|
||||
xnorm = _layernorm(x)
|
||||
# 2D: [batch, time, embed.X]
|
||||
xnorm = _with_sharding_constraint(
|
||||
xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
|
||||
# jump into manual mode where you want to optimise
|
||||
if manual:
|
||||
q_wi = shard_map(matmul_2D_wg_manual, mesh
|
||||
in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
|
||||
l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
|
||||
out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
|
||||
else:
|
||||
q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
|
||||
# 2D: [batch, time, heads.YZX, None]
|
||||
q_wi = _with_sharding_constraint(q_wi,
|
||||
('post_norm_batch', 'time', 'heads', 'qkv'))
|
||||
q = q_wi[:, :, :, :hparams.qkv]
|
||||
q = _rope(sin, cos, q)
|
||||
# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
|
||||
# swiGLU with full d_ff dimension, rather than 2/3 scaled
|
||||
wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
|
||||
wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
|
||||
kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
|
||||
k = kv[:, :, 0, :hparams.qkv]
|
||||
v = kv[:, :, 0, hparams.qkv:]
|
||||
k = _rope(sin, cos, k)
|
||||
|
||||
y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
|
||||
|
||||
y_mlp = special2.swish2(wi0) * wi1
|
||||
# 2D: [batch, time, heads.YZX, None]
|
||||
y_mlp = _with_sharding_constraint(y_mlp,
|
||||
('post_norm_batch', 'time', 'heads', None))
|
||||
|
||||
y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
|
||||
# do the second half of the mlp and the self-attn projection in parallel
|
||||
y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
|
||||
# 2D: [batch.Z, time, embed.XY]
|
||||
y_out = _with_sharding_constraint(
|
||||
y_out, ('residual_batch', 'residual_time', 'residual_embed'))
|
||||
z = y_out + x
|
||||
z = _with_sharding_constraint(
|
||||
z, ('residual_batch', 'residual_time', 'residual_embed'))
|
||||
return z, k, v
|
||||
```
|
||||
|
||||
In the profile below, both the first and second matmul were replaced by manually
|
||||
lowered versions, where the compute (fusions) are fully overlapped with the
|
||||
communication (ppermute)! One fun hint that we are using a latency optimised
|
||||
variant is that the ppmerute pixels are jittered — because there are two
|
||||
overlapping ppermutes using opposite ICI axes at the same time!
|
||||
|
||||
All-to-all is much harder to overlap, so was left on the table.
|
||||
|
||||
<img width="1085" alt="image" src="https://user-images.githubusercontent.com/1458824/216507137-adc35a1f-a76c-4704-a62d-389b42771090.png">
|
||||
|
||||
## Why don't `pmap` or `xmap` already solve this?
|
||||
|
||||
`pmap` was our first multi-device parallelism API. It follows the
|
||||
per-device-code-and-explicit-collectives school. But it had major shortcomings
|
||||
which make it unsuitable for today's programs:
|
||||
* **Mapping multiple axes required nested `pmap`s.** Not only are nested `pmap`s
|
||||
cumbersome to write, but also they make it difficult to control (or even
|
||||
predict) the device placement of data and computation, and difficult to
|
||||
preserve data sharding (see the next two bullets). Today's programs require
|
||||
multiple axes of parallelism.
|
||||
* **Controlling device placement was impossible.** Especially with multiple axes
|
||||
of parallelism, programmers need to control how those axes are aligned with
|
||||
hardware resources and their communication topologies. But (nested) `pmap`
|
||||
doesn't offer control over how mapped program instances are placed on
|
||||
hardware; there's just an automatic device order which the user can't control.
|
||||
([Gopher](https://arxiv.org/abs/2112.11446)'s use of `axis_index_groups` and a
|
||||
single un-nested `pmap` was essentially a hack to get around this by
|
||||
flattening multiple axes of parallelism down to one.)
|
||||
* **`jit`/`pjit` composability.** `jit`-of-`pmap` is a performance footgun, as
|
||||
is nesting `pmap`s, as is e.g. `scan`-of-`pmap`, because sharding is not
|
||||
preserved when returning from an inner `pmap`. To preserve sharding we would
|
||||
need pattern matching on jaxprs to ensure we're working with perfectly nested
|
||||
pmaps, or a pmap just inside a `jit`. Moreover, `pjit` was no help here
|
||||
because `pmap` targets XLA replicas while `pjit` targets the XLA SPMD
|
||||
Partitioner, and composing those two is hard.
|
||||
* **`jax.Array` compatibility (and hence `pjit` compatibility).** Because the
|
||||
sharding of `pmap` outputs can't be expressed as `Shardings` / `OpShardings`,
|
||||
due to `pmap`'s stacking rather than concatenative semantics, the output of a
|
||||
`pmap` computation can't currently be passed to a `pjit` computation without
|
||||
bouncing to host (or dispatching a reshaping computation).
|
||||
* **Multi-controller semantics (and hence `pjit` compatibility).**
|
||||
Multi-controller `pmap` concatenates values across controllers, which works well
|
||||
but differs from single-controller `pmap`'s stacking semantics. More
|
||||
practically, it precludes the use of non-fully-addressable `jax.Array` inputs
|
||||
and outputs as we use with multi-controller `pjit`.
|
||||
* **Eager mode.** We didn't make `pmap` eager-first, and though we eventually
|
||||
(after 4+ years!) added eager operation with `disable_jit()`, the fact that
|
||||
`pmap` has `jit` fused into it means it has its own compilation and dispatch
|
||||
path (actually two dispatch paths: in Python for handling `Tracer`s, and in
|
||||
C++ for performance on raw `Array` inputs!), a heavy implementation burden.
|
||||
* **Reshapes needed in the caller.** A typical use case with `pmap` on 8 devices
|
||||
might look like starting with a batch axis of size 128, reshaping it to split
|
||||
into two axes with sizes (8, 16), and then `pmap`ping over the first. These
|
||||
reshapes are awkward and the compiler often interprets them as copies instead
|
||||
of view — increasing memory and time usage.
|
||||
|
||||
These shortcomings aren't so bad when only doing batch data parallelism. But
|
||||
when more parallelism is involved, `pmap` just can't cut it!
|
||||
|
||||
`xmap` paved the way as a next-gen evolution of `pmap` and solved (almost) all these
|
||||
issues. `shmap` follows in `xmap`'s footsteps and solves these problems in
|
||||
essentially the same ways; indeed, `shmap` is like a specialized subset of `xmap`
|
||||
(what some call the "hard `xmap`" subset), with a few tweaks.
|
||||
|
||||
For the initial prototype, we chose to implement `shmap` as a separate primitive
|
||||
from `xmap`, because limiting the set of features it supports makes it easier to
|
||||
focus on the core functionality. For example, `shmap` doesn't allow unmapped
|
||||
intermediates, making it easier not to worry about the interactions between
|
||||
named axes and autodiff. Furthermore, not having to reason about interactions of
|
||||
all pairs of features makes it easier to add capabilities beyond what's
|
||||
implemented in `xmap` today, such as support for eager mode.
|
||||
|
||||
Both `shmap` and `xmap` share significant portions of the lowering code. We
|
||||
could consider merging both in the future, or even focusing solely on `shmap`,
|
||||
depending on how the usage will evolve.
|
@ -46,6 +46,7 @@ Then create a pull request that adds a file named
|
||||
10657: Sequencing side-effects in JAX <10657-sequencing-effects>
|
||||
11830: `jax.remat` / `jax.checkpoint` new implementation <11830-new-remat-checkpoint>
|
||||
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
|
||||
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
|
||||
|
||||
|
||||
Several early JEPs were converted in hindsight from other documentation,
|
||||
|
@ -100,6 +100,7 @@ py_library_providing_imports_info(
|
||||
"experimental/pjit.py",
|
||||
"experimental/global_device_array.py",
|
||||
"experimental/multihost_utils.py",
|
||||
"experimental/shard_map.py",
|
||||
# until checkify is moved out of experimental
|
||||
"experimental/checkify.py",
|
||||
# to avoid circular dependencies
|
||||
|
@ -63,9 +63,6 @@ def _initial_style_jaxpr(fun, in_avals):
|
||||
def _close_jaxpr(jaxpr):
|
||||
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
def _initial_style_staging() -> bool:
|
||||
return core.thread_local_state.trace_state.initial_style
|
||||
|
||||
def _sum_tangents(_, x, *xs):
|
||||
return reduce(ad.add_tangents, xs, x)
|
||||
|
||||
|
@ -49,7 +49,8 @@ from jax._src import profiler
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
||||
OpShardingSharding, Sharding)
|
||||
OpShardingSharding, NamedSharding, PartitionSpec,
|
||||
Sharding)
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -563,7 +564,7 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
|
||||
|
||||
|
||||
def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]:
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import pjit, shard_map
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is pjit.sharding_constraint_p:
|
||||
@ -571,6 +572,12 @@ def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]:
|
||||
elif eqn.primitive is pjit.pjit_p:
|
||||
yield from eqn.params['in_shardings']
|
||||
yield from eqn.params['out_shardings']
|
||||
elif eqn.primitive is shard_map.shard_map_p:
|
||||
def _names_to_pspec(names):
|
||||
ndmin = max(names) + 1 if names else 0
|
||||
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
||||
yield from (NamedSharding(eqn.params['mesh'], _names_to_pspec(names))
|
||||
for names in eqn.params['in_names'])
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_shardings(subjaxpr)
|
||||
|
||||
|
@ -403,20 +403,21 @@ def tuple_delete(t, idx):
|
||||
class HashableFunction:
|
||||
"""Decouples function equality and hash from its identity.
|
||||
|
||||
Local lambdas and functiond defs are reallocated on each function call, making
|
||||
Local lambdas and function defs are reallocated on each function call, making
|
||||
the functions created on different calls compare as unequal. This breaks our
|
||||
caching logic, which should really only care about comparing the semantics and
|
||||
not actual identity.
|
||||
|
||||
This class makes it possible to compare different functions based on their
|
||||
semantics. The parts that are taken into account are: the bytecode of
|
||||
the wrapped function (which is cached by the CPython interpreter and is stable
|
||||
across the invocations of the surrounding function), and `closure` which should
|
||||
contain all values in scope that affect the function semantics. In particular
|
||||
`closure` should contain all elements of the function closure, or it should be
|
||||
possible to derive the relevant elements of the true function closure based
|
||||
solely on the contents of the `closure` argument (e.g. in case some closed-over
|
||||
values are not hashable, but are entirely determined by hashable locals).
|
||||
semantics. The parts that are taken into account are: the bytecode of the
|
||||
wrapped function (which is cached by the CPython interpreter and is stable
|
||||
across the invocations of the surrounding function), and `closure` which
|
||||
should contain all values in scope that affect the function semantics. In
|
||||
particular `closure` should contain all elements of the function closure, or
|
||||
it should be possible to derive the relevant elements of the true function
|
||||
closure based solely on the contents of the `closure` argument (e.g. in case
|
||||
some closed-over values are not hashable, but are entirely determined by
|
||||
hashable locals).
|
||||
"""
|
||||
|
||||
def __init__(self, f, closure):
|
||||
|
@ -1316,6 +1316,7 @@ tf_not_yet_impl = [
|
||||
"for",
|
||||
"inspect_sharding",
|
||||
"io_callback",
|
||||
"shard_map",
|
||||
|
||||
# Not high priority?
|
||||
"after_all",
|
||||
|
918
jax/experimental/shard_map.py
Normal file
918
jax/experimental/shard_map.py
Normal file
@ -0,0 +1,918 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from functools import partial, lru_cache
|
||||
import inspect
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
|
||||
Set, Tuple, TypeVar, Union, Protocol)
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax.core import Tracer
|
||||
from jax.sharding import NamedSharding, PartitionSpec, Mesh
|
||||
from jax._src import ad_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax, parallel as lax_parallel
|
||||
from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
|
||||
memoize, partition_list, merge_lists)
|
||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
from jax.experimental import maps
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_leaves)
|
||||
from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
|
||||
_generate_key_paths, KeyPath)
|
||||
|
||||
P = PartitionSpec
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
# API
|
||||
|
||||
Specs = Any # PyTree[PartitionSpec]
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
|
||||
check_rep: bool = True):
|
||||
if not callable(f):
|
||||
raise TypeError("shard_map requires a callable for its first argument, "
|
||||
f"but got {f} of type {type(f)}.")
|
||||
if not isinstance(mesh, Mesh):
|
||||
raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its "
|
||||
f"second argument, but got {mesh} of type {type(mesh)}.")
|
||||
_check_specs(SpecErrorType.input, in_specs)
|
||||
_check_specs(SpecErrorType.out, out_specs)
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def wrapped(*args):
|
||||
fun = lu.wrap_init(f)
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
try: in_specs_flat = broadcast_prefix(in_specs, args)
|
||||
except ValueError:
|
||||
e, *_ = prefix_errors(in_specs, args)
|
||||
raise e('shard_map in_specs') from None
|
||||
_check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat)
|
||||
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
||||
|
||||
@memoize
|
||||
def out_names_thunk():
|
||||
dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves)
|
||||
try: out_specs_flat = broadcast_prefix(out_specs, dummy)
|
||||
except ValueError:
|
||||
e, *_ = prefix_errors(out_specs, dummy)
|
||||
raise e('shard_map out_specs') from None
|
||||
return tuple(map(_canonicalize_spec, out_specs_flat))
|
||||
try:
|
||||
out_flat = shard_map_p.bind(
|
||||
flat_fun, *args_flat, mesh=mesh, in_names=in_names_flat,
|
||||
out_names_thunk=out_names_thunk, check_rep=check_rep)
|
||||
except _SpecError as e:
|
||||
fails, = e.args
|
||||
msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails)
|
||||
if any(fail is not no_fail and not fail.shape for fail in fails):
|
||||
msg += (" In particular, for rank 0 outputs which are not constant "
|
||||
"over the mesh, add at least one (singleton) axis to them so "
|
||||
"that they can be concatenated using out_specs.")
|
||||
raise ValueError(msg) from None
|
||||
except _RepError as e:
|
||||
fails, = e.args
|
||||
msg = _rep_error(f, mesh, out_tree(), out_specs, fails)
|
||||
raise ValueError(msg) from None
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return wrapped
|
||||
|
||||
# Internally use AxisNames = Dict[int, Tuple[AxisName, ...]], not PartitionSpecs
|
||||
AxisName = Hashable
|
||||
AxisNames = Dict[int, Tuple[AxisName, ...]] # TODO(mattjj): make it hashable
|
||||
def _canonicalize_spec(spec: PartitionSpec) -> AxisNames:
|
||||
if isinstance(spec, PartitionSpec):
|
||||
return {i: names if isinstance(names, tuple) else (names,)
|
||||
for i, names in enumerate(spec) if names is not None}
|
||||
else:
|
||||
return spec
|
||||
|
||||
# Error checking and messages
|
||||
|
||||
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
|
||||
|
||||
def _check_specs(error_type: SpecErrorType, specs: Any) -> None:
|
||||
if all(isinstance(p, PartitionSpec) for p in tree_leaves(specs)): return
|
||||
prefix = 'in' if error_type == SpecErrorType.input else 'out'
|
||||
msgs = [f" {prefix}_specs{key.pprint()} is {x} of type {type(x).__name__}, "
|
||||
for key, x in _generate_key_paths(specs) if not isinstance(x, P)]
|
||||
raise TypeError(
|
||||
f"shard_map {prefix}_specs argument must be a pytree of "
|
||||
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
|
||||
+ '\n\n'.join(msgs) + '\n\n'
|
||||
f"Check the {prefix}_specs values passed to shard_map.")
|
||||
|
||||
class NoFail: pass
|
||||
no_fail = NoFail()
|
||||
|
||||
def _check_specs_vs_args(
|
||||
f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs,
|
||||
in_specs_flat: List[P], xs: List) -> None:
|
||||
in_avals = map(shaped_abstractify, xs)
|
||||
fail = [a if not len(p) <= a.ndim else no_fail
|
||||
for p, a in zip(in_specs_flat, in_avals)]
|
||||
if any(f is not no_fail for f in fail):
|
||||
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
|
||||
raise ValueError(msg)
|
||||
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
|
||||
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
|
||||
for d, ns in names.items()) else no_fail
|
||||
for a, names in zip(in_avals, in_names_flat)]
|
||||
if any(f is not no_fail for f in fail):
|
||||
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
|
||||
raise ValueError(msg)
|
||||
|
||||
def _spec_rank_error(
|
||||
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
|
||||
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
|
||||
if error_type == SpecErrorType.input:
|
||||
prefix, base = 'in', 'args'
|
||||
ba = _try_infer_args(f, tree)
|
||||
else:
|
||||
prefix, base = 'out', f'{f.__name__}(*args)'
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
||||
if error_type == SpecErrorType.input and ba is not None:
|
||||
arg_key, *_ = fail_key.keys
|
||||
extra = (f", where {base}[{arg_key.key}] is bound to {f.__name__}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
|
||||
else:
|
||||
extra = ""
|
||||
msgs.append(
|
||||
f"{prefix}_specs{spec_key.pprint()} is {spec} which has length "
|
||||
f"{len(spec)}, but "
|
||||
f"{base}{fail_key.pprint()}{extra} has shape {aval.str_short()}, "
|
||||
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
|
||||
assert msgs
|
||||
msg = (f"shard_map applied to the function '{f.__name__}' was given an "
|
||||
f"{prefix}_specs entry which is too long to be compatible with the "
|
||||
f"corresponding {prefix}put value from the function:\n\n"
|
||||
+ '\n\n'.join(msgs) + '\n\n' +
|
||||
f"Entries in {prefix}_specs must be of length no greater than the "
|
||||
f"number of axes in the corresponding {prefix}put value.\n\n"
|
||||
f"Either revise the spec to be shorter, or modify '{f.__name__}' so "
|
||||
f"that its {prefix}puts have sufficient rank.")
|
||||
return msg
|
||||
|
||||
def _spec_divisibility_error(
|
||||
f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
|
||||
ba = _try_infer_args(f, tree)
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
||||
if ba is not None:
|
||||
arg_key, *_ = fail_key.keys
|
||||
extra = (f", where args[{arg_key.key}] is bound to {f.__name__}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
|
||||
names = _canonicalize_spec(spec)
|
||||
for d, ns in names.items():
|
||||
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
|
||||
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
|
||||
total = 'total ' if len(ns) > 1 else ''
|
||||
sz = prod(mesh.shape[n] for n in ns)
|
||||
msgs.append(
|
||||
f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} "
|
||||
f"corresponds to in_specs{spec_key.pprint()} of value {spec}, "
|
||||
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
|
||||
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
|
||||
f"{aval.shape[d]}")
|
||||
assert msgs
|
||||
msg = (f"shard_map applied to the function '{f.__name__}' was given argument "
|
||||
f"arrays with axis sizes that are not evenly divisible by the "
|
||||
f"corresponding mesh axis sizes:\n\n"
|
||||
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
|
||||
f"axis names {mesh.axis_names}.\n\n"
|
||||
+ '\n\n'.join(msgs) + '\n\n' +
|
||||
f"Array arguments' axis sizes must be evenly divisible by the mesh "
|
||||
f"axis or axes indicated by the corresponding elements of the "
|
||||
f"argument's in_specs entry. Consider checking that in_specs are "
|
||||
f"correct, and if so consider changing the mesh axis sizes or else "
|
||||
f"padding the input and adapting '{f.__name__}' appropriately.")
|
||||
return msg
|
||||
|
||||
def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
fails: List[Union[Set, NoFail]]) -> str:
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails):
|
||||
dst = _canonicalize_spec(spec)
|
||||
unmentioned = _unmentioned(mesh, dst)
|
||||
if len(unmentioned) > 1:
|
||||
need_rep = ','.join(map(str, unmentioned))
|
||||
got_rep = ','.join(map(str, rep))
|
||||
diff = ','.join(map(str, unmentioned - rep))
|
||||
msgs.append(
|
||||
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
|
||||
f"corresponding output value is replicated across mesh axes "
|
||||
f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, "
|
||||
f"which is missing the required axes {diff}")
|
||||
else:
|
||||
need_rep_, = unmentioned
|
||||
msgs.append(
|
||||
f"out_specs{spec_key.pprint()} is {spec} which implies that the "
|
||||
f"corresponding output value is replicated across mesh axis "
|
||||
f"'{need_rep_}', but could not infer replication over any axes")
|
||||
assert msgs
|
||||
msg = (f"shard_map applied to the function '{f.__name__}' was given "
|
||||
f"out_specs which require replication which can't be statically "
|
||||
f"inferred given the mesh:\n\n"
|
||||
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
|
||||
f"axis names {mesh.axis_names}.\n\n"
|
||||
+ '\n\n'.join(msgs) + '\n\n' +
|
||||
"Check if these output values are meant to be replicated over those "
|
||||
"mesh axes. If not, consider revising the corresponding out_specs "
|
||||
"entries. If so, consider ddisabling the check by passing the "
|
||||
"check_rep=False argument to shard_map.")
|
||||
return msg
|
||||
|
||||
def _unmentioned(mesh: Mesh, names: AxisNames) -> Set[AxisName]:
|
||||
return set(mesh.axis_names) - {n for ns in names.values() for n in ns}
|
||||
|
||||
def _try_infer_args(f, tree):
|
||||
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
|
||||
try:
|
||||
return inspect.signature(f).bind(*dummy_args)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
T = TypeVar('T')
|
||||
def _iter_paths(tree: PyTreeDef, specs: Specs, fails: List[Union[T, NoFail]]
|
||||
) -> List[Tuple[Tuple[KeyPath, P], Tuple[KeyPath, T]]]:
|
||||
failures = tree_unflatten(tree, fails)
|
||||
failures_aug = _generate_key_paths(failures)
|
||||
specs_ = tree_unflatten(tree_structure(specs), _generate_key_paths(specs))
|
||||
leaf = lambda x: type(x) is tuple and len(x) == 2 and type(x[1]) is P
|
||||
specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf)
|
||||
return [((spec_key, spec), (fail_key, fail_data))
|
||||
for (spec_key, spec), (fail_key, fail_data)
|
||||
in zip(specs_aug, failures_aug) if fail_data is not no_fail]
|
||||
|
||||
# Primitive
|
||||
|
||||
JaxType = Any
|
||||
MaybeTracer = Union[JaxType, Tracer]
|
||||
|
||||
class ShardMapPrimitive(core.Primitive):
|
||||
multiple_results = True
|
||||
|
||||
def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh,
|
||||
in_names: Tuple[AxisNames, ...],
|
||||
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
|
||||
check_rep: bool) -> Sequence[MaybeTracer]:
|
||||
top_trace = core.find_top_trace(args)
|
||||
fun, env_todo = process_env_traces(fun, top_trace.level, mesh,
|
||||
in_names, out_names_thunk, check_rep)
|
||||
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def new_out_names_thunk():
|
||||
out_names = out_names_thunk()
|
||||
_, xforms = env_todo()
|
||||
for t in xforms:
|
||||
out_names = t(out_names)
|
||||
return out_names
|
||||
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
|
||||
shard_map_p, fun, tracers, mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
|
||||
todos, _ = env_todo()
|
||||
return map(core.full_lower, core.apply_todos(todos, outs))
|
||||
|
||||
def get_bind_params(self, params):
|
||||
"""Goes from jaxpr form to python traceable form."""
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
|
||||
axes = new_params.pop('out_names')
|
||||
new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
||||
return [subfun], new_params
|
||||
|
||||
shard_map_p = ShardMapPrimitive('shard_map')
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
|
||||
*args: Any):
|
||||
outs = yield args, {}
|
||||
todos, out_names_transforms = [], []
|
||||
while True:
|
||||
tracers = [x for x in outs if isinstance(x, core.Tracer)
|
||||
and (level is None or x._trace.level > level)]
|
||||
if tracers:
|
||||
ans = max(tracers, key=op.attrgetter('_trace.level'))
|
||||
else:
|
||||
break
|
||||
trace = ans._trace.main.with_cur_sublevel()
|
||||
outs = map(trace.full_raise, outs)
|
||||
outs, (todo, xform) = trace.post_process_shard_map(
|
||||
outs, mesh, in_names, out_names_thunk, check_rep)
|
||||
todos.append(todo)
|
||||
out_names_transforms.append(xform)
|
||||
yield outs, (tuple(todos), tuple(out_names_transforms))
|
||||
|
||||
# Staging
|
||||
|
||||
def _shard_map_staging(
|
||||
trace: pe.DynamicJaxprTrace, prim: core.Primitive, fun: lu.WrappedFun,
|
||||
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
|
||||
in_names: Tuple[AxisNames, ...],
|
||||
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
|
||||
check_rep: bool,
|
||||
) -> Sequence[pe.DynamicJaxprTracer]:
|
||||
in_avals = [t.aval for t in in_tracers]
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(
|
||||
fun, trace.main, in_avals_)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
out_rep = _output_rep(mesh, jaxpr, in_rep)
|
||||
_check_reps(mesh, out_names_thunk(), out_rep)
|
||||
out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_)
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals]
|
||||
invars = map(trace.getvar, in_tracers)
|
||||
constvars = map(trace.getvar, map(trace.instantiate_const, consts))
|
||||
outvars = map(trace.makevar, out_tracers)
|
||||
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
params = dict(mesh=mesh, in_names=in_names_staged,
|
||||
out_names=out_names_thunk(), jaxpr=jaxpr, check_rep=check_rep)
|
||||
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
|
||||
jaxpr.effects, source_info)
|
||||
trace.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
||||
|
||||
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape)))
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj): add table with handlers
|
||||
|
||||
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape)),
|
||||
named_shape={k: v for k, v in aval.named_shape.items()
|
||||
if k not in mesh.shape})
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj): add table with handlers
|
||||
|
||||
# Type-checking
|
||||
|
||||
def _shard_map_typecheck(*in_atoms, jaxpr, mesh, in_names, out_names,
|
||||
check_rep):
|
||||
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
|
||||
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
|
||||
raise core.JaxprTypeError("shard_map argument avals not compatible with "
|
||||
"jaxpr binder avals and in_names")
|
||||
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
|
||||
core.check_jaxpr(jaxpr)
|
||||
if check_rep:
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
out_rep = _output_rep(mesh, jaxpr, in_rep)
|
||||
for rep, dst in zip(out_rep, out_names):
|
||||
if not _valid_repeats(mesh, rep, dst):
|
||||
raise core.JaxprTypeError("shard_map can't prove output is sufficiently "
|
||||
"replicated")
|
||||
out_avals_sharded = [x.aval for x in jaxpr.outvars]
|
||||
out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded)
|
||||
return out_avals, jaxpr.effects
|
||||
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
|
||||
|
||||
def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> Set[AxisName]:
|
||||
return set(mesh.axis_names) - set(n for ns in names.values() for n in ns)
|
||||
|
||||
def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[Set[AxisName]],
|
||||
) -> Sequence[Set[AxisName]]:
|
||||
env: Dict[core.Var, Set[AxisName]] = {}
|
||||
|
||||
def read(x: core.Atom) -> Set[AxisName]:
|
||||
return env[x] if type(x) is core.Var else set(mesh.axis_names)
|
||||
|
||||
def write(v: core.Var, val: Set[AxisName]) -> None:
|
||||
env[v] = val
|
||||
|
||||
map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
|
||||
map(write, jaxpr.invars, in_rep)
|
||||
for e in jaxpr.eqns:
|
||||
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
|
||||
out_rep = rule(mesh, *map(read, e.invars), **e.params)
|
||||
if e.primitive.multiple_results:
|
||||
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
|
||||
map(write, e.outvars, out_rep)
|
||||
else:
|
||||
write(e.outvars[0], out_rep)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
def _valid_repeats(mesh: Mesh, rep: Set[AxisName], dst: AxisNames) -> bool:
|
||||
return _unmentioned(mesh, dst).issubset(rep)
|
||||
|
||||
# Lowering
|
||||
|
||||
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
check_rep):
|
||||
del check_rep
|
||||
sharded_avals = [v.aval for v in jaxpr.invars]
|
||||
in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in,
|
||||
sharded_avals, in_nodes)
|
||||
new_axis_context = mlir.SPMDAxisContext(mesh, frozenset(mesh.axis_names))
|
||||
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
||||
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
|
||||
out_nodes_, _ = mlir.jaxpr_subcomp(sub_ctx, jaxpr, mlir.TokenSet(),
|
||||
(), *in_nodes_,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
sharded_avals = [v.aval for v in jaxpr.outvars]
|
||||
return map(partial(_xla_unshard, mesh), out_names, sharded_avals,
|
||||
ctx.avals_out, out_nodes_)
|
||||
mlir.register_lowering(shard_map_p, _shard_map_lowering)
|
||||
|
||||
def _xla_shard(mesh, names, aval_in, aval_out, x):
|
||||
manual_proto = pxla._manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
sharding_proto = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
||||
aval_in, axes).sharding_proto()
|
||||
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=set())
|
||||
return [mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, set())]
|
||||
|
||||
def _xla_unshard(mesh, names, aval_in, aval_out, xs):
|
||||
x, = xs
|
||||
manual_proto = pxla._manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
sharding_proto = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
||||
aval_out, axes).sharding_proto()
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, set())
|
||||
|
||||
# Eager evaluation
|
||||
|
||||
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
check_rep):
|
||||
del prim
|
||||
args = map(partial(_unmatch_spec, mesh), in_names, args)
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
with core.new_base_main(ShardMapTrace, mesh=mesh) as main:
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
outs_, out_rep = unzip2((t.val, t.rep) for t in out_tracers)
|
||||
del main, t, in_tracers, ans, out_tracers
|
||||
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
|
||||
_check_names(out_names_thunk(), out_avals)
|
||||
if check_rep: _check_reps(mesh, out_names_thunk(), out_rep)
|
||||
return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_)
|
||||
core.EvalTrace.process_shard_map = _shard_map_impl
|
||||
|
||||
def _names_to_pspec(names: AxisNames) -> PartitionSpec:
|
||||
ndmin = max(names) + 1 if names else 0
|
||||
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
||||
|
||||
def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType:
|
||||
with core.eval_context():
|
||||
return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x)
|
||||
|
||||
def _unmatch(mesh, src_tup, x):
|
||||
src = _names_to_pspec(dict(src_tup))
|
||||
dst = P(mesh.axis_names)
|
||||
return shard_map(_add_singleton, mesh, (src,), dst)(x)
|
||||
|
||||
def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray]
|
||||
) -> None:
|
||||
fail = [a if not max(n, default=0) < a.ndim else no_fail
|
||||
for n, a in zip(names, avals)]
|
||||
if any(f is not no_fail for f in fail): raise _SpecError(fail)
|
||||
class _SpecError(Exception): pass
|
||||
|
||||
def _check_reps(mesh, names, reps):
|
||||
fail = [r if not _valid_repeats(mesh, r, n) else no_fail
|
||||
for n, r in zip(names, reps)]
|
||||
if any(f is not no_fail for f in fail): raise _RepError(fail)
|
||||
class _RepError(Exception): pass
|
||||
|
||||
def _match_spec(mesh: Mesh, rep: Set[AxisName], dst: AxisNames, x: JaxType
|
||||
) -> JaxType:
|
||||
with core.eval_context():
|
||||
return jax.jit(HashablePartial(_match, mesh, tuple(dst.items())))(x)
|
||||
|
||||
def _match(mesh, dst_tup, x):
|
||||
src = P(mesh.axis_names)
|
||||
dst = _names_to_pspec(dict(dst_tup))
|
||||
return shard_map(_rem_singleton, mesh, (src,), dst, check_rep=False)(x)
|
||||
|
||||
def _rem_singleton(x): return x.reshape(x.shape[1:])
|
||||
def _add_singleton(x): return x.reshape(1, *x.shape)
|
||||
|
||||
class ShardMapTrace(core.Trace):
|
||||
mesh: Mesh
|
||||
|
||||
def __init__(self, *args, mesh):
|
||||
super().__init__(*args)
|
||||
self.mesh = mesh
|
||||
|
||||
def pure(self, val):
|
||||
val_ = _unmatch_spec(self.mesh, {}, val)
|
||||
return ShardMapTracer(self, set(self.mesh.axis_names), val_)
|
||||
|
||||
def sublift(self, tracer):
|
||||
return ShardMapTracer(self, tracer.rep, tracer.val)
|
||||
|
||||
def process_primitive(self, prim, tracers, params):
|
||||
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
|
||||
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
|
||||
with core.eval_context(), jax.disable_jit(False):
|
||||
out_vals = jax.jit(f)(*in_vals)
|
||||
rule = _rep_rules.get(prim, partial(_rep_rule, prim))
|
||||
out_rep = rule(self.mesh, *in_rep, **params)
|
||||
if prim.multiple_results:
|
||||
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
|
||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
||||
return ShardMapTracer(self, out_rep, out_vals)
|
||||
|
||||
def process_call(self, call_primitive, fun, tracers, params):
|
||||
if call_primitive is not xla.xla_call_p: raise NotImplementedError
|
||||
fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit
|
||||
bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr())
|
||||
fake_primitive = pxla._FakePrimitive(multiple_results=True, bind=bind)
|
||||
_rep_rules[fake_primitive] = lambda *_, **__: set()
|
||||
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
|
||||
out_vals = [t.val for t in out_tracers_]
|
||||
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
|
||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _grab_jaxpr_shadily(*args):
|
||||
out = yield args, {}
|
||||
main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me
|
||||
jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out)
|
||||
yield out, jaxpr
|
||||
|
||||
class ShardMapTracer(core.Tracer):
|
||||
rep: Set[AxisName]
|
||||
val: JaxType
|
||||
|
||||
def __init__(self, trace, rep, val):
|
||||
self._trace = trace
|
||||
self.rep = rep
|
||||
self.val = val
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
aval = core.get_aval(self.val)
|
||||
if (isinstance(aval, core.ConcreteArray) and
|
||||
self.rep == set(self._trace.mesh.axis_names)):
|
||||
with core.eval_context():
|
||||
return core.get_aval(self.val[0])
|
||||
else:
|
||||
aval = core.raise_to_shaped(aval)
|
||||
return core.mapped_aval(self._trace.mesh.size, 0, aval)
|
||||
|
||||
def full_lower(self) -> ShardMapTracer:
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
with core.eval_context():
|
||||
blocks = list(self.val)
|
||||
mesh = self._trace.mesh
|
||||
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
|
||||
return '\n'.join(
|
||||
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
|
||||
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
|
||||
|
||||
def _prim_applier(prim, params_tup, mesh, *args):
|
||||
def apply(*args):
|
||||
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
|
||||
return tree_map(_add_singleton, outs)
|
||||
return shard_map(apply, mesh, P(mesh.axis_names), P(mesh.axis_names))(*args)
|
||||
|
||||
# Static replication checking
|
||||
|
||||
def _rep_rule(prim, mesh, *in_rep, **params):
|
||||
raise NotImplementedError(f"no replication rule for {prim}")
|
||||
|
||||
_rep_rules: Dict[core.Primitive, Callable] = {}
|
||||
register_rule = lambda prim: lambda rule: _rep_rules.setdefault(prim, rule)
|
||||
register_standard = lambda prim: _rep_rules.setdefault(prim, _standard_rep_rule)
|
||||
|
||||
def _standard_rep_rule(_, *in_rep, **__):
|
||||
return set.intersection(*in_rep)
|
||||
|
||||
for o in lax.__dict__.values():
|
||||
if isinstance(o, core.Primitive): register_standard(o)
|
||||
register_standard(ad_util.add_any_p)
|
||||
|
||||
register_standard(lax_parallel.ppermute_p) # doesn't change replication
|
||||
|
||||
@register_rule(lax_parallel.psum_p)
|
||||
def _psum_rule(_, *in_rep, axes, axis_index_groups):
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
axes = (axes,) if not isinstance(axes, tuple) else axes
|
||||
return [r | set(axes) for r in in_rep] # introduces replication
|
||||
|
||||
@register_rule(lax_parallel.all_gather_p)
|
||||
def _all_gather_rule(_, in_rep, *, all_gather_dimension, axis_name, axis_size,
|
||||
axis_index_groups, tiled):
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
if not tiled: raise NotImplementedError
|
||||
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
||||
return in_rep | set(axis_name) # introduces replication
|
||||
|
||||
@register_rule(lax_parallel.reduce_scatter_p)
|
||||
def _reduce_scatter_rule(_, in_rep, *, scatter_dimension, axis_name, axis_size,
|
||||
axis_index_groups, tiled):
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
if not tiled: raise NotImplementedError
|
||||
return in_rep - {axis_name} # removes replication
|
||||
|
||||
@register_rule(lax_parallel.all_to_all_p)
|
||||
def _all_to_all_rule(_, in_rep, *, split_axis, concat_axis, axis_name,
|
||||
axis_index_groups):
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
return in_rep - {axis_name} # removes replication
|
||||
|
||||
@register_rule(pjit.pjit_p)
|
||||
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||
|
||||
# Batching
|
||||
|
||||
def _shard_map_batch(
|
||||
trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun,
|
||||
in_tracers: Sequence[batching.BatchTracer], mesh: Mesh,
|
||||
in_names: Tuple[AxisNames, ...],
|
||||
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
|
||||
check_rep: bool) -> Sequence[batching.BatchTracer]:
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers)
|
||||
if all(bdim is batching.not_mapped for bdim in in_dims):
|
||||
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=out_names_thunk, check_rep=check_rep)
|
||||
if trace.spmd_axis_name is not None:
|
||||
raise NotImplementedError # TODO add named axis to specs
|
||||
if any(isinstance(d, batching.ConcatAxis) for d in in_dims):
|
||||
raise NotImplementedError
|
||||
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims))
|
||||
new_in_names = [{ax + (d is not batching.not_mapped and ax <= d): names[ax] # type: ignore
|
||||
for ax in names} for names, d in zip(in_names, in_dims)]
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def new_out_names_thunk():
|
||||
out_names = out_names_thunk()
|
||||
return [{ax + (d is not batching.not_mapped and ax <= d): names[ax]
|
||||
for ax in names} for names, d in zip(out_names, out_dims())]
|
||||
new_params = dict(mesh=mesh, in_names=new_in_names,
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
|
||||
out_vals = prim.bind(fun, *in_vals, **new_params)
|
||||
make_tracer = partial(batching.BatchTracer, trace,
|
||||
source_info=source_info_util.current())
|
||||
return map(make_tracer, out_vals, out_dims())
|
||||
batching.BatchTrace.process_shard_map = _shard_map_batch
|
||||
|
||||
# Autodiff
|
||||
|
||||
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep):
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
which_nz = [ type(t) is not ad.Zero for t in tangents]
|
||||
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
|
||||
args, in_tree = tree_flatten((primals, tangents))
|
||||
f_jvp = ad.jvp_subtrace(f, trace.main)
|
||||
f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp)
|
||||
tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz]
|
||||
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def new_out_names_thunk():
|
||||
out_ax = out_names_thunk()
|
||||
return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz))
|
||||
params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names),
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
|
||||
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
|
||||
result = shard_map_p.bind(f_jvp, *args, **params)
|
||||
primal_out, tangent_out = tree_unflatten(out_tree(), result)
|
||||
tangent_out = [ad.Zero(ad.get_aval(p).at_least_vspace()) if t is None else t
|
||||
for p, t in zip(primal_out, tangent_out)]
|
||||
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
|
||||
ad.JVPTrace.process_shard_map = _shard_map_jvp
|
||||
|
||||
def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep):
|
||||
del mesh, in_names, out_names_thunk, check_rep
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
||||
out, treedef = tree_flatten((primals, tangents))
|
||||
tangents_nz = [type(t) is not ad.Zero for t in tangents]
|
||||
m = trace.main
|
||||
def todo(x):
|
||||
primals, tangents = tree_unflatten(treedef, x)
|
||||
return map(partial(ad.JVPTracer, m.with_cur_sublevel()), primals, tangents)
|
||||
def out_names_transform(out_names):
|
||||
return (*out_names, *(n for n, nz in zip(out_names, tangents_nz) if nz))
|
||||
return out, (todo, out_names_transform)
|
||||
ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process
|
||||
|
||||
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep):
|
||||
in_pvals = [t.pval for t in tracers]
|
||||
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
|
||||
f = pe.trace_to_subjaxpr_nounits(f, trace.main, False)
|
||||
f = _promote_scalar_residuals(f)
|
||||
f_known, aux = pe.partial_eval_wrapper_nounits(
|
||||
f, (*in_knowns,), (*in_avals_sharded,))
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def known_out_names():
|
||||
out_knowns, _, jaxpr, _ = aux()
|
||||
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
|
||||
assert not any(not v.aval.shape for v in jaxpr.constvars)
|
||||
res_names = ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
|
||||
return (*out_known_names, *res_names)
|
||||
|
||||
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
|
||||
out_names_thunk=known_out_names, check_rep=check_rep)
|
||||
out = shard_map_p.bind(f_known, *in_consts, **known_params)
|
||||
out_knowns, out_avals_sharded, jaxpr, env = aux()
|
||||
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
|
||||
unk_in_names = (({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
|
||||
+ (*unk_in_names,))
|
||||
const_tracers = map(trace.new_instantiated_const, res)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
||||
unk_params = dict(mesh=mesh, in_names=unk_in_names,
|
||||
out_names=unk_out_names, jaxpr=jaxpr, check_rep=False)
|
||||
out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type]
|
||||
out_tracers, shard_map_p, unk_params,
|
||||
jaxpr.effects, source_info_util.current())
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return pe.merge_lists(out_knowns, out_tracers, out_consts)
|
||||
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
||||
|
||||
def _shard_map_partial_eval_post_process(
|
||||
trace, tracers, mesh, in_names, out_names_thunk, check_rep):
|
||||
del check_rep
|
||||
unk_tracers = [t for t in tracers if not t.is_known()]
|
||||
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
|
||||
out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers])
|
||||
out = [*consts, *res]
|
||||
main = trace.main
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr)
|
||||
|
||||
def todo(out):
|
||||
trace = main.with_cur_sublevel()
|
||||
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
const_tracers = map(trace.new_instantiated_const, res)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
|
||||
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
|
||||
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
|
||||
out_names=(*out_names_unknown,), check_rep=False)
|
||||
|
||||
out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
name_stack = trace._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
|
||||
shard_map_p, staged_params, jaxpr.effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
def out_names_transform(out_names):
|
||||
nonlocal out_names_unknown
|
||||
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
|
||||
return (*out_names_known,) + ({0: (*mesh.axis_names,)},) * len(jaxpr.constvars)
|
||||
out_names_unknown: Optional[list] = None
|
||||
|
||||
return out, (todo, out_names_transform)
|
||||
pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process
|
||||
|
||||
@lu.transformation
|
||||
def _promote_scalar_residuals(*args, **kwargs):
|
||||
jaxpr, (out_pvals, out_consts, env) = yield args, kwargs
|
||||
which_scalar = [isinstance(v.aval, core.ShapedArray) and not v.aval.shape
|
||||
for v in jaxpr.constvars]
|
||||
out_consts_ = [jax.lax.broadcast(x, (1,)) if scalar else x
|
||||
for x, scalar in zip(out_consts, which_scalar)]
|
||||
|
||||
@lu.wrap_init
|
||||
def fun(*args):
|
||||
out_consts = [x.reshape(*x.shape[1:]) if scalar else x
|
||||
for x, scalar in zip(out_consts_, which_scalar)]
|
||||
return core.eval_jaxpr(jaxpr, out_consts, *args)
|
||||
in_avals = [v.aval for v in jaxpr.invars]
|
||||
jaxpr, _, out_consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
yield jaxpr, (out_pvals, out_consts, env)
|
||||
|
||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
check_rep):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = [x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
for ns, x in zip(in_names, args)]
|
||||
all_args, in_tree = tree_flatten((out_cts, args))
|
||||
|
||||
@lu.wrap_init
|
||||
def fun_trans(out_cts, args):
|
||||
res, undefs = partition_list(map(ad.is_undefined_primal, args), args)
|
||||
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
|
||||
pe.close_jaxpr(jaxpr), map(ad.is_undefined_primal, args), False)
|
||||
res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res)
|
||||
out = ad.backward_pass(jaxpr_unknown.jaxpr, set(), False, (),
|
||||
(*res_reshaped, *undefs), out_cts)
|
||||
return [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
|
||||
for ns, x in zip(in_names, out)]
|
||||
|
||||
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
|
||||
fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree)
|
||||
|
||||
new_in_names = \
|
||||
[n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \
|
||||
[n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal]
|
||||
|
||||
def new_out_names_thunk():
|
||||
return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz)
|
||||
|
||||
out_flat = shard_map_p.bind(
|
||||
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
|
||||
|
||||
def _shard_map_axis_subst(params, subst, traverse):
|
||||
if 'jaxpr' not in params:
|
||||
return params
|
||||
if not traverse:
|
||||
return params
|
||||
def shadowed_subst(name):
|
||||
return (name,) if name in params['mesh'].shape else subst(name)
|
||||
with core.extend_axis_env_nd(params['mesh'].shape.items()):
|
||||
new_jaxpr = core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst)
|
||||
return dict(params, jaxpr=new_jaxpr)
|
||||
core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst
|
||||
|
||||
# TODO(mattjj): move this to _src/util.py
|
||||
class HashablePartial:
|
||||
def __init__(self, f, *args, **kwargs):
|
||||
self.f = f
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(other) is HashablePartial and
|
||||
self.f.__code__ == other.f.__code__ and
|
||||
self.args == other.args and self.kwargs == other.kwargs)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.f.__code__, self.args, tuple(self.kwargs.items())))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.f(*self.args, *args, **self.kwargs, **kwargs)
|
@ -542,6 +542,7 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
|
||||
# this expansion.
|
||||
for p in xb.expand_platform_alias(platform):
|
||||
_platform_specific_lowerings[p][prim] = rule
|
||||
return rule
|
||||
|
||||
|
||||
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
|
||||
|
@ -426,7 +426,9 @@ def jaxpr_collectives(jaxpr):
|
||||
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
|
||||
xla_call = xla_call_p.bind
|
||||
|
||||
def _xla_call_partial_eval_update_params(params, kept_inputs, num_new_inputs):
|
||||
def _xla_call_partial_eval_update_params(
|
||||
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
|
||||
) -> core.ParamDict:
|
||||
donated_invars = params['donated_invars']
|
||||
if not kept_inputs and donated_invars:
|
||||
# JaxprTrace.post_process_call creates a call with no input tracers
|
||||
|
@ -1007,6 +1007,11 @@ jax_test(
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "shard_map_test",
|
||||
srcs = ["shard_map_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "clear_backends_test",
|
||||
srcs = ["clear_backends_test.py"],
|
||||
|
382
tests/shard_map_test.py
Normal file
382
tests/shard_map_test.py
Normal file
@ -0,0 +1,382 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.pjit import PartitionSpec as P
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Helper for some tests.
|
||||
def create_inputs(a_sharding, b_sharding):
|
||||
x, y, z = 2, 2, 2 # pylint: disable=invalid-name
|
||||
devices = np.array(jax.devices()[:x * y * z]).reshape((x, y, z))
|
||||
mesh = Mesh(devices, axis_names=('x', 'y', 'z'))
|
||||
b, e, f = 8, 8, 8 # pylint: disable=invalid-name
|
||||
m1 = jax.device_put(
|
||||
jnp.arange(b * e).reshape((b, e)),
|
||||
jax.sharding.NamedSharding(mesh, a_sharding))
|
||||
m2 = jax.device_put(
|
||||
jnp.arange(e * f).reshape((e, f)),
|
||||
jax.sharding.NamedSharding(mesh, b_sharding))
|
||||
return mesh, m1, m2
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
prev_xla_flags = None
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
prev_xla_flags = os.getenv("XLA_FLAGS")
|
||||
flags_str = prev_xla_flags or ""
|
||||
# Don't override user-specified device count, or other XLA flags.
|
||||
if "xla_force_host_platform_device_count" not in flags_str:
|
||||
os.environ["XLA_FLAGS"] = (flags_str +
|
||||
" --xla_force_host_platform_device_count=8")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
if len(jax.devices()) < 8:
|
||||
raise unittest.SkipTest("tests require 8 devices")
|
||||
if not jax.config.jax_array:
|
||||
raise unittest.SkipTest("test requires jax_array")
|
||||
|
||||
# Reset to previous configuration in case other test modules will be run.
|
||||
def tearDownModule():
|
||||
if prev_xla_flags is None:
|
||||
del os.environ["XLA_FLAGS"]
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
|
||||
class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_identity(self):
|
||||
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
||||
assert a.device_buffers[0].shape == (4, 2)
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
@jax.jit
|
||||
def fwd(a):
|
||||
c = shard_map(
|
||||
lambda x: x,
|
||||
mesh,
|
||||
in_specs=(P('z', ('x', 'y')),),
|
||||
out_specs=P('z', ('x', 'y')))(a)
|
||||
return c
|
||||
|
||||
c = fwd(a)
|
||||
self.assertEqual(c.device_buffers[0].shape, (4, 2))
|
||||
|
||||
def test_all_gather(self):
|
||||
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
||||
assert a.device_buffers[0].shape == (4, 2)
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y')))
|
||||
def fwd(a):
|
||||
return lax.all_gather(a, 'z', axis=0, tiled=True)
|
||||
|
||||
c = fwd(a)
|
||||
self.assertEqual(c.device_buffers[0].shape, (8, 2))
|
||||
|
||||
def test_matmul_partial(self):
|
||||
raise unittest.SkipTest("invalid replication asserted by out_spec?")
|
||||
|
||||
mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
|
||||
assert a.device_buffers[0].shape == (4, 4)
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('z', 'y'), P('y', None)), out_specs=P('z', None))
|
||||
def fwd(a):
|
||||
c = jnp.matmul(a, b) # [B.z, F] {y.unreduced}
|
||||
return c
|
||||
|
||||
c = fwd(a)
|
||||
self.assertEqual(c.device_buffers[0].shape, (4, 8))
|
||||
|
||||
def test_matmul_reduce_scatter(self):
|
||||
mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
|
||||
assert a.device_buffers[0].shape == (4, 4)
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('z', 'y'), P('y', None)),
|
||||
out_specs=P(('z', 'y'), None))
|
||||
def fwd(a, b):
|
||||
c = jnp.matmul(a, b) # [B.z, F] {y.unreduced}
|
||||
return lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True)
|
||||
|
||||
c = fwd(a, b)
|
||||
self.assertEqual(c.device_buffers[0].shape, (2, 8))
|
||||
|
||||
def test_collective_permute(self):
|
||||
devices = np.array(jax.devices())
|
||||
mesh = Mesh(devices, axis_names=('x'))
|
||||
a = jax.device_put(
|
||||
jnp.arange(8 * 8).reshape((8, 8)),
|
||||
jax.sharding.NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('x', None),),
|
||||
out_specs=P('x', None))
|
||||
def fwd(a):
|
||||
axis_size = lax.psum(1, 'x')
|
||||
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
||||
return lax.ppermute(a, 'x', perm=perm)
|
||||
|
||||
c = fwd(a)
|
||||
self.assertAllClose(c[1, :], a[0, :])
|
||||
|
||||
def test_all_to_all(self):
|
||||
devices = np.array(jax.devices())
|
||||
mesh = Mesh(devices, axis_names=('x'))
|
||||
a = jax.device_put(
|
||||
jnp.arange(8 * 8).reshape((8, 8)),
|
||||
jax.sharding.NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('x', None),), out_specs=P(None, 'x'))
|
||||
def fwd(a):
|
||||
return lax.all_to_all(a, 'x', split_axis=1, concat_axis=1, tiled=True)
|
||||
|
||||
c = fwd(a)
|
||||
assert (c == jnp.reshape(a.T, (1, 64))).all()
|
||||
|
||||
def test_eager_repr(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
s = None
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y'))
|
||||
def f(x):
|
||||
nonlocal s
|
||||
s = str(x)
|
||||
return x
|
||||
_ = f(np.arange(8 * 8.).reshape(8, 8))
|
||||
|
||||
self.assertIsInstance(s, str)
|
||||
self.assertIn('at mesh coordinates', s)
|
||||
|
||||
def test_jvp_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
||||
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
|
||||
args = np.arange(4 * 4.).reshape(4, 4),
|
||||
jtu.check_grads(g, args, 2, ['fwd'])
|
||||
jtu.check_grads(jax.jit(g), args, 2, ['fwd'])
|
||||
|
||||
def test_linearize_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
||||
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
|
||||
x = np.arange(4 * 4.).reshape(4, 4)
|
||||
|
||||
y, y_dot = jax.jvp(g, [x], [x])
|
||||
|
||||
y_, g_lin = jax.linearize(g, x)
|
||||
y_dot_ = g_lin(x)
|
||||
|
||||
self.assertAllClose(y, y_, check_dtypes=False)
|
||||
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
||||
|
||||
def test_linearize_basic_repres(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh,
|
||||
in_specs=(P('x',),), out_specs=P('x',))
|
||||
x = np.arange(4.)
|
||||
|
||||
y, y_dot = jax.jvp(g, [x], [x])
|
||||
|
||||
y_, g_lin = jax.linearize(g, x)
|
||||
y_dot_ = g_lin(x)
|
||||
|
||||
self.assertAllClose(y, y_, check_dtypes=False)
|
||||
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
||||
|
||||
def test_linearize_basic_repres_jit(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
||||
in_specs=(P('x',),), out_specs=P('x',))
|
||||
x = np.arange(4.)
|
||||
|
||||
y, y_dot = jax.jvp(g, [x], [x])
|
||||
|
||||
y_, g_lin = jax.linearize(g, x)
|
||||
y_dot_ = g_lin(x)
|
||||
|
||||
self.assertAllClose(y, y_, check_dtypes=False)
|
||||
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
||||
|
||||
def test_replication_checker_eager(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
x = np.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def f(x):
|
||||
return 2 * x
|
||||
def g(x):
|
||||
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'statically inferred'):
|
||||
g(x)
|
||||
|
||||
def f2(x):
|
||||
return jax.lax.psum(x, 'x')
|
||||
def g2(x):
|
||||
return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
||||
_ = g2(x) # doesn't crash
|
||||
|
||||
def test_replication_checker_jit(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
x = np.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def f(x):
|
||||
return 2 * x
|
||||
def g(x):
|
||||
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'statically inferred'):
|
||||
jax.jit(g)(x)
|
||||
|
||||
def f2(x):
|
||||
return jax.lax.psum(x, 'x')
|
||||
def g2(x):
|
||||
return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
||||
_ = jax.jit(g2)(x) # doesn't crash
|
||||
|
||||
def test_process_env_traces(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
||||
x = np.arange(8.)
|
||||
|
||||
def g(x):
|
||||
y = (3. * x).sum()
|
||||
z = shard_map(lambda x: 2 * x * y, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.))
|
||||
return z
|
||||
|
||||
jtu.check_grads(g, (x,), modes=['fwd'], order=2)
|
||||
|
||||
def test_eager_control_flow(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
x = jnp.arange(2 * 2.).reshape(2, 2)
|
||||
|
||||
def f(x):
|
||||
y = jax.lax.psum(x, ('x', 'y'))
|
||||
if y < 0:
|
||||
return x
|
||||
else:
|
||||
return -x
|
||||
|
||||
def g(x):
|
||||
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x)
|
||||
y = g(x)
|
||||
self.assertAllClose(y, -x, check_dtypes=False)
|
||||
|
||||
def test_outer_jit_detects_shard_map_mesh(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
|
||||
_ = jax.jit(f)(jnp.array(2.0)) # doesnt crash
|
||||
|
||||
def test_vmap_basic(self):
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
raise unittest.SkipTest("pjit batcher error") # TODO(mattjj)
|
||||
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def g(x):
|
||||
return shard_map(lambda x: 2. * x, mesh,
|
||||
in_specs=P('y'), out_specs=P('y'))(x)
|
||||
y = jax.vmap(g, axis_name='x')(x)
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
def test_tree_prefix_error(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y'))
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
with self.assertRaisesRegex(ValueError, r'shard_map in_specs\[0\]'):
|
||||
f([x, x])
|
||||
|
||||
def test_rank_errors(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
|
||||
def foo():
|
||||
return {'hi': [3.]}
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'which has length 1'):
|
||||
shard_map(foo, mesh=mesh, in_specs=(), out_specs={'hi': P('x')})()
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'which has length 1'):
|
||||
jax.jit(lambda: shard_map(foo, mesh=mesh,
|
||||
in_specs=(), out_specs={'hi': P('x')})())()
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'which has rank 0'):
|
||||
shard_map(foo, mesh=mesh, in_specs=({'hi': P('x')},), out_specs=())(
|
||||
{'hi': [jnp.array(3.)]})
|
||||
|
||||
def test_reverse_mode_ad(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('x',), P(None)), out_specs=P('x',))
|
||||
def f(x, y):
|
||||
return jnp.sin(x) + 3 + jnp.tan(2.) * jnp.cos(x) + y
|
||||
|
||||
x = jnp.arange(8.) / 10.
|
||||
y = jnp.arange(4.) / 10.
|
||||
jtu.check_grads(f, (x, y), modes=['fwd', 'rev'], order=2)
|
||||
|
||||
def test_post_process(self):
|
||||
# JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
|
||||
def f(x):
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def g(y):
|
||||
return jnp.sin(y) * jnp.sin(x).sum()
|
||||
return g(jnp.arange(8.))
|
||||
|
||||
x = jnp.arange(8.)
|
||||
_, f_lin = jax.linearize(f, x)
|
||||
y_dot = f_lin(x)
|
||||
|
||||
y_dot_expected = jnp.sin(jnp.arange(8.)) * (jnp.cos(x) * x).sum()
|
||||
self.assertAllClose(y_dot, y_dot_expected, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user