Add gcs support to serialization/deserialization of GDA

PiperOrigin-RevId: 415356134
This commit is contained in:
Yash Katariya 2021-12-09 14:24:47 -08:00 committed by jax authors
parent 4c3b1d2f00
commit bb04cf08db

View File

@ -14,6 +14,7 @@
"""GlobalDeviceArray serialization and deserialization."""
import asyncio
import re
from typing import Callable
import jax
@ -58,13 +59,20 @@ def _get_metadata(gda):
}
def get_tensorstore_spec(ckpt_path):
def get_tensorstore_spec(ckpt_path: str):
spec = {'driver': 'zarr', 'kvstore': {}}
# TODO(yashkatariya): Add GCS kvstore too.
spec['kvstore'] = {
'driver': 'file',
'path': ckpt_path,
}
if ckpt_path.startswith('gs://'):
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
if m is None:
raise ValueError('The ckpt_path should contain the bucket name and the '
f'file path inside the bucket. Got: {ckpt_path}')
gcs_bucket = m.group(1)
path_without_bucket = m.group(2)
spec['kvstore'] = {'driver': 'gcs', 'bucket': gcs_bucket,
'path': path_without_bucket}
else:
spec['kvstore'] = {'driver': 'file', 'path': ckpt_path}
return spec