mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
set tensorstore settings in cloud_tpu that make serialization more robust
This commit is contained in:
parent
8c2425e571
commit
6a891f2cd9
@ -68,3 +68,7 @@ def cloud_tpu_init() -> None:
|
||||
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
||||
if hardware_utils.tpu_enhanced_barrier_supported():
|
||||
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
|
||||
|
||||
# this makes tensorstore serialization work better on TPU
|
||||
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
|
||||
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')
|
||||
|
@ -63,9 +63,6 @@ _BARRIER_TIMED_OUT_MSG = (
|
||||
"Suggestions for possible fixes:\n"
|
||||
"* Check the logs to see if one or more processes failed.\n"
|
||||
"* Make sure the training and checkpointing endpoints are close geographically.\n"
|
||||
"* Try setting these environment variables: "
|
||||
"`TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS=60` "
|
||||
"`TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES=256` which will force a http retry\n"
|
||||
"* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
Loading…
x
Reference in New Issue
Block a user