mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add gcs support to serialization/deserialization of GDA
PiperOrigin-RevId: 415356134
This commit is contained in:
parent
4c3b1d2f00
commit
bb04cf08db
@ -14,6 +14,7 @@
|
||||
"""GlobalDeviceArray serialization and deserialization."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
@ -58,13 +59,20 @@ def _get_metadata(gda):
|
||||
}
|
||||
|
||||
|
||||
def get_tensorstore_spec(ckpt_path):
|
||||
def get_tensorstore_spec(ckpt_path: str):
|
||||
spec = {'driver': 'zarr', 'kvstore': {}}
|
||||
# TODO(yashkatariya): Add GCS kvstore too.
|
||||
spec['kvstore'] = {
|
||||
'driver': 'file',
|
||||
'path': ckpt_path,
|
||||
}
|
||||
|
||||
if ckpt_path.startswith('gs://'):
|
||||
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
|
||||
if m is None:
|
||||
raise ValueError('The ckpt_path should contain the bucket name and the '
|
||||
f'file path inside the bucket. Got: {ckpt_path}')
|
||||
gcs_bucket = m.group(1)
|
||||
path_without_bucket = m.group(2)
|
||||
spec['kvstore'] = {'driver': 'gcs', 'bucket': gcs_bucket,
|
||||
'path': path_without_bucket}
|
||||
else:
|
||||
spec['kvstore'] = {'driver': 'file', 'path': ckpt_path}
|
||||
return spec
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user