Abstract the array_serialization error message to a global variable so that it can be overridden.

PiperOrigin-RevId: 561224461
This commit is contained in:
Yash Katariya 2023-08-29 21:45:50 -07:00 committed by jax authors
parent e785f89470
commit 6b574708ee

View File

@ -41,6 +41,9 @@ _REMOVED_VALUE = 'Value removed'
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
_module_unique_count = itertools.count()
_DEFAULT_DRIVER = 'file'
_DISTRIBUTED_SYSTEM_MSG = (
'Please initialize the distributed system via '
'`jax.distributed.initialize()` at the start of your program.')
logger = logging.getLogger(__name__)
@ -392,9 +395,7 @@ class AsyncManager:
self._exception = None
if jax.process_count() > 1 and distributed.global_state.client is None:
raise ValueError('Please initialize the distributed system via '
'`jax.distributed.initialize()` at the start of your '
'program.')
raise ValueError(_DISTRIBUTED_SYSTEM_MSG)
if jax.process_count() > 1:
self._client = distributed.global_state.client
self._count = None