[host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False

`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
This commit is contained in:
George Necula 2024-10-01 07:06:48 -07:00 committed by jax authors
parent 0cfed4efad
commit 2228115cf4
3 changed files with 13 additions and 2 deletions

View File

@ -27,6 +27,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
`--jax_host_callback_legacy` configuration value to `True`, which means that
if your code uses `jax.experimental.host_callback` APIs, those API calls
will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the
`--jax_host_callback_legacy` to `True`. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See {jax-issue}`#20385` for a discussion.
* Deprecations
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike

View File

@ -568,10 +568,10 @@ _HOST_CALLBACK_OUTFEED = config.bool_flag(
)
_HOST_CALLBACK_LEGACY = config.bool_flag(
'jax_host_callback_legacy',
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True),
config.bool_env('JAX_HOST_CALLBACK_LEGACY', False),
help=(
'Use old implementation of host_callback, documented in the module docstring.'
'If False, use the jax.experimental.io_callback implementation. '
'If False, use the new jax.experimental.io_callback implementation. '
'See https://github.com/jax-ml/jax/issues/20385.'
)
)

View File

@ -349,6 +349,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( 6.00 9.00 )""")
def test_tap_eval_exception(self):
self.supported_only_in_legacy_mode()
if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall")
@ -852,6 +853,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertEqual(100, count)
def test_tap_jit_tap_exception(self):
self.supported_only_in_legacy_mode()
if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error