mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Abstract the array_serialization error message to a global variable so that it can be overridden.
PiperOrigin-RevId: 561224461
This commit is contained in:
parent
e785f89470
commit
6b574708ee
@ -41,6 +41,9 @@ _REMOVED_VALUE = 'Value removed'
|
||||
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_module_unique_count = itertools.count()
|
||||
_DEFAULT_DRIVER = 'file'
|
||||
_DISTRIBUTED_SYSTEM_MSG = (
|
||||
'Please initialize the distributed system via '
|
||||
'`jax.distributed.initialize()` at the start of your program.')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -392,9 +395,7 @@ class AsyncManager:
|
||||
self._exception = None
|
||||
|
||||
if jax.process_count() > 1 and distributed.global_state.client is None:
|
||||
raise ValueError('Please initialize the distributed system via '
|
||||
'`jax.distributed.initialize()` at the start of your '
|
||||
'program.')
|
||||
raise ValueError(_DISTRIBUTED_SYSTEM_MSG)
|
||||
if jax.process_count() > 1:
|
||||
self._client = distributed.global_state.client
|
||||
self._count = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user