Fix cache misses when re-creating equivalent mesh objects

The `Mesh` class was missing `__eq__` and `__hash__` and inherited the
(bad) Python defaults of comparison and hashing by identity.

PiperOrigin-RevId: 369407380
This commit is contained in:
Adam Paszke 2021-04-20 03:48:07 -07:00 committed by jax authors
parent 14acd070c2
commit 93c63d0341
3 changed files with 54 additions and 11 deletions

View File

@ -1254,16 +1254,33 @@ mesh devices ndarray would have to be transposed before flattening and assignmen
ArrayMapping = OrderedDictType[MeshAxisName, int]
class Mesh:
__slots__ = ('devices', 'axis_names')
__slots__ = ('devices', 'axis_names', '_hash')
def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
assert devices.ndim == len(axis_names)
# TODO: Make sure that devices are unique? At least with the quick and
# dirty check that the array size is not larger than the number of
# available devices?
self.devices = devices
self.devices = devices.copy()
self.devices.flags.writeable = False
self.axis_names = tuple(axis_names)
def __eq__(self, other):
if not isinstance(other, Mesh):
return False
return (self.axis_names == other.axis_names and
np.array_equal(self.devices, other.devices))
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self.axis_names, tuple(self.devices.flat)))
return self._hash
def __setattr__(self, name, value):
if hasattr(self, name):
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
super().__setattr__(name, value)
@property
def shape(self):
return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape))

View File

@ -206,6 +206,26 @@ class PJitTest(jtu.BufferDonationTestCase):
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
def testCaching(self):
def f(x):
assert should_be_tracing
return jnp.sin(x) * 2
x = np.arange(16).reshape(4, 4)
devices = np.array(list(jax.local_devices())[:4])
if devices.size < 4:
raise SkipTest("Test requires 4 devices")
devices = devices.reshape((2, 2))
with mesh(devices, ('x', 'y')):
should_be_tracing = True
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
should_be_tracing = False
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
# Re-create the mesh to make sure that has no influence on caching
with mesh(devices, ('x', 'y')):
should_be_tracing = False
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
# TODO(skye): add more unit tests once API is more finalized
@curry

View File

@ -362,19 +362,25 @@ class XMapTest(XMapTestCase):
self.assertAllClose(run({'i': 'x'}), run({'i': 'y'}))
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testCompilationCache(self):
def testCaching(self):
def f(x):
assert python_should_be_executing
return x * 2
fm = xmap(f,
in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})
devices = np.array(jax.local_devices()[:2])
if devices.size < 2:
raise SkipTest("Test requires 2 devices")
x = np.arange(8).reshape((2, 2, 2))
python_should_be_executing = True
fm(x)
python_should_be_executing = False
fm(x)
with mesh(devices, ('x',)):
python_should_be_executing = True
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(x)
python_should_be_executing = False
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(x)
with mesh(devices, ('x',)):
python_should_be_executing = False
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(x)
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}