[shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program

This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.

The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.

The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.

This change should be enough for old-style jax2tf, but will require more
work for native serialization.

We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.

PiperOrigin-RevId: 515137407
This commit is contained in:
George Necula 2023-03-08 14:09:35 -08:00 committed by jax authors
parent 942e79ffe3
commit 961e09e614
5 changed files with 305 additions and 129 deletions

View File

@ -11,13 +11,17 @@ Remember to align the itemized text with the first line of an item within a list
* Changes
* `jax.tree_util` now contain a set of APIs that allow user to define keys for their
custom pytree node. This includes:
* `tree_flatten_with_path` that flattens a tree and return not only each leaf but
* `tree_flatten_with_path` that flattens a tree and return not only each leaf but
also their key paths.
* `tree_map_with_paths` that can map a function that takes the key path as argument.
* `register_pytree_with_keys`` to register how the key path and leaves should looks
* `register_pytree_with_keys`` to register how the key path and leaves should looks
like in a custom pytree node.
* `keystr` that pretty-prints a key path.
* {func}`jax2tf.call_tf` has a new parameter `output_shape_dtype` (default `None`)
that can be used to declare the output shape and type of the result. This enables
{func}`jax2tf.call_tf` to work in the presence of shape polymorphism. ({jax-issue}`#14734`).
* Deprecations
* The old key-path APIs in `jax.tree_util` are deprecated and will be removed 3 months
from Mar 10 2023:

View File

@ -1229,6 +1229,57 @@ DeviceArray data or for np.ndarray that are aligned on 16-byte
boundaries) and on GPU (for DeviceArray).
The zero-copy does not yet work on TPU.
`call_tf` works even with shape polymorphism, but in that case
the user must pass the `output_shape_dtype` parameter to `call_tf` to declare
the expected output shapes. This allows JAX tracing to know the shape and
dtype of the results so that it can continue tracing the rest of the program.
When `output_shape_dtype` is not given (the default case), `call_tf` will
form a `tf.Graph` for the called TF function and will use the inferred
type and shape. However, in presence of dynamic shape the inferred TF
type will contain `None` for the dynamic dimensions, which is not enough
information for JAX shape polymorphism.
For example:
```python
def fun_jax(x):
y_shape = (x.shape[0] * 2, y.shape[1:])
y = jax2tf.call_tf(
lambda x: tf.concat([x, x], axis=0),
output_shape_dype=jax.ShapeDtypeStruct(y_shape, x.dtype))(x)
# JAX will know the y.shape
return jnp.ones(y.shape, dtype=y.dtype) + y
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
An even simpler example for a function that returns the same shape as the input:
```python
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin,
output_shape_dtype=x)
)(x)
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
If all the output shapes of the TF function are static, JAX does not need the
`output_shape_dtype` argument:
```python
def fun_tf(x):
return tf.math.reduce_sum(tf.math.sin(x))
def fun_jax(x):
return jax2tf.call_tf(fun_tf)(x)
# The following will not throw an error because the output shape of fun_tf is static.
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
The shape polymorphism support for `call_tf` does not yet work for native lowering.
### Limitations of call_tf
The TF function must be compileable (`tf.function(func, jit_compile=True)`)
@ -1312,38 +1363,14 @@ JAX computation runs on TPU. This will fail if the computation captures
variables on some other devices. It is best to use ``call_tf``
with TF functions that do not capture variables.
A TF function wrapped with `call_tf` cannot be applied to inputs whose
shapes are not constants, unless all the output shapes of the TF function
are static. The may arise when you try to apply `jax2tf.convert` with
polymorphic shapes on the result of `call_tf`:
```python
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin)(x)
# The following will throw an error.
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
This is unsatisfying, because the result of the above conversion
could be simply `tf.math.sin`, which is batch polymorphic. But
JAX cannot keep track of shapes through a `call_tf` call, and it
cannot be sure that the shape-polymorphic conversion is safe.
If all the output shapes of the TF function are static, JAX does not need to
keep track of shapes after a `call_tf` call, hence allows shape-polymorphic
inputs in such cases:
```python
def fun_tf(x):
return tf.math.reduce_sum(tf.math.sin(x))
def fun_jax(x):
return jax2tf.call_tf(fun_tf)(x)
# The following will not throw an error because the output shape of fun_tf is static.
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
In some rare cases your called TF function may contain ops with output
of statically known shape, but for which the shape inference is not implemented
completely and will appear to `call_tf` as if they have dynamically-shaped
outputs. In these cases you may get an error that
`call_tf cannot call functions whose output has dynamic shape`. Try using
the `output_shape_dtype` parameter to specify the expected output shape
(this essentially allows you to override the shape inference for the
purposes of `call_tf`.)
# Misc notes

View File

@ -55,13 +55,15 @@ map = util.safe_map
zip = util.safe_zip
TfConcreteFunction = Any
TfVal = jax2tf_internal.TfVal
# The platforms for which to use DLPack to avoid copying (only works on GPU
# and CPU at the moment, and only for DeviceArray). For CPU we don't need
# DLPack, if we are careful.
_DLPACK_PLATFORMS = ("gpu",)
def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
def call_tf(callable_tf: Callable, has_side_effects=True,
output_shape_dtype=None) -> Callable:
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
@ -90,6 +92,14 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
has_side_effects: if True then it ensures that instances of this primitive
are not removed or replicated by JAX optimizations such as dead-code
elimination.
output_shape_dtype: An optional declaration of the expected shapes and dtypes
from the called TensorFlow function. If given it will be used during JAX
tracing to form the abstract values of the results of the `call_tf`. If
not given then we form a `tf.Graph` for the called TensorFlow function and
we use the TensorFlow-inferred shapes and types. Must be a pytree matching the
structure of the nested structure returned from the TensorFlow function,
containing objects with `.shape` and `.dtype` attributes,
e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
op-by-op mode or in a staged context. This callable can be used with
@ -113,67 +123,58 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
def make_tensorspec(a_jax):
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
a_tf_shape = [
d if core.is_constant_dim(d) else None for d in a_jax.shape
]
d if core.is_constant_dim(d) else None for d in a_jax.shape]
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
def check_tf_result(r_tf):
# Check that the TF function returns values of expected types. This
# improves error reporting, preventing hard-to-diagnose errors downstream
try:
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
except Exception as e:
msg = ("The called TF function returns a result that is not "
f"convertible to JAX: {r_tf}.")
raise ValueError(msg) from e
if output_shape_dtype is not None:
output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
else:
output_avals, output_shape_dtype_tree = None, None
res_treedef = None # We'll store here the result treedef
res_tf_flat = None # For error reporting
# The function below will be called at least once, either in eager
# or in graph mode.
# mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
args_tf = args_treedef.unflatten(args_tf_flat)
res_tf = callable_tf(*args_tf)
nonlocal res_treedef, res_tf_flat
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
for r_tf in res_tf_flat:
check_tf_result(r_tf)
assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}"
assert res_treedef is None or res_treedef == res_treedef_now, (
f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
res_treedef = res_treedef_now
return res_tf_flat
if output_avals is not None:
if res_treedef != output_shape_dtype_tree:
raise ValueError(
"The pytree of the TensorFlow function results does not match the "
"pytree of the declared output_shape_dtype:\n"
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
assert len(output_avals) == len(res_tf_flat)
checked_res_tf_flat = [
check_tf_result(i, r_tf, r_aval)
for i, (r_tf, r_aval) in enumerate(
zip(res_tf_flat,
(output_avals if output_avals is not None
else (None,) * len(res_tf_flat))))]
return checked_res_tf_flat
# Prepare a tf.function ahead of time, to cache the concrete functions. This
# won't be used in op-by-op execution mode.
function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True)
input_shapes_tf = [s.shape for s in args_flat_sig_tf]
output_shapes_tf = _get_concrete_function_tf(
function_flat_tf, args_flat_sig_tf
).output_shapes
if not all(s.is_fully_defined() for s in input_shapes_tf) and not all(
s.is_fully_defined() for s in output_shapes_tf
):
for a_jax, a_tf_shape in zip(args_flat_jax, input_shapes_tf):
if not a_tf_shape.is_fully_defined():
msg = (
"call_tf cannot be applied to shape-polymorphic arguments unless"
" all the output shapes are static. Found argument shape:"
f" {a_jax.shape}. See"
" https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
" for a discussion."
)
raise ValueError(msg)
res_jax_flat = call_tf_p.bind(
*args_flat_jax,
# Carry the actual function such that op-by-op call can call in TF eager mode.
callable_flat_tf=callable_flat_tf,
function_flat_tf=function_flat_tf,
args_flat_sig_tf=args_flat_sig_tf,
output_avals=output_avals,
has_side_effects=has_side_effects)
# We must have called callable_flat_tf by nοw
assert res_treedef is not None
# Sometimes, in compiled mode, we get a different number of results than we
# got when tracing the TF function (and building the res_treedef). This
@ -248,6 +249,44 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
return util.wraps(callable_tf)(make_call)
def check_tf_result(idx: int, r_tf: TfVal, r_aval: Optional[core.ShapedArray]) -> TfVal:
# Check that the TF function returns values of expected types. This
# improves error reporting, preventing hard-to-diagnose errors downstream
try:
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
except Exception as e:
msg = ("The called TF function returns a result that is not "
f"convertible to JAX: {r_tf}.")
raise ValueError(msg) from e
if r_aval is None:
return r_tf
# We convert to TF type, and canonicalize to 32-bit if necessary
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
# check at runtime that the returned shape matches the declared shape. I wish
# that tf.ensure_shape did this, but it can only take shapes that contain None
# not computed shapes. However, in eager mode we should be able to resolve
# the declared shapes to constants and we get better checking.
if tf.executing_eagerly():
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
else:
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
# We do as much checking as we can here, instead of relying on tf.ensure_shape
# because the latter gives different errors in eager vs. compiled mode.
if (r_tf.dtype != r_aval_dtype_tf or
len(r_tf.shape) != len(r_aval_shape_tf) or
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
msg = ("The shapes or dtypes returned by the TensorFlow function "
"do not match the declared output_shape_dtype:\n"
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
raise ValueError(msg)
# At this point tf.ensure_shape does not do much, it should never throw an
# error, albeit it may refine the shape a bit.
return tf.ensure_shape(r_tf, r_aval_shape_tf)
call_tf_p = core.Primitive("call_tf")
call_tf_p.multiple_results = True
@ -309,39 +348,42 @@ effects.remat_allowed_effects.add_type(CallTfEffect)
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
def _call_tf_abstract_eval(*_,
def _call_tf_abstract_eval(*args_flat_avals,
function_flat_tf,
args_flat_sig_tf,
has_side_effects, **__):
has_side_effects,
output_avals, **__):
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
effects = {call_tf_effect} if has_side_effects else set()
# If not output_avals is given, then we ask TF to infer the output shapes.
# We call this even if output_avals is given because it will ensure that
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
# there is a small cost of calling it more often than needed.
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
args_flat_sig_tf)
if output_avals is not None:
return output_avals, effects
def is_fully_known_shape(s):
return s.rank is not None and all([d is not None for d in s])
effects = {call_tf_effect} if has_side_effects else set()
if all([is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes]):
return (
tuple([
# We convert to JAX type, and canonicalize to 32-bit if necessary
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes)
]),
effects)
# There are some cases when TF shape inference is not powerful enough to
# figure out the output shapes (e.g., b/128924522), even in situations where
# XLA can compile the code, from which we can get the shapes.
if all(is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes):
avals_from_tf = tuple(
# We convert to JAX type, and canonicalize to 32-bit if necessary
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes))
return avals_from_tf, effects
# We use the "cpu" as the platform, since JAX abstract eval is not platform
# specific; the "cpu" backend is always available and for abstract evaluation
# it should not matter which platform we use.
_, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf,
"CPU")
return tuple(result_avals), effects
msg = ("call_tf cannot call functions whose output has dynamic shape. "
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
"Consider using the `output_shape_dtype` argument to call_tf. "
"\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
" for a discussion.")
raise ValueError(msg)
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
@ -372,6 +414,11 @@ def _code_generator_and_avals(
) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]],
Sequence[ir.Value]]],
Sequence[core.ShapedArray]]:
# TODO(necula): we have refactored the code to not need to lower the code
# just in order to get the avals, so in fact the returned avals from this
# function are never used. We keep it here for now in case we detect
# a regressions, but if not we should simplify this function.
# Returns and caches a code generator (taking a builder and the
# XlaOps for the arguments) and a sequence of result abstract shapes.
@ -478,8 +525,7 @@ def _register_call_lowering(platform):
for platform in ("cpu", "cuda", "tpu"):
_register_call_lowering(platform)
# Support the call_tf under jax2tf.convert
TfVal = jax2tf_internal.TfVal
# Support the call_tf under jax2tf.convert in eager mode
def _jax2tf_call_tf(*args: TfVal,
callable_flat_tf: Callable,
**_) -> TfVal:

View File

@ -335,8 +335,7 @@ def convert(fun_jax: Callable,
_thread_local_state.tf_outer_name_scope = tf.get_current_name_scope()
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean(
) and not _thread_local_state.inside_call_tf:
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError(
"convert must be used outside all JAX transformations." +

View File

@ -44,6 +44,12 @@ def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
else:
return func
def _maybe_tf_jit(with_jit: bool, func: Callable) -> Callable:
if with_jit:
return tf.function(func, autograph=False, jit_compile=True)
else:
return func
def _named_test(**kwargs):
return dict(kwargs,
testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]))
@ -53,8 +59,7 @@ _parameterized_jit = parameterized.named_parameters(
for with_jit in [True, False])
_call_tf_non_compileable_error = "Error compiling TensorFlow function. call_tf can used in a staged context .* only with compileable functions"
_call_tf_dynamic_shape_error = "Compiled TensorFlow function has dynamic output shape.* call_tf can used in a staged context .* only with compileable functions"
_call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape"
class CallTfTest(tf_test_util.JaxToTfTestCase):
@ -171,8 +176,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
x = np.array([True, False], dtype=np.bool_)
self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) # Works in eager mode
with self.assertRaisesRegex(ValueError,
_call_tf_dynamic_shape_error):
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
jax.jit(f_jax)(x)
def test_error_bad_result_tensorarray(self):
@ -569,9 +573,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(x[0:x[1]], res1)
# Now under jit, should fail because the function is not compileable
with self.assertRaisesRegex(
ValueError, "Compiled TensorFlow function has dynamic output shape"
):
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
fun_jax(x)
@ -1099,29 +1101,143 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(expected, res1, check_dtypes=False)
# Now under jit, should fail because the function is not compileable
with self.assertRaisesRegex(ValueError,
_call_tf_dynamic_shape_error):
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
fun_jax(x)
# TODO(necula): this should work in op-by-op mode, but it fails because
# jax2tf.convert does abstract evaluation.
with self.assertRaisesRegex(ValueError,
_call_tf_dynamic_shape_error):
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
fun_tf_rt(x)
def test_shape_polymorphism_error(self):
x = np.array([.7, .8], dtype=np.float32)
@_parameterized_jit
def test_shape_poly_static_output_shape(self, with_jit=True):
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.")
x = np.array([0.7, 0.8], dtype=np.float32)
def fun_tf(x):
return tf.math.sin(x)
return tf.math.reduce_sum(tf.math.sin(x))
fun_jax = jax2tf.call_tf(fun_tf)
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
self.assertAllClose(fun_tf(x), fun_tf_rt(x))
@_parameterized_jit
def test_shape_poly(self, with_jit=False):
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
y = jax2tf.call_tf(tf.math.sin,
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x)
z = jnp.cos(y)
w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0),
output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z)
assert w.shape[0] == 2 * x.shape[0]
return w
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
res_tf = fun_tf_rt(x)
self.assertAllClose(fun_jax(x), res_tf)
@_parameterized_jit
def test_shape_poly_pytree_result(self, with_jit=True):
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
# Returns a tuple
y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)),
output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x)
assert y[0].shape[0] == x.shape[0]
assert y[1].shape[0] == 2 * x.shape[0]
return y
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
res_tf = fun_tf_rt(x)
self.assertAllClose(fun_jax(x), res_tf)
@_parameterized_jit
def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True):
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin)(x)
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
fun_tf_rt(x)
@_parameterized_jit
def test_shape_poly_error_mismatch_output_shape_dtype_tree(self, with_jit=False):
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin,
output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape, x.dtype)))(x)
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])
with self.assertRaisesRegex(
ValueError, "call_tf cannot be applied to shape-polymorphic arguments"
):
ValueError,
"The pytree of the TensorFlow function results does not match the pytree of the declared output_shape_dtype"):
fun_tf_rt(x)
@parameterized.named_parameters(
_named_test(with_jit=with_jit, kind=kind)
for with_jit in [True, False]
for kind in ["bad_rank", "bad_dim", "bad_dtype", "bad_dtype_x64"])
def test_shape_poly_error_mismatch_output_shape_dtype(self, with_jit=False, kind="bad_rank"):
x = np.array([7, 8, 9, 10], dtype=np.float32)
if kind == "bad_rank":
def fun_jax(x):
return jax2tf.call_tf(lambda x: x,
# Wrong shape rank
output_shape_dtype=jax.ShapeDtypeStruct((), x.dtype))(x)
elif kind == "bad_dim":
def fun_jax(x):
bad_shape = (5 + x.shape[0],)
y = jax2tf.call_tf(lambda x: x,
# Wrong dimension
output_shape_dtype=jax.ShapeDtypeStruct(bad_shape, x.dtype))(x)
# JAX will believe that the following is Ok, leading to downstream error in TF
return y + jnp.ones(bad_shape, dtype=x.dtype)
elif kind == "bad_dtype":
def fun_jax(x):
return jax2tf.call_tf(lambda x: x,
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.int32))(x)
elif kind == "bad_dtype_x64":
def fun_jax(x):
return jax2tf.call_tf(lambda x: x * np.float64(3.),
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.float64))(x)
else:
assert False
expect_ex = ValueError
expect_error = r"The shapes or dtypes returned by the TensorFlow function do not match the declared output_shape_dtype"
# Call without shape polymorphism
fun_tf_rt = _maybe_tf_jit(with_jit, jax2tf.convert(fun_jax))
with self.assertRaisesRegex(expect_ex, expect_error):
fun_tf_rt(x)
# Now with shape polymorphism
if kind == "bad_dim" and with_jit:
# TODO: in jit more the error pops up later, at AddV2
expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
if kind == "bad_dim" and config.jax2tf_default_experimental_native_lowering:
# TODO(b/268386622): call_tf with shape polymorphism and native lowering.
expect_error = "Error compiling TensorFlow function. call_tf can used .* only with compileable functions with static output shapes"
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
with self.assertRaisesRegex(expect_ex, expect_error):
fun_tf_rt(x)
def test_inner_native_lowering(self):
@ -1142,22 +1258,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
self.assertIn('op: "XlaCallModule"', f_outer_graph)
self.assertNotIn('op: "Sin"', f_outer_graph)
@_parameterized_jit
def test_shape_polymorphism_static_output_shape(self, with_jit=True):
# TODO(b/268386622) Dynamic shapes not yet supported.
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("Skip test because of dynamic shapes.")
x = np.array([0.7, 0.8], dtype=np.float32)
def fun_tf(x):
return tf.math.reduce_sum(tf.math.sin(x))
fun_jax = jax2tf.call_tf(fun_tf)
fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])
if with_jit:
fun_tf_rt = tf.function(jit_compile=True, autograph=False)(fun_tf_rt)
self.assertAllClose(fun_tf(x), fun_tf_rt(x))
@parameterized.named_parameters(
_named_test(f2_function=f2_function, f2_saved_model=f2_saved_model,
f4_function=f4_function, f4_saved_model=f4_saved_model)