diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 4b89cf12d..e12080027 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -68,3 +68,7 @@ def cloud_tpu_init() -> None: os.environ['TPU_ML_PLATFORM'] = 'JAX' if hardware_utils.tpu_enhanced_barrier_supported(): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" + + # this makes tensorstore serialization work better on TPU + os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') + os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256') diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 585fe02cd..e94ecaeaf 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -63,9 +63,6 @@ _BARRIER_TIMED_OUT_MSG = ( "Suggestions for possible fixes:\n" "* Check the logs to see if one or more processes failed.\n" "* Make sure the training and checkpointing endpoints are close geographically.\n" - "* Try setting these environment variables: " - "`TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS=60` " - "`TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES=256` which will force a http retry\n" "* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.") logger = logging.getLogger(__name__)