From 961e09e61493842f5c057a2b2086813f2af58009 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 8 Mar 2023 14:09:35 -0800 Subject: [PATCH] [shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997). See b/272154366 for more details. The use case for call_tf with shape polymorphism is when we have a JAX program that calls into TF function, and we want to serialize the JAX program with some shapes unknown. Previously this use case did not work, except in the special case when the output shape of the called TF function returns statically known shapes. The idea is that we allow the user of call_tf to specify the output shape. This can be done even in presence of shape polymorphism, by writing the output shape as an expression in terms of the input shapes. This is what other JAX primitives do, e.g., concat, so we are simply enabling call_tf to get the same behavior. This change should be enough for old-style jax2tf, but will require more work for native serialization. We also removed some old code that was trying to workaround some limitations in shape inference in TF. I think that those workarounds are ugly, and I am prepared to give error messages rather than keep that code. So far no tests fail. PiperOrigin-RevId: 515137407 --- CHANGELOG.md | 8 +- jax/experimental/jax2tf/README.md | 91 ++++++---- jax/experimental/jax2tf/call_tf.py | 166 +++++++++++------- jax/experimental/jax2tf/jax2tf.py | 3 +- jax/experimental/jax2tf/tests/call_tf_test.py | 166 ++++++++++++++---- 5 files changed, 305 insertions(+), 129 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71f8e795a..dd4d1cf03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,13 +11,17 @@ Remember to align the itemized text with the first line of an item within a list * Changes * `jax.tree_util` now contain a set of APIs that allow user to define keys for their custom pytree node. This includes: - * `tree_flatten_with_path` that flattens a tree and return not only each leaf but + * `tree_flatten_with_path` that flattens a tree and return not only each leaf but also their key paths. * `tree_map_with_paths` that can map a function that takes the key path as argument. - * `register_pytree_with_keys`` to register how the key path and leaves should looks + * `register_pytree_with_keys`` to register how the key path and leaves should looks like in a custom pytree node. * `keystr` that pretty-prints a key path. + * {func}`jax2tf.call_tf` has a new parameter `output_shape_dtype` (default `None`) + that can be used to declare the output shape and type of the result. This enables + {func}`jax2tf.call_tf` to work in the presence of shape polymorphism. ({jax-issue}`#14734`). + * Deprecations * The old key-path APIs in `jax.tree_util` are deprecated and will be removed 3 months from Mar 10 2023: diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 4a40c8c5f..139673079 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -1229,6 +1229,57 @@ DeviceArray data or for np.ndarray that are aligned on 16-byte boundaries) and on GPU (for DeviceArray). The zero-copy does not yet work on TPU. +`call_tf` works even with shape polymorphism, but in that case +the user must pass the `output_shape_dtype` parameter to `call_tf` to declare +the expected output shapes. This allows JAX tracing to know the shape and +dtype of the results so that it can continue tracing the rest of the program. +When `output_shape_dtype` is not given (the default case), `call_tf` will +form a `tf.Graph` for the called TF function and will use the inferred +type and shape. However, in presence of dynamic shape the inferred TF +type will contain `None` for the dynamic dimensions, which is not enough +information for JAX shape polymorphism. + +For example: + +```python +def fun_jax(x): + y_shape = (x.shape[0] * 2, y.shape[1:]) + y = jax2tf.call_tf( + lambda x: tf.concat([x, x], axis=0), + output_shape_dype=jax.ShapeDtypeStruct(y_shape, x.dtype))(x) + # JAX will know the y.shape + return jnp.ones(y.shape, dtype=y.dtype) + y + +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` + +An even simpler example for a function that returns the same shape as the input: + +```python +def fun_jax(x): + return jax2tf.call_tf(tf.math.sin, + output_shape_dtype=x) + )(x) + +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` + +If all the output shapes of the TF function are static, JAX does not need the +`output_shape_dtype` argument: + +```python +def fun_tf(x): + return tf.math.reduce_sum(tf.math.sin(x)) + +def fun_jax(x): + return jax2tf.call_tf(fun_tf)(x) + +# The following will not throw an error because the output shape of fun_tf is static. +jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) +``` + +The shape polymorphism support for `call_tf` does not yet work for native lowering. + ### Limitations of call_tf The TF function must be compileable (`tf.function(func, jit_compile=True)`) @@ -1312,38 +1363,14 @@ JAX computation runs on TPU. This will fail if the computation captures variables on some other devices. It is best to use ``call_tf`` with TF functions that do not capture variables. -A TF function wrapped with `call_tf` cannot be applied to inputs whose -shapes are not constants, unless all the output shapes of the TF function -are static. The may arise when you try to apply `jax2tf.convert` with -polymorphic shapes on the result of `call_tf`: - -```python -def fun_jax(x): - return jax2tf.call_tf(tf.math.sin)(x) - -# The following will throw an error. -jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) -``` - -This is unsatisfying, because the result of the above conversion -could be simply `tf.math.sin`, which is batch polymorphic. But -JAX cannot keep track of shapes through a `call_tf` call, and it -cannot be sure that the shape-polymorphic conversion is safe. - -If all the output shapes of the TF function are static, JAX does not need to -keep track of shapes after a `call_tf` call, hence allows shape-polymorphic -inputs in such cases: - -```python -def fun_tf(x): - return tf.math.reduce_sum(tf.math.sin(x)) - -def fun_jax(x): - return jax2tf.call_tf(fun_tf)(x) - -# The following will not throw an error because the output shape of fun_tf is static. -jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x) -``` +In some rare cases your called TF function may contain ops with output +of statically known shape, but for which the shape inference is not implemented +completely and will appear to `call_tf` as if they have dynamically-shaped +outputs. In these cases you may get an error that +`call_tf cannot call functions whose output has dynamic shape`. Try using +the `output_shape_dtype` parameter to specify the expected output shape +(this essentially allows you to override the shape inference for the +purposes of `call_tf`.) # Misc notes diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 78773021e..b2a849a8c 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -55,13 +55,15 @@ map = util.safe_map zip = util.safe_zip TfConcreteFunction = Any +TfVal = jax2tf_internal.TfVal # The platforms for which to use DLPack to avoid copying (only works on GPU # and CPU at the moment, and only for DeviceArray). For CPU we don't need # DLPack, if we are careful. _DLPACK_PLATFORMS = ("gpu",) -def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: +def call_tf(callable_tf: Callable, has_side_effects=True, + output_shape_dtype=None) -> Callable: """Calls a TensorFlow function from JAX, with support for reverse autodiff. The ``callable_tf`` will be called with TensorFlow-compatible arguments ( @@ -90,6 +92,14 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: has_side_effects: if True then it ensures that instances of this primitive are not removed or replicated by JAX optimizations such as dead-code elimination. + output_shape_dtype: An optional declaration of the expected shapes and dtypes + from the called TensorFlow function. If given it will be used during JAX + tracing to form the abstract values of the results of the `call_tf`. If + not given then we form a `tf.Graph` for the called TensorFlow function and + we use the TensorFlow-inferred shapes and types. Must be a pytree matching the + structure of the nested structure returned from the TensorFlow function, + containing objects with `.shape` and `.dtype` attributes, + e.g., `jax.ShapeDtypeStruct` or `jax.Array`. Returns: a JAX callable that can be invoked with JAX pytree arguments, in op-by-op mode or in a staged context. This callable can be used with @@ -113,67 +123,58 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: def make_tensorspec(a_jax): a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype) a_tf_shape = [ - d if core.is_constant_dim(d) else None for d in a_jax.shape - ] + d if core.is_constant_dim(d) else None for d in a_jax.shape] return tf.TensorSpec(a_tf_shape, a_tf_dtype) args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax)) - def check_tf_result(r_tf): - # Check that the TF function returns values of expected types. This - # improves error reporting, preventing hard-to-diagnose errors downstream - try: - jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf) - except Exception as e: - msg = ("The called TF function returns a result that is not " - f"convertible to JAX: {r_tf}.") - raise ValueError(msg) from e + if output_shape_dtype is not None: + output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype) + output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat) + else: + output_avals, output_shape_dtype_tree = None, None res_treedef = None # We'll store here the result treedef res_tf_flat = None # For error reporting # The function below will be called at least once, either in eager - # or in graph mode. + # mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf() def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]: args_tf = args_treedef.unflatten(args_tf_flat) res_tf = callable_tf(*args_tf) nonlocal res_treedef, res_tf_flat res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf) - for r_tf in res_tf_flat: - check_tf_result(r_tf) - assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}" + assert res_treedef is None or res_treedef == res_treedef_now, ( + f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}") res_treedef = res_treedef_now - return res_tf_flat + if output_avals is not None: + if res_treedef != output_shape_dtype_tree: + raise ValueError( + "The pytree of the TensorFlow function results does not match the " + "pytree of the declared output_shape_dtype:\n" + f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}") + assert len(output_avals) == len(res_tf_flat) + + checked_res_tf_flat = [ + check_tf_result(i, r_tf, r_aval) + for i, (r_tf, r_aval) in enumerate( + zip(res_tf_flat, + (output_avals if output_avals is not None + else (None,) * len(res_tf_flat))))] + return checked_res_tf_flat # Prepare a tf.function ahead of time, to cache the concrete functions. This # won't be used in op-by-op execution mode. function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True) - input_shapes_tf = [s.shape for s in args_flat_sig_tf] - output_shapes_tf = _get_concrete_function_tf( - function_flat_tf, args_flat_sig_tf - ).output_shapes - - if not all(s.is_fully_defined() for s in input_shapes_tf) and not all( - s.is_fully_defined() for s in output_shapes_tf - ): - for a_jax, a_tf_shape in zip(args_flat_jax, input_shapes_tf): - if not a_tf_shape.is_fully_defined(): - msg = ( - "call_tf cannot be applied to shape-polymorphic arguments unless" - " all the output shapes are static. Found argument shape:" - f" {a_jax.shape}. See" - " https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" - " for a discussion." - ) - raise ValueError(msg) - res_jax_flat = call_tf_p.bind( *args_flat_jax, # Carry the actual function such that op-by-op call can call in TF eager mode. callable_flat_tf=callable_flat_tf, function_flat_tf=function_flat_tf, args_flat_sig_tf=args_flat_sig_tf, + output_avals=output_avals, has_side_effects=has_side_effects) + # We must have called callable_flat_tf by nοw assert res_treedef is not None # Sometimes, in compiled mode, we get a different number of results than we # got when tracing the TF function (and building the res_treedef). This @@ -248,6 +249,44 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: return util.wraps(callable_tf)(make_call) +def check_tf_result(idx: int, r_tf: TfVal, r_aval: Optional[core.ShapedArray]) -> TfVal: + # Check that the TF function returns values of expected types. This + # improves error reporting, preventing hard-to-diagnose errors downstream + try: + jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf) + except Exception as e: + msg = ("The called TF function returns a result that is not " + f"convertible to JAX: {r_tf}.") + raise ValueError(msg) from e + + if r_aval is None: + return r_tf + # We convert to TF type, and canonicalize to 32-bit if necessary + r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype) + # Checking shapes is trickier in presence of dynamic shapes. I wish we could + # check at runtime that the returned shape matches the declared shape. I wish + # that tf.ensure_shape did this, but it can only take shapes that contain None + # not computed shapes. However, in eager mode we should be able to resolve + # the declared shapes to constants and we get better checking. + if tf.executing_eagerly(): + r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape) + else: + r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval) + # We do as much checking as we can here, instead of relying on tf.ensure_shape + # because the latter gives different errors in eager vs. compiled mode. + if (r_tf.dtype != r_aval_dtype_tf or + len(r_tf.shape) != len(r_aval_shape_tf) or + any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d + for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))): + msg = ("The shapes or dtypes returned by the TensorFlow function " + "do not match the declared output_shape_dtype:\n" + f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]") + raise ValueError(msg) + # At this point tf.ensure_shape does not do much, it should never throw an + # error, albeit it may refine the shape a bit. + return tf.ensure_shape(r_tf, r_aval_shape_tf) + + call_tf_p = core.Primitive("call_tf") call_tf_p.multiple_results = True @@ -309,39 +348,42 @@ effects.remat_allowed_effects.add_type(CallTfEffect) effects.custom_derivatives_allowed_effects.add_type(CallTfEffect) -def _call_tf_abstract_eval(*_, +def _call_tf_abstract_eval(*args_flat_avals, function_flat_tf, args_flat_sig_tf, - has_side_effects, **__): + has_side_effects, + output_avals, **__): # Called only when we form a Jaxpr, i.e., under jit, scan, etc. + effects = {call_tf_effect} if has_side_effects else set() + # If not output_avals is given, then we ask TF to infer the output shapes. + # We call this even if output_avals is given because it will ensure that + # callable_flat_tf is called. Since _get_concrete_function_tf is cached + # there is a small cost of calling it more often than needed. concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf) + if output_avals is not None: + return output_avals, effects + def is_fully_known_shape(s): return s.rank is not None and all([d is not None for d in s]) - effects = {call_tf_effect} if has_side_effects else set() - if all([is_fully_known_shape(s) - for s in concrete_function_flat_tf.output_shapes]): - return ( - tuple([ - # We convert to JAX type, and canonicalize to 32-bit if necessary - core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) - for dtype, shape in zip(concrete_function_flat_tf.output_dtypes, - concrete_function_flat_tf.output_shapes) - ]), - effects) - # There are some cases when TF shape inference is not powerful enough to - # figure out the output shapes (e.g., b/128924522), even in situations where - # XLA can compile the code, from which we can get the shapes. + if all(is_fully_known_shape(s) + for s in concrete_function_flat_tf.output_shapes): + avals_from_tf = tuple( + # We convert to JAX type, and canonicalize to 32-bit if necessary + core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) + for dtype, shape in zip(concrete_function_flat_tf.output_dtypes, + concrete_function_flat_tf.output_shapes)) + return avals_from_tf, effects - # We use the "cpu" as the platform, since JAX abstract eval is not platform - # specific; the "cpu" backend is always available and for abstract evaluation - # it should not matter which platform we use. - _, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, - "CPU") - return tuple(result_avals), effects + msg = ("call_tf cannot call functions whose output has dynamic shape. " + f"Found output shapes: {concrete_function_flat_tf.output_shapes}. " + "Consider using the `output_shape_dtype` argument to call_tf. " + "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion.") + raise ValueError(msg) call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) @@ -372,6 +414,11 @@ def _code_generator_and_avals( ) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]], Sequence[ir.Value]]], Sequence[core.ShapedArray]]: + # TODO(necula): we have refactored the code to not need to lower the code + # just in order to get the avals, so in fact the returned avals from this + # function are never used. We keep it here for now in case we detect + # a regressions, but if not we should simplify this function. + # Returns and caches a code generator (taking a builder and the # XlaOps for the arguments) and a sequence of result abstract shapes. @@ -478,8 +525,7 @@ def _register_call_lowering(platform): for platform in ("cpu", "cuda", "tpu"): _register_call_lowering(platform) -# Support the call_tf under jax2tf.convert -TfVal = jax2tf_internal.TfVal +# Support the call_tf under jax2tf.convert in eager mode def _jax2tf_call_tf(*args: TfVal, callable_flat_tf: Callable, **_) -> TfVal: diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 7acd92831..e8c6f9b22 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -335,8 +335,7 @@ def convert(fun_jax: Callable, _thread_local_state.tf_outer_name_scope = tf.get_current_name_scope() # TODO: is there a better way to check if we are inside a transformation? - if not core.trace_state_clean( - ) and not _thread_local_state.inside_call_tf: + if not core.trace_state_clean() and not _thread_local_state.inside_call_tf: # It is Ok to nest convert when we are inside a call_tf raise ValueError( "convert must be used outside all JAX transformations." + diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 370e3e925..f4ecb8f89 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -44,6 +44,12 @@ def _maybe_jit(with_jit: bool, func: Callable) -> Callable: else: return func +def _maybe_tf_jit(with_jit: bool, func: Callable) -> Callable: + if with_jit: + return tf.function(func, autograph=False, jit_compile=True) + else: + return func + def _named_test(**kwargs): return dict(kwargs, testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())])) @@ -53,8 +59,7 @@ _parameterized_jit = parameterized.named_parameters( for with_jit in [True, False]) _call_tf_non_compileable_error = "Error compiling TensorFlow function. call_tf can used in a staged context .* only with compileable functions" -_call_tf_dynamic_shape_error = "Compiled TensorFlow function has dynamic output shape.* call_tf can used in a staged context .* only with compileable functions" - +_call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape" class CallTfTest(tf_test_util.JaxToTfTestCase): @@ -171,8 +176,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): x = np.array([True, False], dtype=np.bool_) self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) # Works in eager mode - with self.assertRaisesRegex(ValueError, - _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): jax.jit(f_jax)(x) def test_error_bad_result_tensorarray(self): @@ -569,9 +573,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(x[0:x[1]], res1) # Now under jit, should fail because the function is not compileable - with self.assertRaisesRegex( - ValueError, "Compiled TensorFlow function has dynamic output shape" - ): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) @@ -1099,29 +1101,143 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(expected, res1, check_dtypes=False) # Now under jit, should fail because the function is not compileable - with self.assertRaisesRegex(ValueError, - _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) # TODO(necula): this should work in op-by-op mode, but it fails because # jax2tf.convert does abstract evaluation. - with self.assertRaisesRegex(ValueError, - _call_tf_dynamic_shape_error): + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf)) fun_tf_rt(x) - def test_shape_polymorphism_error(self): - x = np.array([.7, .8], dtype=np.float32) + @_parameterized_jit + def test_shape_poly_static_output_shape(self, with_jit=True): + if config.jax2tf_default_experimental_native_lowering: + raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") + x = np.array([0.7, 0.8], dtype=np.float32) + def fun_tf(x): - return tf.math.sin(x) + return tf.math.reduce_sum(tf.math.sin(x)) fun_jax = jax2tf.call_tf(fun_tf) + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + self.assertAllClose(fun_tf(x), fun_tf_rt(x)) + + @_parameterized_jit + def test_shape_poly(self, with_jit=False): + if config.jax2tf_default_experimental_native_lowering: + raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + y = jax2tf.call_tf(tf.math.sin, + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x) + z = jnp.cos(y) + w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0), + output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z) + assert w.shape[0] == 2 * x.shape[0] + return w + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + res_tf = fun_tf_rt(x) + self.assertAllClose(fun_jax(x), res_tf) + + @_parameterized_jit + def test_shape_poly_pytree_result(self, with_jit=True): + if config.jax2tf_default_experimental_native_lowering: + raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.") + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + # Returns a tuple + y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)), + output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x) + assert y[0].shape[0] == x.shape[0] + assert y[1].shape[0] == 2 * x.shape[0] + return y + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + res_tf = fun_tf_rt(x) + self.assertAllClose(fun_jax(x), res_tf) + + @_parameterized_jit + def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True): + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + return jax2tf.call_tf(tf.math.sin)(x) + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): + fun_tf_rt(x) + + @_parameterized_jit + def test_shape_poly_error_mismatch_output_shape_dtype_tree(self, with_jit=False): + x = np.array([7, 8, 9, 10], dtype=np.float32) + def fun_jax(x): + return jax2tf.call_tf(tf.math.sin, + output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape, x.dtype)))(x) + + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) - fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]) with self.assertRaisesRegex( - ValueError, "call_tf cannot be applied to shape-polymorphic arguments" - ): + ValueError, + "The pytree of the TensorFlow function results does not match the pytree of the declared output_shape_dtype"): + fun_tf_rt(x) + + @parameterized.named_parameters( + _named_test(with_jit=with_jit, kind=kind) + for with_jit in [True, False] + for kind in ["bad_rank", "bad_dim", "bad_dtype", "bad_dtype_x64"]) + def test_shape_poly_error_mismatch_output_shape_dtype(self, with_jit=False, kind="bad_rank"): + x = np.array([7, 8, 9, 10], dtype=np.float32) + + if kind == "bad_rank": + def fun_jax(x): + return jax2tf.call_tf(lambda x: x, + # Wrong shape rank + output_shape_dtype=jax.ShapeDtypeStruct((), x.dtype))(x) + elif kind == "bad_dim": + def fun_jax(x): + bad_shape = (5 + x.shape[0],) + y = jax2tf.call_tf(lambda x: x, + # Wrong dimension + output_shape_dtype=jax.ShapeDtypeStruct(bad_shape, x.dtype))(x) + # JAX will believe that the following is Ok, leading to downstream error in TF + return y + jnp.ones(bad_shape, dtype=x.dtype) + elif kind == "bad_dtype": + def fun_jax(x): + return jax2tf.call_tf(lambda x: x, + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.int32))(x) + elif kind == "bad_dtype_x64": + def fun_jax(x): + return jax2tf.call_tf(lambda x: x * np.float64(3.), + output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.float64))(x) + else: + assert False + expect_ex = ValueError + expect_error = r"The shapes or dtypes returned by the TensorFlow function do not match the declared output_shape_dtype" + + # Call without shape polymorphism + fun_tf_rt = _maybe_tf_jit(with_jit, jax2tf.convert(fun_jax)) + with self.assertRaisesRegex(expect_ex, expect_error): + fun_tf_rt(x) + + # Now with shape polymorphism + if kind == "bad_dim" and with_jit: + # TODO: in jit more the error pops up later, at AddV2 + expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2" + if kind == "bad_dim" and config.jax2tf_default_experimental_native_lowering: + # TODO(b/268386622): call_tf with shape polymorphism and native lowering. + expect_error = "Error compiling TensorFlow function. call_tf can used .* only with compileable functions with static output shapes" + fun_tf_rt = _maybe_tf_jit(with_jit, + jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])) + with self.assertRaisesRegex(expect_ex, expect_error): fun_tf_rt(x) def test_inner_native_lowering(self): @@ -1142,22 +1258,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): self.assertIn('op: "XlaCallModule"', f_outer_graph) self.assertNotIn('op: "Sin"', f_outer_graph) - @_parameterized_jit - def test_shape_polymorphism_static_output_shape(self, with_jit=True): - # TODO(b/268386622) Dynamic shapes not yet supported. - if config.jax2tf_default_experimental_native_lowering: - raise unittest.SkipTest("Skip test because of dynamic shapes.") - x = np.array([0.7, 0.8], dtype=np.float32) - - def fun_tf(x): - return tf.math.reduce_sum(tf.math.sin(x)) - - fun_jax = jax2tf.call_tf(fun_tf) - fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]) - if with_jit: - fun_tf_rt = tf.function(jit_compile=True, autograph=False)(fun_tf_rt) - self.assertAllClose(fun_tf(x), fun_tf_rt(x)) - @parameterized.named_parameters( _named_test(f2_function=f2_function, f2_saved_model=f2_saved_model, f4_function=f4_function, f4_saved_model=f4_saved_model)