Remove option to use StreamExecutor Cloud TPU client in JAX

It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023).

PiperOrigin-RevId: 554935166
This commit is contained in:
Skye Wanderman-Milne 2023-08-08 14:04:34 -07:00 committed by jax authors
parent f05f197874
commit 3e50fea29e
2 changed files with 7 additions and 8 deletions

View File

@ -24,6 +24,7 @@ Remember to align the itemized text with the first line of an item within a list
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
* `jax.jaxpr_util` has been removed from the public JAX namespace.
* `JAX_USE_PJRT_C_API_ON_TPU` no longer has an effect (i.e. it always defaults to true).
* Internal deprecations:
* The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype`

View File

@ -69,12 +69,10 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'
if 'JAX_USE_PJRT_C_API_ON_TPU' not in os.environ:
os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = 'true'
use_pjrt_c_api = os.environ['JAX_USE_PJRT_C_API_ON_TPU']
if use_pjrt_c_api in ("false", "0"):
# TODO(skyewm): remove this warning at some point, say around Sept 2023.
use_pjrt_c_api = os.environ.get('JAX_USE_PJRT_C_API_ON_TPU', None)
if use_pjrt_c_api:
warnings.warn(
f"JAX_USE_PJRT_C_API_ON_TPU={use_pjrt_c_api} will no longer be "
"supported in an upcoming future release. Please file an issue at "
"https://github.com/google/jax/issues if you need this setting.")
"JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (the new TPU "
"runtime is always enabled now). Unset the environment variable "
"to disable this warning.")