mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False
`jax.experimental.host_callback` has been deprecated since March 2024 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API. If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs. See https://github.com/google/jax/issues/20385 for a discussion. PiperOrigin-RevId: 681004255
This commit is contained in:
parent
0cfed4efad
commit
2228115cf4
@ -27,6 +27,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
that directly accesses shards accordingly. The rank of the per-shard-shape
|
||||
now matches that of the global shape which is the same behavior as jit.
|
||||
This avoids costly reshapes when passing results from pmap into jit.
|
||||
* `jax.experimental.host_callback` has been deprecated since March 2024, with
|
||||
JAX version 0.4.26. Now we set the default value of the
|
||||
`--jax_host_callback_legacy` configuration value to `True`, which means that
|
||||
if your code uses `jax.experimental.host_callback` APIs, those API calls
|
||||
will be implemented in terms of the new `jax.experimental.io_callback` API.
|
||||
If this breaks your code, for a very limited time, you can set the
|
||||
`--jax_host_callback_legacy` to `True`. Soon we will remove that
|
||||
configuration option, so you should instead transition to using the
|
||||
new JAX callback APIs. See {jax-issue}`#20385` for a discussion.
|
||||
|
||||
* Deprecations
|
||||
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
|
||||
|
@ -568,10 +568,10 @@ _HOST_CALLBACK_OUTFEED = config.bool_flag(
|
||||
)
|
||||
_HOST_CALLBACK_LEGACY = config.bool_flag(
|
||||
'jax_host_callback_legacy',
|
||||
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True),
|
||||
config.bool_env('JAX_HOST_CALLBACK_LEGACY', False),
|
||||
help=(
|
||||
'Use old implementation of host_callback, documented in the module docstring.'
|
||||
'If False, use the jax.experimental.io_callback implementation. '
|
||||
'If False, use the new jax.experimental.io_callback implementation. '
|
||||
'See https://github.com/jax-ml/jax/issues/20385.'
|
||||
)
|
||||
)
|
||||
|
@ -349,6 +349,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
( 6.00 9.00 )""")
|
||||
|
||||
def test_tap_eval_exception(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
if not hcb._HOST_CALLBACK_OUTFEED.value:
|
||||
raise SkipTest("TODO: implement error handling for customcall")
|
||||
|
||||
@ -852,6 +853,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(100, count)
|
||||
|
||||
def test_tap_jit_tap_exception(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
if not hcb._HOST_CALLBACK_OUTFEED.value:
|
||||
raise SkipTest("TODO: implement error handling for customcall")
|
||||
# Simulate a tap error
|
||||
|
Loading…
x
Reference in New Issue
Block a user