diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ea1ade65..bd721d66f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Remember to align the itemized text with the first line of an item within a list * The option `--jax_coordination_service` has been removed. It is now always `True`. * `jax.jaxpr_util` has been removed from the public JAX namespace. + * `JAX_USE_PJRT_C_API_ON_TPU` no longer has an effect (i.e. it always defaults to true). * Internal deprecations: * The internal utilities `jax.core.is_opaque_dtype` and `jax.core.has_opaque_dtype` diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 7ecd13308..e1cb601b1 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -69,12 +69,10 @@ def cloud_tpu_init() -> None: os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' - if 'JAX_USE_PJRT_C_API_ON_TPU' not in os.environ: - os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = 'true' - - use_pjrt_c_api = os.environ['JAX_USE_PJRT_C_API_ON_TPU'] - if use_pjrt_c_api in ("false", "0"): + # TODO(skyewm): remove this warning at some point, say around Sept 2023. + use_pjrt_c_api = os.environ.get('JAX_USE_PJRT_C_API_ON_TPU', None) + if use_pjrt_c_api: warnings.warn( - f"JAX_USE_PJRT_C_API_ON_TPU={use_pjrt_c_api} will no longer be " - "supported in an upcoming future release. Please file an issue at " - "https://github.com/google/jax/issues if you need this setting.") + "JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (the new TPU " + "runtime is always enabled now). Unset the environment variable " + "to disable this warning.")