add docstring / reference doc link for axis_index

fixes #2534
This commit is contained in:
Matthew Johnson 2020-03-29 13:56:26 -07:00
parent bcc5191c63
commit fcc1e76c5a
2 changed files with 38 additions and 0 deletions

View File

@ -163,3 +163,4 @@ Parallelism support is experimental.
pmin
ppermute
pswapaxes
axis_index

View File

@ -248,6 +248,43 @@ parallel_pure_rules: Dict[core.Primitive, Callable] = {}
def axis_index(axis_name):
"""Return the index along the pmapped axis ``axis_name``.
Args:
axis_name: hashable Python object used to name the pmapped axis (see the
``pmap`` docstring for more details).
Returns:
An integer representing the index.
For example, with 8 XLA devices available:
>>> from functools import partial
>>> @partial(pmap, axis_name='i')
... def f(_):
... return lax.axis_index('i')
...
>>> f(np.zeros(4))
ShardedDeviceArray([0, 1, 2, 3], dtype=int32)
>>> f(np.zeros(8))
ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
>>> @partial(pmap, axis_name='i')
... @partial(pmap, axis_name='j')
... def f(_):
... return lax.axis_index('i'), lax.axis_index('j')
...
>>> x, y = f(np.zeros((4, 2)))
>>> print(x)
[[0 0]
[1 1]
[2 2]
[3 3]]
>>> print(y)
[[0 1]
[0 1]
[0 1]
[0 1]]
"""
dynamic_axis_env = _thread_local_state.dynamic_axis_env
frame = dynamic_axis_env[axis_name]
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]