diff --git a/CHANGELOG.md b/CHANGELOG.md index e8f8d7c6d..33d4a2e64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit. + * `jax.experimental.host_callback` has been deprecated since March 2024, with + JAX version 0.4.26. Now we set the default value of the + `--jax_host_callback_legacy` configuration value to `True`, which means that + if your code uses `jax.experimental.host_callback` APIs, those API calls + will be implemented in terms of the new `jax.experimental.io_callback` API. + If this breaks your code, for a very limited time, you can set the + `--jax_host_callback_legacy` to `True`. Soon we will remove that + configuration option, so you should instead transition to using the + new JAX callback APIs. See {jax-issue}`#20385` for a discussion. * Deprecations * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 49162809a..bc5477ebc 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -568,10 +568,10 @@ _HOST_CALLBACK_OUTFEED = config.bool_flag( ) _HOST_CALLBACK_LEGACY = config.bool_flag( 'jax_host_callback_legacy', - config.bool_env('JAX_HOST_CALLBACK_LEGACY', True), + config.bool_env('JAX_HOST_CALLBACK_LEGACY', False), help=( 'Use old implementation of host_callback, documented in the module docstring.' - 'If False, use the jax.experimental.io_callback implementation. ' + 'If False, use the new jax.experimental.io_callback implementation. ' 'See https://github.com/jax-ml/jax/issues/20385.' ) ) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 244fe6ef4..5e624ef9a 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -349,6 +349,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): ( 6.00 9.00 )""") def test_tap_eval_exception(self): + self.supported_only_in_legacy_mode() if not hcb._HOST_CALLBACK_OUTFEED.value: raise SkipTest("TODO: implement error handling for customcall") @@ -852,6 +853,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): self.assertEqual(100, count) def test_tap_jit_tap_exception(self): + self.supported_only_in_legacy_mode() if not hcb._HOST_CALLBACK_OUTFEED.value: raise SkipTest("TODO: implement error handling for customcall") # Simulate a tap error