remove busted example from shmap jep

This commit is contained in:
Matthew Johnson 2024-11-01 16:37:46 +00:00
parent 2a41c04fef
commit 26f70c9c16

View File

@ -3,6 +3,9 @@
*January 2023*
**This was the design doc proposing `shard_map`. You may instead want
[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).**
## Motivation
JAX supports two schools of thought for multi-device programming:
@ -374,114 +377,8 @@ One philosophy is: it is almost always simpler to write a program in `jit==pjit`
— but if a given part of the program is less optimized by the compiler than it
could be, drop into `shmap`!
### A realistic transformer example
In fact, we can implement a simple version of the ["collective
matmul"](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) algorithm
recently introduced in XLA to overlap communication and computation using `shmap`
and 30 lines of Python. The basic idea of the algorithm can be grasped with a
simple example.
Suppose we want to compute `C = A @ B` where `A` is sharded by a 1D mesh on the
0-th dimension while `B` and `C` are replicated.
```python
M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))
mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))
@jax.jit
def f(lhs, rhs):
return lhs @ rhs
C = f(A_x, B)
```
A profile shows the blocking all-gather across 8 devices before the matmul can
start. This is suboptimal because `A` is sharded on a non-contracting dimension,
and each shard of `A` can be matmul'ed with `B` independently and this chunked
computation can be overlapped with fetching of the next shard of `A` from
another device.
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507011-e854fb11-43d5-484d-993b-19a3349ed4b9.png">
This overlap can be implemented using `shmap` and explicit collectives.
```python
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.psum(1, axis_name='i')
axis_index = jax.lax.axis_index(axis_name='i')
chunk_size = lhs.shape[0]
def f(i, carrys):
accum, lhs = carrys
# matmul for a chunk
update = lhs @ rhs
# circular shift to the left
lhs = jax.lax.ppermute(
lhs,
axis_name='i',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
# device 0 computes chunks 0, 1, ...
# device 1 computes chunks 1, 2, ...
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum, lhs
accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
# fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
# accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
for i in range(0, axis_size - 1):
accum, lhs = f(i, (accum, lhs))
# compute the last chunk, without the ppermute
update = lhs @ rhs
i = axis_size - 1
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum
```
```
jit_sharded_f = jax.jit(shard_map(
collective_matmul_allgather_lhs_non_contracting, mesh,
in_specs=(P('i', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)
```
A profile shows that the all-gather is gone, and replaced with overlapped matmul
with async collective permute. This profile matches very closely with the
collective matmul paper result.
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507064-139f032c-d869-4b67-9e11-1587d4fd2de9.png">
This collective matmul technique can be used to speed up feedforward blocks in
transformer layers. This typically consists of two matrix multiplications
followed by a `ReduceScatter` (to resolve partial sums from a parallelized
matrix multiplication) and preceded by an `AllGather` (to collect the sharded
dimensions along some axes and allow partial sum computation). Together, the
`ReduceScatter` from one layer and the `AllGather` for the next amount to an
`AllReduce`.
In a typical profile, the two matmuls will be followed by an `AllReduce`, and
they will not be overlapped. Collective matmul can be used to achieve the
overlap, but is difficult to trigger, has a minimum slice size and does not yet
cover all topologies, tensor shapes and variants of collective matmul (i.e
latency and throughput optimized variants). [In a recent
paper](https://arxiv.org/abs/2211.05102), we found a ~40% gain in many
circumstances from manually implementing collective matmul variants in `shmap`
style.
But it isnt always more complex! We expect this to be a much more natural way
to think about pipelined computation, and plan to do some demos of that soon!
### Another realistic example
### A realistic example
Here's how `shmap` might look in a transformer layer pass with a 2D weight
gathered pattern ([paper](https://arxiv.org/abs/2211.05102), Sec 3.2.3 on p. 5):