diff --git a/CHANGELOG.md b/CHANGELOG.md index c1e9fdc43..eb301ba45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * {func}`jax.nonzero` has a new optional `size` argument that allows it to be used within `jit` ({jax-issue}`#6501`) * {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`). + * {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`#6569`). * Breaking changes: * The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed @@ -32,6 +33,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * Bug fixes: * The {func}`jax2tf.convert` now works in presence of gradients for functions with integer inputs ({jax-issue}`#6360`). + * Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured + `tf.Variable` ({jax-issue}`#6572`). ## jaxlib 0.1.66 (unreleased) * New features: diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b7a9271d0..75f4f1688 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -817,9 +817,7 @@ def _outside_call_translation_rule( current_token = args_op[-2] current_itoken = args_op[-1] # TODO: expose shape.is_token - assert not comp.get_shape(current_token).is_array() and not comp.get_shape(current_token).is_array(), ( - "The last two arguments must be tokens") - assert not comp.get_shape(current_itoken).is_array() and not comp.get_shape(current_itoken).is_array(), ( + assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), ( "The last two arguments must be tokens") args_to_outfeed = args_op[:-2] diff --git a/jax/tools/jax_to_hlo.py b/jax/tools/jax_to_hlo.py index 3985dab4e..454125888 100644 --- a/jax/tools/jax_to_hlo.py +++ b/jax/tools/jax_to_hlo.py @@ -105,7 +105,7 @@ def jax_to_hlo(fn, input_shapes, constants=None): for arg_name, shape in input_shapes: if not shape.is_array(): raise ValueError('Shape %s is not an array, but currently only arrays ' - 'are supported (i.e., no tuples).' % str(shape)) + 'are supported (i.e., no tuples, nor tokens).' % str(shape)) # Check that `shape` either doesn't have a layout or has the default layout. #