Fix different device order reshard for McJAX.

Previously, the `np.take(tile_assignment_devices, permute_order)` did not maintain the invariant of maintaining the concrete device order after permutation (per the old array).

But doing `np.take(permute_order, tile_assignment_devices)` maintains that invariant and hence is the correct thing to do.

PiperOrigin-RevId: 654884965
This commit is contained in:
Yash Katariya 2024-07-22 13:55:55 -07:00 committed by jax authors
parent ce5f9a6da9
commit 81afdaa9e8
3 changed files with 21 additions and 52 deletions

View File

@ -16,7 +16,6 @@
from __future__ import annotations
import atexit
import collections
from collections.abc import Callable, Iterator, Sequence
import contextlib
import dataclasses
@ -38,7 +37,6 @@ from jax._src import dtypes
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.abstract_arrays import array_types
@ -324,10 +322,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _override_get_device_assignment(sharding, *args, **kwargs):
da = sharding._device_assignment
return xb.get_device_backend(da[0]), da
def _identity_fn(x):
return x
@ -361,7 +355,7 @@ def _different_device_order_reshard(x, target_sharding):
new_op_sharding.iota_reshape_dims = []
new_op_sharding.iota_transpose_perm = []
new_op_sharding.tile_assignment_devices = np.take(
old_hlo_sharding.tile_assignment_devices(), permute_order
permute_order, old_hlo_sharding.tile_assignment_devices()
)
new_hlo_sharding = xc.HloSharding.from_proto(new_op_sharding)
# TODO(yashkatariya): Enable this when HloSharding conversion is fixed in
@ -370,40 +364,18 @@ def _different_device_order_reshard(x, target_sharding):
# == new_hlo_sharding.tile_assignment_dimensions())
# assert (new_op_sharding.tile_assignment_devices
# == new_hlo_sharding.tile_assignment_devices())
assert (list(np.take(inp_sharding._device_assignment,
old_hlo_sharding.tile_assignment_devices()))
== list(np.take(target_sharding._device_assignment,
new_op_sharding.tile_assignment_devices)))
new_sharding = GSPMDSharding(
target_sharding._device_assignment, new_hlo_sharding,
memory_kind=target_sharding.memory_kind)
old_device_to_index_buffer = collections.defaultdict()
old_index_to_buffer = collections.defaultdict()
for s in x.addressable_shards:
old_index_to_buffer[array.hashed_index(s.index)] = s.data
old_device_to_index_buffer[s.device] = (s.index, s.data)
new_arrays = []
for new_d, new_index in new_sharding.addressable_devices_indices_map(x.shape).items():
old_index, old_buf = old_device_to_index_buffer[new_d]
if old_index == new_index:
assert array._get_device(old_buf) == new_d, (
array._get_device(old_buf), new_d)
new_arrays.append(old_buf)
else:
old_buf = old_index_to_buffer[array.hashed_index(new_index)]
new_arrays.append(
pxla.batched_device_put(old_buf.aval, SingleDeviceSharding(new_d),
[old_buf], [new_d]))
new_x = array.ArrayImpl(
x.aval, new_sharding, new_arrays, committed=True, _skip_checks=True)
_orig_get_and_check_device_assignment = pxla._get_and_check_device_assignment.fn
pxla._get_and_check_device_assignment.fn = partial(
_override_get_device_assignment, target_sharding)
try:
return api.jit(_identity_fn, out_shardings=target_sharding)(new_x)
finally:
pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment
new_x = array.make_array_from_single_device_arrays(
x.shape,
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding,
memory_kind=target_sharding.memory_kind),
x._arrays,
)
return api.jit(_identity_fn, out_shardings=target_sharding)(new_x)
@dataclasses.dataclass(frozen=True)

View File

@ -25,7 +25,6 @@ from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
import threading
from typing import Any, NamedTuple, TypeVar, Union, cast
import warnings
@ -1740,16 +1739,6 @@ def _get_default_device() -> xc.Device:
return config.default_device.value or xb.local_devices()[0]
class _thread_local_decorator(threading.local):
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
@_thread_local_decorator
def _get_and_check_device_assignment(
shardings: Iterable[ShardingInfo],
devices: Sequence[xc.Device] | None,

View File

@ -4265,10 +4265,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
s2 = NamedSharding(mesh2, P())
x_s1 = jax.device_put(np_inp, s1)
# Reshard!
out = jax.device_put(x_s1, s2)
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, s2)
del out
s3 = NamedSharding(mesh2, P('model_q'))
x_s3 = jax.device_put(np_inp, s3)
# Reshard to iota device assignment!
out2 = jax.device_put(x_s3, s1)
self.assertArraysEqual(out2, np_inp)
self.assertEqual(out2.sharding, s1)
def test_convert_element_type_sharding(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))