mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix cache misses when re-creating equivalent mesh objects
The `Mesh` class was missing `__eq__` and `__hash__` and inherited the (bad) Python defaults of comparison and hashing by identity. PiperOrigin-RevId: 369407380
This commit is contained in:
parent
14acd070c2
commit
93c63d0341
@ -1254,16 +1254,33 @@ mesh devices ndarray would have to be transposed before flattening and assignmen
|
||||
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
||||
|
||||
class Mesh:
|
||||
__slots__ = ('devices', 'axis_names')
|
||||
__slots__ = ('devices', 'axis_names', '_hash')
|
||||
|
||||
def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
|
||||
assert devices.ndim == len(axis_names)
|
||||
# TODO: Make sure that devices are unique? At least with the quick and
|
||||
# dirty check that the array size is not larger than the number of
|
||||
# available devices?
|
||||
self.devices = devices
|
||||
self.devices = devices.copy()
|
||||
self.devices.flags.writeable = False
|
||||
self.axis_names = tuple(axis_names)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Mesh):
|
||||
return False
|
||||
return (self.axis_names == other.axis_names and
|
||||
np.array_equal(self.devices, other.devices))
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((self.axis_names, tuple(self.devices.flat)))
|
||||
return self._hash
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if hasattr(self, name):
|
||||
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape))
|
||||
|
@ -206,6 +206,26 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
# Annotation from pjit
|
||||
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
||||
|
||||
def testCaching(self):
|
||||
def f(x):
|
||||
assert should_be_tracing
|
||||
return jnp.sin(x) * 2
|
||||
|
||||
x = np.arange(16).reshape(4, 4)
|
||||
devices = np.array(list(jax.local_devices())[:4])
|
||||
if devices.size < 4:
|
||||
raise SkipTest("Test requires 4 devices")
|
||||
devices = devices.reshape((2, 2))
|
||||
with mesh(devices, ('x', 'y')):
|
||||
should_be_tracing = True
|
||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||
should_be_tracing = False
|
||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||
# Re-create the mesh to make sure that has no influence on caching
|
||||
with mesh(devices, ('x', 'y')):
|
||||
should_be_tracing = False
|
||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||
|
||||
# TODO(skye): add more unit tests once API is more finalized
|
||||
|
||||
@curry
|
||||
|
@ -362,19 +362,25 @@ class XMapTest(XMapTestCase):
|
||||
self.assertAllClose(run({'i': 'x'}), run({'i': 'y'}))
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testCompilationCache(self):
|
||||
def testCaching(self):
|
||||
def f(x):
|
||||
assert python_should_be_executing
|
||||
return x * 2
|
||||
fm = xmap(f,
|
||||
in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})
|
||||
devices = np.array(jax.local_devices()[:2])
|
||||
if devices.size < 2:
|
||||
raise SkipTest("Test requires 2 devices")
|
||||
x = np.arange(8).reshape((2, 2, 2))
|
||||
python_should_be_executing = True
|
||||
fm(x)
|
||||
python_should_be_executing = False
|
||||
fm(x)
|
||||
with mesh(devices, ('x',)):
|
||||
python_should_be_executing = True
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(x)
|
||||
python_should_be_executing = False
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(x)
|
||||
with mesh(devices, ('x',)):
|
||||
python_should_be_executing = False
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(x)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
|
Loading…
x
Reference in New Issue
Block a user