mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
remove busted example from shmap jep
This commit is contained in:
parent
2a41c04fef
commit
26f70c9c16
@ -3,6 +3,9 @@
|
||||
|
||||
*January 2023*
|
||||
|
||||
**This was the design doc proposing `shard_map`. You may instead want
|
||||
[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).**
|
||||
|
||||
## Motivation
|
||||
|
||||
JAX supports two schools of thought for multi-device programming:
|
||||
@ -374,114 +377,8 @@ One philosophy is: it is almost always simpler to write a program in `jit==pjit`
|
||||
— but if a given part of the program is less optimized by the compiler than it
|
||||
could be, drop into `shmap`!
|
||||
|
||||
### A realistic transformer example
|
||||
|
||||
In fact, we can implement a simple version of the ["collective
|
||||
matmul"](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) algorithm
|
||||
recently introduced in XLA to overlap communication and computation using `shmap`
|
||||
and 30 lines of Python. The basic idea of the algorithm can be grasped with a
|
||||
simple example.
|
||||
|
||||
Suppose we want to compute `C = A @ B` where `A` is sharded by a 1D mesh on the
|
||||
0-th dimension while `B` and `C` are replicated.
|
||||
|
||||
```python
|
||||
M, K, N = 4096, 2048, 1024
|
||||
A = jnp.arange(np.prod((M, K))).reshape((M, K))
|
||||
B = jnp.arange(np.prod((K, N))).reshape((K, N))
|
||||
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
|
||||
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))
|
||||
|
||||
@jax.jit
|
||||
def f(lhs, rhs):
|
||||
return lhs @ rhs
|
||||
|
||||
C = f(A_x, B)
|
||||
```
|
||||
|
||||
A profile shows the blocking all-gather across 8 devices before the matmul can
|
||||
start. This is suboptimal because `A` is sharded on a non-contracting dimension,
|
||||
and each shard of `A` can be matmul'ed with `B` independently and this chunked
|
||||
computation can be overlapped with fetching of the next shard of `A` from
|
||||
another device.
|
||||
|
||||
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507011-e854fb11-43d5-484d-993b-19a3349ed4b9.png">
|
||||
|
||||
This overlap can be implemented using `shmap` and explicit collectives.
|
||||
|
||||
```python
|
||||
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
|
||||
# lhs is the looped operand; rhs is the local operand
|
||||
axis_size = jax.lax.psum(1, axis_name='i')
|
||||
axis_index = jax.lax.axis_index(axis_name='i')
|
||||
chunk_size = lhs.shape[0]
|
||||
|
||||
def f(i, carrys):
|
||||
accum, lhs = carrys
|
||||
# matmul for a chunk
|
||||
update = lhs @ rhs
|
||||
# circular shift to the left
|
||||
lhs = jax.lax.ppermute(
|
||||
lhs,
|
||||
axis_name='i',
|
||||
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
|
||||
)
|
||||
# device 0 computes chunks 0, 1, ...
|
||||
# device 1 computes chunks 1, 2, ...
|
||||
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
|
||||
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
|
||||
return accum, lhs
|
||||
|
||||
accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
|
||||
# fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
|
||||
# accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
|
||||
for i in range(0, axis_size - 1):
|
||||
accum, lhs = f(i, (accum, lhs))
|
||||
|
||||
# compute the last chunk, without the ppermute
|
||||
update = lhs @ rhs
|
||||
i = axis_size - 1
|
||||
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
|
||||
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
|
||||
|
||||
return accum
|
||||
```
|
||||
|
||||
```
|
||||
jit_sharded_f = jax.jit(shard_map(
|
||||
collective_matmul_allgather_lhs_non_contracting, mesh,
|
||||
in_specs=(P('i', None), P()), out_specs=P()))
|
||||
C = jit_sharded_f(A_x, B)
|
||||
```
|
||||
|
||||
A profile shows that the all-gather is gone, and replaced with overlapped matmul
|
||||
with async collective permute. This profile matches very closely with the
|
||||
collective matmul paper result.
|
||||
|
||||
<img width="1147" alt="image" src="https://user-images.githubusercontent.com/1458824/216507064-139f032c-d869-4b67-9e11-1587d4fd2de9.png">
|
||||
|
||||
This collective matmul technique can be used to speed up feedforward blocks in
|
||||
transformer layers. This typically consists of two matrix multiplications
|
||||
followed by a `ReduceScatter` (to resolve partial sums from a parallelized
|
||||
matrix multiplication) and preceded by an `AllGather` (to collect the sharded
|
||||
dimensions along some axes and allow partial sum computation). Together, the
|
||||
`ReduceScatter` from one layer and the `AllGather` for the next amount to an
|
||||
`AllReduce`.
|
||||
|
||||
In a typical profile, the two matmuls will be followed by an `AllReduce`, and
|
||||
they will not be overlapped. Collective matmul can be used to achieve the
|
||||
overlap, but is difficult to trigger, has a minimum slice size and does not yet
|
||||
cover all topologies, tensor shapes and variants of collective matmul (i.e
|
||||
latency and throughput optimized variants). [In a recent
|
||||
paper](https://arxiv.org/abs/2211.05102), we found a ~40% gain in many
|
||||
circumstances from manually implementing collective matmul variants in `shmap`
|
||||
style.
|
||||
|
||||
But it isn’t always more complex! We expect this to be a much more natural way
|
||||
to think about pipelined computation, and plan to do some demos of that soon!
|
||||
|
||||
### Another realistic example
|
||||
### A realistic example
|
||||
|
||||
Here's how `shmap` might look in a transformer layer pass with a 2D weight
|
||||
gathered pattern ([paper](https://arxiv.org/abs/2211.05102), Sec 3.2.3 on p. 5):
|
||||
|
Loading…
x
Reference in New Issue
Block a user