Add examples to lax.psum to illustrate axis_index_groups better.

PiperOrigin-RevId: 497401892
This commit is contained in:
Qiao Zhang 2022-12-23 12:04:19 -08:00 committed by jax authors
parent 2f3d75aa03
commit 9fda20fc29

View File

@ -76,6 +76,35 @@ def psum(x, axis_name, *, axis_index_groups=None):
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0. 0.16666667 0.33333334 0.5 ]
Suppose we want to perform `psum` among two groups, one with `device0` and `device1`, the other with `device2` and `device3`,
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[1 1 5 5]
An example using 2D-shaped x. Each row is data from one device.
>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]
Full `psum` across all devices:
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[[24 28 32 36]
[24 28 32 36]
[24 28 32 36]
[24 28 32 36]]
Perform `psum` among two groups:
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[[ 4 6 8 10]
[ 4 6 8 10]
[20 22 24 26]
[20 22 24 26]]
"""
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
@ -1477,7 +1506,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, t
For example, with 4 XLA devices available:
>>> x = np.arange(16).reshape(4,4)
>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0 1 2 3]
[ 4 5 6 7]