[host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False

`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
This commit is contained in:
George Necula 2024-10-01 07:06:48 -07:00 committed by jax authors
parent 0cfed4efad
commit 2228115cf4
3 changed files with 13 additions and 2 deletions

View File

@ -27,6 +27,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
that directly accesses shards accordingly. The rank of the per-shard-shape that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit. now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit. This avoids costly reshapes when passing results from pmap into jit.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
`--jax_host_callback_legacy` configuration value to `True`, which means that
if your code uses `jax.experimental.host_callback` APIs, those API calls
will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the
`--jax_host_callback_legacy` to `True`. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See {jax-issue}`#20385` for a discussion.
* Deprecations * Deprecations
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike

View File

@ -568,10 +568,10 @@ _HOST_CALLBACK_OUTFEED = config.bool_flag(
) )
_HOST_CALLBACK_LEGACY = config.bool_flag( _HOST_CALLBACK_LEGACY = config.bool_flag(
'jax_host_callback_legacy', 'jax_host_callback_legacy',
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True), config.bool_env('JAX_HOST_CALLBACK_LEGACY', False),
help=( help=(
'Use old implementation of host_callback, documented in the module docstring.' 'Use old implementation of host_callback, documented in the module docstring.'
'If False, use the jax.experimental.io_callback implementation. ' 'If False, use the new jax.experimental.io_callback implementation. '
'See https://github.com/jax-ml/jax/issues/20385.' 'See https://github.com/jax-ml/jax/issues/20385.'
) )
) )

View File

@ -349,6 +349,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( 6.00 9.00 )""") ( 6.00 9.00 )""")
def test_tap_eval_exception(self): def test_tap_eval_exception(self):
self.supported_only_in_legacy_mode()
if not hcb._HOST_CALLBACK_OUTFEED.value: if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall") raise SkipTest("TODO: implement error handling for customcall")
@ -852,6 +853,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertEqual(100, count) self.assertEqual(100, count)
def test_tap_jit_tap_exception(self): def test_tap_jit_tap_exception(self):
self.supported_only_in_legacy_mode()
if not hcb._HOST_CALLBACK_OUTFEED.value: if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall") raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error # Simulate a tap error