mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 20:36:06 +00:00
Add examples to lax.psum to illustrate axis_index_groups better.
PiperOrigin-RevId: 497401892
This commit is contained in:
parent
2f3d75aa03
commit
9fda20fc29
@ -76,6 +76,35 @@ def psum(x, axis_name, *, axis_index_groups=None):
|
||||
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
|
||||
>>> print(y)
|
||||
[0. 0.16666667 0.33333334 0.5 ]
|
||||
|
||||
Suppose we want to perform `psum` among two groups, one with `device0` and `device1`, the other with `device2` and `device3`,
|
||||
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
|
||||
>>> print(y)
|
||||
[1 1 5 5]
|
||||
|
||||
An example using 2D-shaped x. Each row is data from one device.
|
||||
>>> x = np.arange(16).reshape(4, 4)
|
||||
>>> print(x)
|
||||
[[ 0 1 2 3]
|
||||
[ 4 5 6 7]
|
||||
[ 8 9 10 11]
|
||||
[12 13 14 15]]
|
||||
|
||||
Full `psum` across all devices:
|
||||
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
|
||||
>>> print(y)
|
||||
[[24 28 32 36]
|
||||
[24 28 32 36]
|
||||
[24 28 32 36]
|
||||
[24 28 32 36]]
|
||||
|
||||
Perform `psum` among two groups:
|
||||
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
|
||||
>>> print(y)
|
||||
[[ 4 6 8 10]
|
||||
[ 4 6 8 10]
|
||||
[20 22 24 26]
|
||||
[20 22 24 26]]
|
||||
"""
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
@ -1477,7 +1506,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, t
|
||||
|
||||
For example, with 4 XLA devices available:
|
||||
|
||||
>>> x = np.arange(16).reshape(4,4)
|
||||
>>> x = np.arange(16).reshape(4, 4)
|
||||
>>> print(x)
|
||||
[[ 0 1 2 3]
|
||||
[ 4 5 6 7]
|
||||
|
Loading…
x
Reference in New Issue
Block a user