mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix different device order reshard for McJAX.
Previously, the `np.take(tile_assignment_devices, permute_order)` did not maintain the invariant of maintaining the concrete device order after permutation (per the old array). But doing `np.take(permute_order, tile_assignment_devices)` maintains that invariant and hence is the correct thing to do. PiperOrigin-RevId: 654884965
This commit is contained in:
parent
ce5f9a6da9
commit
81afdaa9e8
@ -16,7 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import collections
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
@ -38,7 +37,6 @@ from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.abstract_arrays import array_types
|
||||
@ -324,10 +322,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
|
||||
def _override_get_device_assignment(sharding, *args, **kwargs):
|
||||
da = sharding._device_assignment
|
||||
return xb.get_device_backend(da[0]), da
|
||||
|
||||
def _identity_fn(x):
|
||||
return x
|
||||
|
||||
@ -361,7 +355,7 @@ def _different_device_order_reshard(x, target_sharding):
|
||||
new_op_sharding.iota_reshape_dims = []
|
||||
new_op_sharding.iota_transpose_perm = []
|
||||
new_op_sharding.tile_assignment_devices = np.take(
|
||||
old_hlo_sharding.tile_assignment_devices(), permute_order
|
||||
permute_order, old_hlo_sharding.tile_assignment_devices()
|
||||
)
|
||||
new_hlo_sharding = xc.HloSharding.from_proto(new_op_sharding)
|
||||
# TODO(yashkatariya): Enable this when HloSharding conversion is fixed in
|
||||
@ -370,40 +364,18 @@ def _different_device_order_reshard(x, target_sharding):
|
||||
# == new_hlo_sharding.tile_assignment_dimensions())
|
||||
# assert (new_op_sharding.tile_assignment_devices
|
||||
# == new_hlo_sharding.tile_assignment_devices())
|
||||
assert (list(np.take(inp_sharding._device_assignment,
|
||||
old_hlo_sharding.tile_assignment_devices()))
|
||||
== list(np.take(target_sharding._device_assignment,
|
||||
new_op_sharding.tile_assignment_devices)))
|
||||
|
||||
new_sharding = GSPMDSharding(
|
||||
target_sharding._device_assignment, new_hlo_sharding,
|
||||
memory_kind=target_sharding.memory_kind)
|
||||
|
||||
old_device_to_index_buffer = collections.defaultdict()
|
||||
old_index_to_buffer = collections.defaultdict()
|
||||
for s in x.addressable_shards:
|
||||
old_index_to_buffer[array.hashed_index(s.index)] = s.data
|
||||
old_device_to_index_buffer[s.device] = (s.index, s.data)
|
||||
|
||||
new_arrays = []
|
||||
for new_d, new_index in new_sharding.addressable_devices_indices_map(x.shape).items():
|
||||
old_index, old_buf = old_device_to_index_buffer[new_d]
|
||||
if old_index == new_index:
|
||||
assert array._get_device(old_buf) == new_d, (
|
||||
array._get_device(old_buf), new_d)
|
||||
new_arrays.append(old_buf)
|
||||
else:
|
||||
old_buf = old_index_to_buffer[array.hashed_index(new_index)]
|
||||
new_arrays.append(
|
||||
pxla.batched_device_put(old_buf.aval, SingleDeviceSharding(new_d),
|
||||
[old_buf], [new_d]))
|
||||
|
||||
new_x = array.ArrayImpl(
|
||||
x.aval, new_sharding, new_arrays, committed=True, _skip_checks=True)
|
||||
|
||||
_orig_get_and_check_device_assignment = pxla._get_and_check_device_assignment.fn
|
||||
pxla._get_and_check_device_assignment.fn = partial(
|
||||
_override_get_device_assignment, target_sharding)
|
||||
try:
|
||||
return api.jit(_identity_fn, out_shardings=target_sharding)(new_x)
|
||||
finally:
|
||||
pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment
|
||||
new_x = array.make_array_from_single_device_arrays(
|
||||
x.shape,
|
||||
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding,
|
||||
memory_kind=target_sharding.memory_kind),
|
||||
x._arrays,
|
||||
)
|
||||
return api.jit(_identity_fn, out_shardings=target_sharding)(new_x)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -25,7 +25,6 @@ from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, NamedTuple, TypeVar, Union, cast
|
||||
import warnings
|
||||
|
||||
@ -1740,16 +1739,6 @@ def _get_default_device() -> xc.Device:
|
||||
return config.default_device.value or xb.local_devices()[0]
|
||||
|
||||
|
||||
class _thread_local_decorator(threading.local):
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
@_thread_local_decorator
|
||||
def _get_and_check_device_assignment(
|
||||
shardings: Iterable[ShardingInfo],
|
||||
devices: Sequence[xc.Device] | None,
|
||||
|
@ -4265,10 +4265,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
s2 = NamedSharding(mesh2, P())
|
||||
|
||||
x_s1 = jax.device_put(np_inp, s1)
|
||||
|
||||
# Reshard!
|
||||
out = jax.device_put(x_s1, s2)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertEqual(out.sharding, s2)
|
||||
del out
|
||||
|
||||
s3 = NamedSharding(mesh2, P('model_q'))
|
||||
x_s3 = jax.device_put(np_inp, s3)
|
||||
# Reshard to iota device assignment!
|
||||
out2 = jax.device_put(x_s3, s1)
|
||||
self.assertArraysEqual(out2, np_inp)
|
||||
self.assertEqual(out2.sharding, s1)
|
||||
|
||||
def test_convert_element_type_sharding(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user