mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
parent
bcc5191c63
commit
fcc1e76c5a
@ -163,3 +163,4 @@ Parallelism support is experimental.
|
||||
pmin
|
||||
ppermute
|
||||
pswapaxes
|
||||
axis_index
|
||||
|
@ -248,6 +248,43 @@ parallel_pure_rules: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
def axis_index(axis_name):
|
||||
"""Return the index along the pmapped axis ``axis_name``.
|
||||
|
||||
Args:
|
||||
axis_name: hashable Python object used to name the pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
|
||||
Returns:
|
||||
An integer representing the index.
|
||||
|
||||
For example, with 8 XLA devices available:
|
||||
|
||||
>>> from functools import partial
|
||||
>>> @partial(pmap, axis_name='i')
|
||||
... def f(_):
|
||||
... return lax.axis_index('i')
|
||||
...
|
||||
>>> f(np.zeros(4))
|
||||
ShardedDeviceArray([0, 1, 2, 3], dtype=int32)
|
||||
>>> f(np.zeros(8))
|
||||
ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
||||
>>> @partial(pmap, axis_name='i')
|
||||
... @partial(pmap, axis_name='j')
|
||||
... def f(_):
|
||||
... return lax.axis_index('i'), lax.axis_index('j')
|
||||
...
|
||||
>>> x, y = f(np.zeros((4, 2)))
|
||||
>>> print(x)
|
||||
[[0 0]
|
||||
[1 1]
|
||||
[2 2]
|
||||
[3 3]]
|
||||
>>> print(y)
|
||||
[[0 1]
|
||||
[0 1]
|
||||
[0 1]
|
||||
[0 1]]
|
||||
"""
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
frame = dynamic_axis_env[axis_name]
|
||||
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user