mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too. Fix the tests so that the mock devices are hashable. PiperOrigin-RevId: 561103167
This commit is contained in:
parent
ff5b480c6b
commit
6072d5993e
@ -24,6 +24,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
||||
* jax2tf default serialization version is now 7, which introduces new shape
|
||||
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
|
||||
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
|
||||
applies to mock devices or user created devices. `jax.devices()` are
|
||||
already hashable.
|
||||
|
||||
* Breaking changes:
|
||||
* jax2tf now uses native serialization by default. See
|
||||
|
@ -90,6 +90,35 @@ class ResourceEnv(NamedTuple):
|
||||
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
||||
if global_mesh.empty:
|
||||
return global_mesh
|
||||
is_local_device = np.vectorize(
|
||||
lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
|
||||
subcube_indices = []
|
||||
# We take the smallest slice of each dimension that doesn't skip any local device.
|
||||
for axis in range(global_mesh.devices.ndim):
|
||||
other_axes = util.tuple_delete(tuple(range(global_mesh.devices.ndim)), axis)
|
||||
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
||||
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
||||
local_slices = is_local_device.any(other_axes, keepdims=False)
|
||||
nonzero_indices = np.flatnonzero(local_slices)
|
||||
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
||||
subcube_indices.append(slice(start, end + 1))
|
||||
subcube_indices = tuple(subcube_indices)
|
||||
# We only end up with all conditions being true if the local devices formed a
|
||||
# subcube of the full array. This is because we were biased towards taking a
|
||||
# "hull" spanned by the devices, and in case the local devices don't form a
|
||||
# subcube that hull will contain non-local devices.
|
||||
if not is_local_device[subcube_indices].all():
|
||||
raise ValueError(
|
||||
"When passing host local inputs to pjit or xmap, devices "
|
||||
"connected to a single host must form a contiguous subcube of the "
|
||||
"global device mesh")
|
||||
return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names)
|
||||
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
|
||||
|
||||
@ -156,18 +185,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
axis_names = tuple(axis_names)
|
||||
assert devices.ndim == len(axis_names)
|
||||
|
||||
flat_devices = tuple(devices.flat)
|
||||
|
||||
# TODO(yashkatariya): Make Mock Devices hashable and them remove this
|
||||
# workaround
|
||||
_use_cache = True
|
||||
try:
|
||||
hash(flat_devices[0])
|
||||
except:
|
||||
_use_cache = False
|
||||
|
||||
if _use_cache:
|
||||
key = (axis_names, devices.shape, flat_devices)
|
||||
key = (axis_names, devices.shape, tuple(devices.flat))
|
||||
val = _mesh_object_dict.get(key, None)
|
||||
if val is not None:
|
||||
return val
|
||||
@ -176,7 +194,6 @@ class Mesh(contextlib.ContextDecorator):
|
||||
self.devices = devices.copy()
|
||||
self.devices.flags.writeable = False
|
||||
self.axis_names = axis_names
|
||||
if _use_cache:
|
||||
_mesh_object_dict[key] = self
|
||||
return self
|
||||
|
||||
@ -248,36 +265,12 @@ class Mesh(contextlib.ContextDecorator):
|
||||
def is_multi_process(self):
|
||||
return self.devices.size != len(self.local_devices)
|
||||
|
||||
@functools.cached_property
|
||||
@property
|
||||
def local_mesh(self):
|
||||
return self._local_mesh(xb.process_index())
|
||||
|
||||
def _local_mesh(self, process_index):
|
||||
if self.empty:
|
||||
return self
|
||||
is_local_device = np.vectorize(
|
||||
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
|
||||
subcube_indices = []
|
||||
# We take the smallest slice of each dimension that doesn't skip any local device.
|
||||
for axis in range(self.devices.ndim):
|
||||
other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis)
|
||||
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
||||
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
||||
local_slices = is_local_device.any(other_axes, keepdims=False)
|
||||
nonzero_indices = np.flatnonzero(local_slices)
|
||||
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
||||
subcube_indices.append(slice(start, end + 1))
|
||||
subcube_indices = tuple(subcube_indices)
|
||||
# We only end up with all conditions being true if the local devices formed a
|
||||
# subcube of the full array. This is because we were biased towards taking a
|
||||
# "hull" spanned by the devices, and in case the local devices don't form a
|
||||
# subcube that hull will contain non-local devices.
|
||||
if not is_local_device[subcube_indices].all():
|
||||
raise ValueError(
|
||||
"When passing host local inputs to pjit or xmap, devices "
|
||||
"connected to a single host must form a contiguous subcube of the "
|
||||
"global device mesh")
|
||||
return Mesh(self.devices[subcube_indices], self.axis_names)
|
||||
return _get_local_mesh(self, process_index)
|
||||
|
||||
@functools.cached_property
|
||||
def device_ids(self):
|
||||
|
@ -28,7 +28,7 @@ from jax.sharding import Mesh
|
||||
from jax._src import test_util
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MockTpuDevice:
|
||||
"""Mock TPU device for testing."""
|
||||
id: int
|
||||
|
Loading…
x
Reference in New Issue
Block a user