1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Merge pull request from apaszke:axis_index_handle_list

PiperOrigin-RevId: 333304709
This commit is contained in:
jax authors 2020-09-23 09:11:33 -07:00
commit c875ab3ec9

@ -692,7 +692,7 @@ pxla.multi_host_supported_collectives.add(axis_index_p)
# wants to bind an axis name has to additionally implement `process_axis_index`
# and put its main trace on the axis env stack.
def _axis_index_bind(*, axis_name):
if not isinstance(axis_name, tuple):
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
inner_size = 1
index = 0