mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too. Fix the tests so that the mock devices are hashable. PiperOrigin-RevId: 561103167
This commit is contained in:
parent
ff5b480c6b
commit
6072d5993e
@ -24,6 +24,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
|
||||||
* jax2tf default serialization version is now 7, which introduces new shape
|
* jax2tf default serialization version is now 7, which introduces new shape
|
||||||
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
|
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
|
||||||
|
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
|
||||||
|
applies to mock devices or user created devices. `jax.devices()` are
|
||||||
|
already hashable.
|
||||||
|
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
* jax2tf now uses native serialization by default. See
|
* jax2tf now uses native serialization by default. See
|
||||||
|
@ -90,6 +90,35 @@ class ResourceEnv(NamedTuple):
|
|||||||
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
|
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
|
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
||||||
|
if global_mesh.empty:
|
||||||
|
return global_mesh
|
||||||
|
is_local_device = np.vectorize(
|
||||||
|
lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
|
||||||
|
subcube_indices = []
|
||||||
|
# We take the smallest slice of each dimension that doesn't skip any local device.
|
||||||
|
for axis in range(global_mesh.devices.ndim):
|
||||||
|
other_axes = util.tuple_delete(tuple(range(global_mesh.devices.ndim)), axis)
|
||||||
|
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
||||||
|
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
||||||
|
local_slices = is_local_device.any(other_axes, keepdims=False)
|
||||||
|
nonzero_indices = np.flatnonzero(local_slices)
|
||||||
|
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
||||||
|
subcube_indices.append(slice(start, end + 1))
|
||||||
|
subcube_indices = tuple(subcube_indices)
|
||||||
|
# We only end up with all conditions being true if the local devices formed a
|
||||||
|
# subcube of the full array. This is because we were biased towards taking a
|
||||||
|
# "hull" spanned by the devices, and in case the local devices don't form a
|
||||||
|
# subcube that hull will contain non-local devices.
|
||||||
|
if not is_local_device[subcube_indices].all():
|
||||||
|
raise ValueError(
|
||||||
|
"When passing host local inputs to pjit or xmap, devices "
|
||||||
|
"connected to a single host must form a contiguous subcube of the "
|
||||||
|
"global device mesh")
|
||||||
|
return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names)
|
||||||
|
|
||||||
|
|
||||||
_mesh_object_dict = {} # type: ignore
|
_mesh_object_dict = {} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@ -156,28 +185,16 @@ class Mesh(contextlib.ContextDecorator):
|
|||||||
axis_names = tuple(axis_names)
|
axis_names = tuple(axis_names)
|
||||||
assert devices.ndim == len(axis_names)
|
assert devices.ndim == len(axis_names)
|
||||||
|
|
||||||
flat_devices = tuple(devices.flat)
|
key = (axis_names, devices.shape, tuple(devices.flat))
|
||||||
|
val = _mesh_object_dict.get(key, None)
|
||||||
# TODO(yashkatariya): Make Mock Devices hashable and them remove this
|
if val is not None:
|
||||||
# workaround
|
return val
|
||||||
_use_cache = True
|
|
||||||
try:
|
|
||||||
hash(flat_devices[0])
|
|
||||||
except:
|
|
||||||
_use_cache = False
|
|
||||||
|
|
||||||
if _use_cache:
|
|
||||||
key = (axis_names, devices.shape, flat_devices)
|
|
||||||
val = _mesh_object_dict.get(key, None)
|
|
||||||
if val is not None:
|
|
||||||
return val
|
|
||||||
|
|
||||||
self = super(Mesh, cls).__new__(cls)
|
self = super(Mesh, cls).__new__(cls)
|
||||||
self.devices = devices.copy()
|
self.devices = devices.copy()
|
||||||
self.devices.flags.writeable = False
|
self.devices.flags.writeable = False
|
||||||
self.axis_names = axis_names
|
self.axis_names = axis_names
|
||||||
if _use_cache:
|
_mesh_object_dict[key] = self
|
||||||
_mesh_object_dict[key] = self
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
@ -248,36 +265,12 @@ class Mesh(contextlib.ContextDecorator):
|
|||||||
def is_multi_process(self):
|
def is_multi_process(self):
|
||||||
return self.devices.size != len(self.local_devices)
|
return self.devices.size != len(self.local_devices)
|
||||||
|
|
||||||
@functools.cached_property
|
@property
|
||||||
def local_mesh(self):
|
def local_mesh(self):
|
||||||
return self._local_mesh(xb.process_index())
|
return self._local_mesh(xb.process_index())
|
||||||
|
|
||||||
def _local_mesh(self, process_index):
|
def _local_mesh(self, process_index):
|
||||||
if self.empty:
|
return _get_local_mesh(self, process_index)
|
||||||
return self
|
|
||||||
is_local_device = np.vectorize(
|
|
||||||
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
|
|
||||||
subcube_indices = []
|
|
||||||
# We take the smallest slice of each dimension that doesn't skip any local device.
|
|
||||||
for axis in range(self.devices.ndim):
|
|
||||||
other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis)
|
|
||||||
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
|
||||||
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
|
||||||
local_slices = is_local_device.any(other_axes, keepdims=False)
|
|
||||||
nonzero_indices = np.flatnonzero(local_slices)
|
|
||||||
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
|
||||||
subcube_indices.append(slice(start, end + 1))
|
|
||||||
subcube_indices = tuple(subcube_indices)
|
|
||||||
# We only end up with all conditions being true if the local devices formed a
|
|
||||||
# subcube of the full array. This is because we were biased towards taking a
|
|
||||||
# "hull" spanned by the devices, and in case the local devices don't form a
|
|
||||||
# subcube that hull will contain non-local devices.
|
|
||||||
if not is_local_device[subcube_indices].all():
|
|
||||||
raise ValueError(
|
|
||||||
"When passing host local inputs to pjit or xmap, devices "
|
|
||||||
"connected to a single host must form a contiguous subcube of the "
|
|
||||||
"global device mesh")
|
|
||||||
return Mesh(self.devices[subcube_indices], self.axis_names)
|
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def device_ids(self):
|
def device_ids(self):
|
||||||
|
@ -28,7 +28,7 @@ from jax.sharding import Mesh
|
|||||||
from jax._src import test_util
|
from jax._src import test_util
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass(frozen=True)
|
||||||
class MockTpuDevice:
|
class MockTpuDevice:
|
||||||
"""Mock TPU device for testing."""
|
"""Mock TPU device for testing."""
|
||||||
id: int
|
id: int
|
||||||
|
Loading…
x
Reference in New Issue
Block a user