mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[host_callback] Minor fix to use the new xla_shape.is_token
This commit is contained in:
parent
2c7556e014
commit
d762ec1d21
@ -17,6 +17,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
|
||||
be used within `jit` ({jax-issue}`#6501`)
|
||||
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
|
||||
* {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`#6569`).
|
||||
* Breaking changes:
|
||||
* The following function names have changed. There are still aliases, so this
|
||||
should not break existing code, but the aliases will eventually be removed
|
||||
@ -32,6 +33,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Bug fixes:
|
||||
* The {func}`jax2tf.convert` now works in presence of gradients for functions
|
||||
with integer inputs ({jax-issue}`#6360`).
|
||||
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
|
||||
`tf.Variable` ({jax-issue}`#6572`).
|
||||
|
||||
## jaxlib 0.1.66 (unreleased)
|
||||
* New features:
|
||||
|
@ -817,9 +817,7 @@ def _outside_call_translation_rule(
|
||||
current_token = args_op[-2]
|
||||
current_itoken = args_op[-1]
|
||||
# TODO: expose shape.is_token
|
||||
assert not comp.get_shape(current_token).is_array() and not comp.get_shape(current_token).is_array(), (
|
||||
"The last two arguments must be tokens")
|
||||
assert not comp.get_shape(current_itoken).is_array() and not comp.get_shape(current_itoken).is_array(), (
|
||||
assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), (
|
||||
"The last two arguments must be tokens")
|
||||
|
||||
args_to_outfeed = args_op[:-2]
|
||||
|
@ -105,7 +105,7 @@ def jax_to_hlo(fn, input_shapes, constants=None):
|
||||
for arg_name, shape in input_shapes:
|
||||
if not shape.is_array():
|
||||
raise ValueError('Shape %s is not an array, but currently only arrays '
|
||||
'are supported (i.e., no tuples).' % str(shape))
|
||||
'are supported (i.e., no tuples, nor tokens).' % str(shape))
|
||||
|
||||
# Check that `shape` either doesn't have a layout or has the default layout.
|
||||
#
|
||||
|
Loading…
x
Reference in New Issue
Block a user