mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Merge pull request #4391 from apaszke:axis_index_handle_list
PiperOrigin-RevId: 333304709
This commit is contained in:
commit
c875ab3ec9
@ -692,7 +692,7 @@ pxla.multi_host_supported_collectives.add(axis_index_p)
|
||||
# wants to bind an axis name has to additionally implement `process_axis_index`
|
||||
# and put its main trace on the axis env stack.
|
||||
def _axis_index_bind(*, axis_name):
|
||||
if not isinstance(axis_name, tuple):
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
inner_size = 1
|
||||
index = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user