mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #14472 from nouiz:shmap_jep_fixes
PiperOrigin-RevId: 509617771
This commit is contained in:
commit
a9ef98992c
@ -59,24 +59,29 @@ Or keep reading the next section to see some `shmap` examples and the API spec.
|
||||
Sho shick:
|
||||
|
||||
```python
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh, PartitionSpec as P
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
devices = mesh_utils.create_device_mesh((4, 2))
|
||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
mesh = Mesh(devices, axis_names=('i', 'j'))
|
||||
|
||||
a = jnp.arange( 8 * 16.).reshape(8, 16)
|
||||
b = jnp.arange(16 * 32.).reshape(16, 32)
|
||||
|
||||
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
|
||||
out_specs=P('x', None))
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
|
||||
out_specs=P('i', None))
|
||||
def matmul_basic(a_block, b_block):
|
||||
# a_block: f32[2, 8]
|
||||
# b_block: f32[8, 32]
|
||||
z_partialsum = jnp.dot(a_block, b_block)
|
||||
z_block = jax.lax.psum(z_partialsum, 'y')
|
||||
z_block = jax.lax.psum(z_partialsum, 'j')
|
||||
return z_block
|
||||
|
||||
c = matmul_basic(a, b) # c: f32[8, 32]
|
||||
@ -99,13 +104,13 @@ Notice:
|
||||
Here's another matmul variant with a fully sharded result:
|
||||
|
||||
```python
|
||||
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
|
||||
out_specs=P('x', 'y'))
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
|
||||
out_specs=P('i', 'j'))
|
||||
def matmul_reduce_scatter(a_block, b_block):
|
||||
# c_partialsum: f32[8/X, 32]
|
||||
c_partialsum = jnp.matmul(a_block, b_block)
|
||||
# c_block: f32[8/X, 32/Y]
|
||||
c_block = lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True)
|
||||
c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
|
||||
return c_block
|
||||
|
||||
c = matmul_reduce_scatter(a, b)
|
||||
@ -153,7 +158,7 @@ jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
|
||||
Recall that `jnp.split` slices its input into equally-sized blocks with the same
|
||||
rank, so that if in the above example `y` has shape `f32[8,5]` then each `y_blk`
|
||||
has shape `f32[2,5]`, and if each `f(y_blk)` has shape `f32[3,7]` then the final
|
||||
concatenated result `shmap(f, ...)(y)` has shape `f32[12,7]`. So `shmap`
|
||||
concatenated result `shard_map(f, ...)(y)` has shape `f32[12,7]`. So `shmap`
|
||||
(`shard_map`) maps over shards, or blocks, of its inputs. We can say it's a
|
||||
*rank-preserving ma*p with unconcatenating/concatenating of its inputs/outputs.
|
||||
|
||||
@ -316,7 +321,7 @@ physical layout of a single logical `Array`.
|
||||
from jax.sharding import Mesh
|
||||
Specs = PyTree[PartitionSpec]
|
||||
|
||||
def shmap(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
|
||||
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
|
||||
) -> Callable:
|
||||
...
|
||||
```
|
||||
@ -387,8 +392,8 @@ M, K, N = 4096, 2048, 1024
|
||||
A = jnp.arange(np.prod((M, K))).reshape((M, K))
|
||||
B = jnp.arange(np.prod((K, N))).reshape((K, N))
|
||||
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x'))
|
||||
A_x = jax.device_put(A, NamedSharding(mesh, P('x', None)))
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
|
||||
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))
|
||||
|
||||
@jax.jit
|
||||
def f(lhs, rhs):
|
||||
@ -410,8 +415,8 @@ This overlap can be implemented using `shmap` and explicit collectives.
|
||||
```python
|
||||
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
|
||||
# lhs is the looped operand; rhs is the local operand
|
||||
axis_size = jax.lax.psum(1, axis_name='x')
|
||||
axis_index = jax.lax.axis_index(axis_name='x')
|
||||
axis_size = jax.lax.psum(1, axis_name='i')
|
||||
axis_index = jax.lax.axis_index(axis_name='i')
|
||||
chunk_size = lhs.shape[0]
|
||||
|
||||
def f(i, carrys):
|
||||
@ -421,7 +426,7 @@ def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
|
||||
# circular shift to the left
|
||||
lhs = jax.lax.ppermute(
|
||||
lhs,
|
||||
axis_name='x',
|
||||
axis_name='i',
|
||||
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
|
||||
)
|
||||
# device 0 computes chunks 0, 1, ...
|
||||
@ -448,7 +453,7 @@ def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
|
||||
```
|
||||
jit_sharded_f = jax.jit(shard_map(
|
||||
collective_matmul_allgather_lhs_non_contracting, mesh,
|
||||
in_specs=(P('x', None), P()), out_specs=P()))
|
||||
in_specs=(P('i', None), P()), out_specs=P()))
|
||||
C = jit_sharded_f(A_x, B)
|
||||
```
|
||||
|
||||
@ -498,7 +503,7 @@ def matmul_2D_wg_manual(xnorm, q_wi, layer):
|
||||
xnorm,
|
||||
params.q_wi,
|
||||
scatter_dimension=(0, 2),
|
||||
axis_name='x',
|
||||
axis_name='i',
|
||||
layer=layer)
|
||||
return q_wi
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user