set tensorstore settings in cloud_tpu that make serialization more robust

This commit is contained in:
David Hall 2024-04-26 14:19:06 -07:00
parent 8c2425e571
commit 6a891f2cd9
2 changed files with 4 additions and 3 deletions

View File

@ -68,3 +68,7 @@ def cloud_tpu_init() -> None:
os.environ['TPU_ML_PLATFORM'] = 'JAX'
if hardware_utils.tpu_enhanced_barrier_supported():
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
# this makes tensorstore serialization work better on TPU
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')

View File

@ -63,9 +63,6 @@ _BARRIER_TIMED_OUT_MSG = (
"Suggestions for possible fixes:\n"
"* Check the logs to see if one or more processes failed.\n"
"* Make sure the training and checkpointing endpoints are close geographically.\n"
"* Try setting these environment variables: "
"`TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS=60` "
"`TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES=256` which will force a http retry\n"
"* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.")
logger = logging.getLogger(__name__)