Enable use of GlobalDeviceArray (GDA) in T5X Checkpointer. Add a separate unit test, gda_checkpoints_test, to cover this use case.

GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere.

Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer.

PiperOrigin-RevId: 439358913
This commit is contained in:
Colin Gaffney 2022-04-04 10:55:31 -07:00 committed by jax authors
parent 1b8be90801
commit 41b6e00141
2 changed files with 53 additions and 1 deletions

View File

@ -64,6 +64,13 @@ def _get_metadata(gda):
}
def _spec_has_metadata(tree):
if not isinstance(tree, dict):
return False
return 'metadata' in tree or any(
_spec_has_metadata(subtree) for _, subtree in tree.items())
def get_tensorstore_spec(ckpt_path: str):
spec = {'driver': 'zarr', 'kvstore': {}}
@ -82,7 +89,9 @@ def get_tensorstore_spec(ckpt_path: str):
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
if not tensorstore_spec.get('metadata'):
# 'metadata' may not be present at the top level (for example, if we are using
# a 'cast' driver).
if not _spec_has_metadata(tensorstore_spec):
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
t = await ts.open(

View File

@ -139,5 +139,48 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.local_shards:
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
def test_spec_has_metadata(self):
spec = {
'a': {
'b': 1,
'c': 2,
},
'd': 3,
'e': {
'a': 2,
'metadata': 3
},
'f': 4
}
self.assertTrue(serialization._spec_has_metadata(spec))
self.assertTrue(
serialization._spec_has_metadata({
'driver': 'zarr',
'kvstore': 'gfile',
'metadata': {
'chunks': 4,
'shape': (32, 64)
},
'one_more': 'thing'
}))
def test_spec_has_no_metadata(self):
spec = {
'a': {
'b': 1,
'c': 2,
},
'd': 3,
'e': {
'a': 2,
},
'f': 4
}
self.assertFalse(serialization._spec_has_metadata(spec))
def test_empty_spec_has_no_metadata(self):
spec = {}
self.assertFalse(serialization._spec_has_metadata(spec))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())