diff --git a/CHANGELOG.md b/CHANGELOG.md index 527e80d3c..d0a49ef6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ Remember to align the itemized text with the first line of an item within a list `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). * jax2tf default serialization version is now 7, which introduces new shape [safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). + * Devices passed to `jax.sharding.Mesh` should be hashable. This specifically + applies to mock devices or user created devices. `jax.devices()` are + already hashable. * Breaking changes: * jax2tf now uses native serialization by default. See diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 9a38f1af6..ea8d66c14 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -90,6 +90,35 @@ class ResourceEnv(NamedTuple): return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})" +@functools.lru_cache(maxsize=128) +def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: + if global_mesh.empty: + return global_mesh + is_local_device = np.vectorize( + lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices) + subcube_indices = [] + # We take the smallest slice of each dimension that doesn't skip any local device. + for axis in range(global_mesh.devices.ndim): + other_axes = util.tuple_delete(tuple(range(global_mesh.devices.ndim)), axis) + # NOTE: This re-reduces over many axes multiple times, so we could definitely + # optimize it, but I hope it won't be a bottleneck anytime soon. + local_slices = is_local_device.any(other_axes, keepdims=False) + nonzero_indices = np.flatnonzero(local_slices) + start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices)) + subcube_indices.append(slice(start, end + 1)) + subcube_indices = tuple(subcube_indices) + # We only end up with all conditions being true if the local devices formed a + # subcube of the full array. This is because we were biased towards taking a + # "hull" spanned by the devices, and in case the local devices don't form a + # subcube that hull will contain non-local devices. + if not is_local_device[subcube_indices].all(): + raise ValueError( + "When passing host local inputs to pjit or xmap, devices " + "connected to a single host must form a contiguous subcube of the " + "global device mesh") + return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names) + + _mesh_object_dict = {} # type: ignore @@ -156,28 +185,16 @@ class Mesh(contextlib.ContextDecorator): axis_names = tuple(axis_names) assert devices.ndim == len(axis_names) - flat_devices = tuple(devices.flat) - - # TODO(yashkatariya): Make Mock Devices hashable and them remove this - # workaround - _use_cache = True - try: - hash(flat_devices[0]) - except: - _use_cache = False - - if _use_cache: - key = (axis_names, devices.shape, flat_devices) - val = _mesh_object_dict.get(key, None) - if val is not None: - return val + key = (axis_names, devices.shape, tuple(devices.flat)) + val = _mesh_object_dict.get(key, None) + if val is not None: + return val self = super(Mesh, cls).__new__(cls) self.devices = devices.copy() self.devices.flags.writeable = False self.axis_names = axis_names - if _use_cache: - _mesh_object_dict[key] = self + _mesh_object_dict[key] = self return self def __reduce__(self): @@ -248,36 +265,12 @@ class Mesh(contextlib.ContextDecorator): def is_multi_process(self): return self.devices.size != len(self.local_devices) - @functools.cached_property + @property def local_mesh(self): return self._local_mesh(xb.process_index()) def _local_mesh(self, process_index): - if self.empty: - return self - is_local_device = np.vectorize( - lambda d: d.process_index == process_index, otypes=[bool])(self.devices) - subcube_indices = [] - # We take the smallest slice of each dimension that doesn't skip any local device. - for axis in range(self.devices.ndim): - other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis) - # NOTE: This re-reduces over many axes multiple times, so we could definitely - # optimize it, but I hope it won't be a bottleneck anytime soon. - local_slices = is_local_device.any(other_axes, keepdims=False) - nonzero_indices = np.flatnonzero(local_slices) - start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices)) - subcube_indices.append(slice(start, end + 1)) - subcube_indices = tuple(subcube_indices) - # We only end up with all conditions being true if the local devices formed a - # subcube of the full array. This is because we were biased towards taking a - # "hull" spanned by the devices, and in case the local devices don't form a - # subcube that hull will contain non-local devices. - if not is_local_device[subcube_indices].all(): - raise ValueError( - "When passing host local inputs to pjit or xmap, devices " - "connected to a single host must form a contiguous subcube of the " - "global device mesh") - return Mesh(self.devices[subcube_indices], self.axis_names) + return _get_local_mesh(self, process_index) @functools.cached_property def device_ids(self): diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index b8caa5aeb..48504ed71 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -28,7 +28,7 @@ from jax.sharding import Mesh from jax._src import test_util -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class MockTpuDevice: """Mock TPU device for testing.""" id: int