[host_callback] Minor fix to use the new xla_shape.is_token

This commit is contained in:
George Necula 2021-04-28 12:22:32 +03:00
parent 2c7556e014
commit d762ec1d21
3 changed files with 5 additions and 4 deletions

View File

@ -17,6 +17,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`#6501`)
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
* {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`#6569`).
* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
@ -32,6 +33,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* Bug fixes:
* The {func}`jax2tf.convert` now works in presence of gradients for functions
with integer inputs ({jax-issue}`#6360`).
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
`tf.Variable` ({jax-issue}`#6572`).
## jaxlib 0.1.66 (unreleased)
* New features:

View File

@ -817,9 +817,7 @@ def _outside_call_translation_rule(
current_token = args_op[-2]
current_itoken = args_op[-1]
# TODO: expose shape.is_token
assert not comp.get_shape(current_token).is_array() and not comp.get_shape(current_token).is_array(), (
"The last two arguments must be tokens")
assert not comp.get_shape(current_itoken).is_array() and not comp.get_shape(current_itoken).is_array(), (
assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), (
"The last two arguments must be tokens")
args_to_outfeed = args_op[:-2]

View File

@ -105,7 +105,7 @@ def jax_to_hlo(fn, input_shapes, constants=None):
for arg_name, shape in input_shapes:
if not shape.is_array():
raise ValueError('Shape %s is not an array, but currently only arrays '
'are supported (i.e., no tuples).' % str(shape))
'are supported (i.e., no tuples, nor tokens).' % str(shape))
# Check that `shape` either doesn't have a layout or has the default layout.
#