mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Enable use of GlobalDeviceArray
(GDA) in T5X Checkpointer. Add a separate unit test, gda_checkpoints_test
, to cover this use case.
GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere. Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer. PiperOrigin-RevId: 439358913
This commit is contained in:
parent
1b8be90801
commit
41b6e00141
@ -64,6 +64,13 @@ def _get_metadata(gda):
|
||||
}
|
||||
|
||||
|
||||
def _spec_has_metadata(tree):
|
||||
if not isinstance(tree, dict):
|
||||
return False
|
||||
return 'metadata' in tree or any(
|
||||
_spec_has_metadata(subtree) for _, subtree in tree.items())
|
||||
|
||||
|
||||
def get_tensorstore_spec(ckpt_path: str):
|
||||
spec = {'driver': 'zarr', 'kvstore': {}}
|
||||
|
||||
@ -82,7 +89,9 @@ def get_tensorstore_spec(ckpt_path: str):
|
||||
|
||||
|
||||
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
|
||||
if not tensorstore_spec.get('metadata'):
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
# a 'cast' driver).
|
||||
if not _spec_has_metadata(tensorstore_spec):
|
||||
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
|
||||
|
||||
t = await ts.open(
|
||||
|
@ -139,5 +139,48 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.local_shards:
|
||||
self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
|
||||
|
||||
def test_spec_has_metadata(self):
|
||||
spec = {
|
||||
'a': {
|
||||
'b': 1,
|
||||
'c': 2,
|
||||
},
|
||||
'd': 3,
|
||||
'e': {
|
||||
'a': 2,
|
||||
'metadata': 3
|
||||
},
|
||||
'f': 4
|
||||
}
|
||||
self.assertTrue(serialization._spec_has_metadata(spec))
|
||||
self.assertTrue(
|
||||
serialization._spec_has_metadata({
|
||||
'driver': 'zarr',
|
||||
'kvstore': 'gfile',
|
||||
'metadata': {
|
||||
'chunks': 4,
|
||||
'shape': (32, 64)
|
||||
},
|
||||
'one_more': 'thing'
|
||||
}))
|
||||
|
||||
def test_spec_has_no_metadata(self):
|
||||
spec = {
|
||||
'a': {
|
||||
'b': 1,
|
||||
'c': 2,
|
||||
},
|
||||
'd': 3,
|
||||
'e': {
|
||||
'a': 2,
|
||||
},
|
||||
'f': 4
|
||||
}
|
||||
self.assertFalse(serialization._spec_has_metadata(spec))
|
||||
|
||||
def test_empty_spec_has_no_metadata(self):
|
||||
spec = {}
|
||||
self.assertFalse(serialization._spec_has_metadata(spec))
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user