Merge pull request #14472 from nouiz:shmap_jep_fixes

PiperOrigin-RevId: 509617771
This commit is contained in:
jax authors 2023-02-14 13:14:33 -08:00
commit a9ef98992c

View File

@ -59,24 +59,29 @@ Or keep reading the next section to see some `shmap` examples and the API spec.
Sho shick:
```python
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
mesh = Mesh(devices, axis_names=('i', 'j'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 32]
z_partialsum = jnp.dot(a_block, b_block)
z_block = jax.lax.psum(z_partialsum, 'y')
z_block = jax.lax.psum(z_partialsum, 'j')
return z_block
c = matmul_basic(a, b) # c: f32[8, 32]
@ -99,13 +104,13 @@ Notice:
Here's another matmul variant with a fully sharded result:
```python
@partial(shmap, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', 'y'))
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
# c_partialsum: f32[8/X, 32]
c_partialsum = jnp.matmul(a_block, b_block)
# c_block: f32[8/X, 32/Y]
c_block = lax.psum_scatter(c_partialsum, 'y', scatter_dimension=1, tiled=True)
c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
return c_block
c = matmul_reduce_scatter(a, b)
@ -153,7 +158,7 @@ jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
Recall that `jnp.split` slices its input into equally-sized blocks with the same
rank, so that if in the above example `y` has shape `f32[8,5]` then each `y_blk`
has shape `f32[2,5]`, and if each `f(y_blk)` has shape `f32[3,7]` then the final
concatenated result `shmap(f, ...)(y)` has shape `f32[12,7]`. So `shmap`
concatenated result `shard_map(f, ...)(y)` has shape `f32[12,7]`. So `shmap`
(`shard_map`) maps over shards, or blocks, of its inputs. We can say it's a
*rank-preserving ma*p with unconcatenating/concatenating of its inputs/outputs.
@ -316,7 +321,7 @@ physical layout of a single logical `Array`.
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shmap(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
) -> Callable:
...
```
@ -387,8 +392,8 @@ M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))
mesh = Mesh(np.array(jax.devices()), axis_names=('x'))
A_x = jax.device_put(A, NamedSharding(mesh, P('x', None)))
mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))
@jax.jit
def f(lhs, rhs):
@ -410,8 +415,8 @@ This overlap can be implemented using `shmap` and explicit collectives.
```python
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.psum(1, axis_name='x')
axis_index = jax.lax.axis_index(axis_name='x')
axis_size = jax.lax.psum(1, axis_name='i')
axis_index = jax.lax.axis_index(axis_name='i')
chunk_size = lhs.shape[0]
def f(i, carrys):
@ -421,7 +426,7 @@ def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
# circular shift to the left
lhs = jax.lax.ppermute(
lhs,
axis_name='x',
axis_name='i',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
# device 0 computes chunks 0, 1, ...
@ -448,7 +453,7 @@ def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
```
jit_sharded_f = jax.jit(shard_map(
collective_matmul_allgather_lhs_non_contracting, mesh,
in_specs=(P('x', None), P()), out_specs=P()))
in_specs=(P('i', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)
```
@ -498,7 +503,7 @@ def matmul_2D_wg_manual(xnorm, q_wi, layer):
xnorm,
params.q_wi,
scatter_dimension=(0, 2),
axis_name='x',
axis_name='i',
layer=layer)
return q_wi