Any devices passed to jax.sharding.Mesh are required to be hashable.

This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
This commit is contained in:
Yash Katariya 2023-08-29 12:17:37 -07:00 committed by jax authors
parent ff5b480c6b
commit 6072d5993e
3 changed files with 40 additions and 44 deletions

View File

@ -24,6 +24,9 @@ Remember to align the itemized text with the first line of an item within a list
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.
* Breaking changes:
* jax2tf now uses native serialization by default. See

View File

@ -90,6 +90,35 @@ class ResourceEnv(NamedTuple):
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
@functools.lru_cache(maxsize=128)
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
if global_mesh.empty:
return global_mesh
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(global_mesh.devices.ndim):
other_axes = util.tuple_delete(tuple(range(global_mesh.devices.ndim)), axis)
# NOTE: This re-reduces over many axes multiple times, so we could definitely
# optimize it, but I hope it won't be a bottleneck anytime soon.
local_slices = is_local_device.any(other_axes, keepdims=False)
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names)
_mesh_object_dict = {} # type: ignore
@ -156,28 +185,16 @@ class Mesh(contextlib.ContextDecorator):
axis_names = tuple(axis_names)
assert devices.ndim == len(axis_names)
flat_devices = tuple(devices.flat)
# TODO(yashkatariya): Make Mock Devices hashable and them remove this
# workaround
_use_cache = True
try:
hash(flat_devices[0])
except:
_use_cache = False
if _use_cache:
key = (axis_names, devices.shape, flat_devices)
val = _mesh_object_dict.get(key, None)
if val is not None:
return val
key = (axis_names, devices.shape, tuple(devices.flat))
val = _mesh_object_dict.get(key, None)
if val is not None:
return val
self = super(Mesh, cls).__new__(cls)
self.devices = devices.copy()
self.devices.flags.writeable = False
self.axis_names = axis_names
if _use_cache:
_mesh_object_dict[key] = self
_mesh_object_dict[key] = self
return self
def __reduce__(self):
@ -248,36 +265,12 @@ class Mesh(contextlib.ContextDecorator):
def is_multi_process(self):
return self.devices.size != len(self.local_devices)
@functools.cached_property
@property
def local_mesh(self):
return self._local_mesh(xb.process_index())
def _local_mesh(self, process_index):
if self.empty:
return self
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(self.devices.ndim):
other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis)
# NOTE: This re-reduces over many axes multiple times, so we could definitely
# optimize it, but I hope it won't be a bottleneck anytime soon.
local_slices = is_local_device.any(other_axes, keepdims=False)
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(self.devices[subcube_indices], self.axis_names)
return _get_local_mesh(self, process_index)
@functools.cached_property
def device_ids(self):

View File

@ -28,7 +28,7 @@ from jax.sharding import Mesh
from jax._src import test_util
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class MockTpuDevice:
"""Mock TPU device for testing."""
id: int