[shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program

This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.

The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.

The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.

This change should be enough for old-style jax2tf, but will require more
work for native serialization.

We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.

PiperOrigin-RevId: 515137407
This commit is contained in:
George Necula 2023-03-08 14:09:35 -08:00 committed by jax authors
parent 942e79ffe3
commit 961e09e614
5 changed files with 305 additions and 129 deletions

View File

@ -11,13 +11,17 @@ Remember to align the itemized text with the first line of an item within a list
* Changes * Changes
* `jax.tree_util` now contain a set of APIs that allow user to define keys for their * `jax.tree_util` now contain a set of APIs that allow user to define keys for their
custom pytree node. This includes: custom pytree node. This includes:
* `tree_flatten_with_path` that flattens a tree and return not only each leaf but * `tree_flatten_with_path` that flattens a tree and return not only each leaf but
also their key paths. also their key paths.
* `tree_map_with_paths` that can map a function that takes the key path as argument. * `tree_map_with_paths` that can map a function that takes the key path as argument.
* `register_pytree_with_keys`` to register how the key path and leaves should looks * `register_pytree_with_keys`` to register how the key path and leaves should looks
like in a custom pytree node. like in a custom pytree node.
* `keystr` that pretty-prints a key path. * `keystr` that pretty-prints a key path.
* {func}`jax2tf.call_tf` has a new parameter `output_shape_dtype` (default `None`)
that can be used to declare the output shape and type of the result. This enables
{func}`jax2tf.call_tf` to work in the presence of shape polymorphism. ({jax-issue}`#14734`).
* Deprecations * Deprecations
* The old key-path APIs in `jax.tree_util` are deprecated and will be removed 3 months * The old key-path APIs in `jax.tree_util` are deprecated and will be removed 3 months
from Mar 10 2023: from Mar 10 2023:

View File

@ -1229,6 +1229,57 @@ DeviceArray data or for np.ndarray that are aligned on 16-byte
boundaries) and on GPU (for DeviceArray). boundaries) and on GPU (for DeviceArray).
The zero-copy does not yet work on TPU. The zero-copy does not yet work on TPU.
`call_tf` works even with shape polymorphism, but in that case
the user must pass the `output_shape_dtype` parameter to `call_tf` to declare
the expected output shapes. This allows JAX tracing to know the shape and
dtype of the results so that it can continue tracing the rest of the program.
When `output_shape_dtype` is not given (the default case), `call_tf` will
form a `tf.Graph` for the called TF function and will use the inferred
type and shape. However, in presence of dynamic shape the inferred TF
type will contain `None` for the dynamic dimensions, which is not enough
information for JAX shape polymorphism.
For example:
```python
def fun_jax(x):
y_shape = (x.shape[0] * 2, y.shape[1:])
y = jax2tf.call_tf(
lambda x: tf.concat([x, x], axis=0),
output_shape_dype=jax.ShapeDtypeStruct(y_shape, x.dtype))(x)
# JAX will know the y.shape
return jnp.ones(y.shape, dtype=y.dtype) + y
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
An even simpler example for a function that returns the same shape as the input:
```python
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin,
output_shape_dtype=x)
)(x)
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
If all the output shapes of the TF function are static, JAX does not need the
`output_shape_dtype` argument:
```python
def fun_tf(x):
return tf.math.reduce_sum(tf.math.sin(x))
def fun_jax(x):
return jax2tf.call_tf(fun_tf)(x)
# The following will not throw an error because the output shape of fun_tf is static.
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
The shape polymorphism support for `call_tf` does not yet work for native lowering.
### Limitations of call_tf ### Limitations of call_tf
The TF function must be compileable (`tf.function(func, jit_compile=True)`) The TF function must be compileable (`tf.function(func, jit_compile=True)`)
@ -1312,38 +1363,14 @@ JAX computation runs on TPU. This will fail if the computation captures
variables on some other devices. It is best to use ``call_tf`` variables on some other devices. It is best to use ``call_tf``
with TF functions that do not capture variables. with TF functions that do not capture variables.
A TF function wrapped with `call_tf` cannot be applied to inputs whose In some rare cases your called TF function may contain ops with output
shapes are not constants, unless all the output shapes of the TF function of statically known shape, but for which the shape inference is not implemented
are static. The may arise when you try to apply `jax2tf.convert` with completely and will appear to `call_tf` as if they have dynamically-shaped
polymorphic shapes on the result of `call_tf`: outputs. In these cases you may get an error that
`call_tf cannot call functions whose output has dynamic shape`. Try using
```python the `output_shape_dtype` parameter to specify the expected output shape
def fun_jax(x): (this essentially allows you to override the shape inference for the
return jax2tf.call_tf(tf.math.sin)(x) purposes of `call_tf`.)
# The following will throw an error.
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
This is unsatisfying, because the result of the above conversion
could be simply `tf.math.sin`, which is batch polymorphic. But
JAX cannot keep track of shapes through a `call_tf` call, and it
cannot be sure that the shape-polymorphic conversion is safe.
If all the output shapes of the TF function are static, JAX does not need to
keep track of shapes after a `call_tf` call, hence allows shape-polymorphic
inputs in such cases:
```python
def fun_tf(x):
return tf.math.reduce_sum(tf.math.sin(x))
def fun_jax(x):
return jax2tf.call_tf(fun_tf)(x)
# The following will not throw an error because the output shape of fun_tf is static.
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
# Misc notes # Misc notes

View File

@ -55,13 +55,15 @@ map = util.safe_map
zip = util.safe_zip zip = util.safe_zip
TfConcreteFunction = Any TfConcreteFunction = Any
TfVal = jax2tf_internal.TfVal
# The platforms for which to use DLPack to avoid copying (only works on GPU # The platforms for which to use DLPack to avoid copying (only works on GPU
# and CPU at the moment, and only for DeviceArray). For CPU we don't need # and CPU at the moment, and only for DeviceArray). For CPU we don't need
# DLPack, if we are careful. # DLPack, if we are careful.
_DLPACK_PLATFORMS = ("gpu",) _DLPACK_PLATFORMS = ("gpu",)
def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable: def call_tf(callable_tf: Callable, has_side_effects=True,
output_shape_dtype=None) -> Callable:
"""Calls a TensorFlow function from JAX, with support for reverse autodiff. """Calls a TensorFlow function from JAX, with support for reverse autodiff.
The ``callable_tf`` will be called with TensorFlow-compatible arguments ( The ``callable_tf`` will be called with TensorFlow-compatible arguments (
@ -90,6 +92,14 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
has_side_effects: if True then it ensures that instances of this primitive has_side_effects: if True then it ensures that instances of this primitive
are not removed or replicated by JAX optimizations such as dead-code are not removed or replicated by JAX optimizations such as dead-code
elimination. elimination.
output_shape_dtype: An optional declaration of the expected shapes and dtypes
from the called TensorFlow function. If given it will be used during JAX
tracing to form the abstract values of the results of the `call_tf`. If
not given then we form a `tf.Graph` for the called TensorFlow function and
we use the TensorFlow-inferred shapes and types. Must be a pytree matching the
structure of the nested structure returned from the TensorFlow function,
containing objects with `.shape` and `.dtype` attributes,
e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
Returns: a JAX callable that can be invoked with JAX pytree arguments, in Returns: a JAX callable that can be invoked with JAX pytree arguments, in
op-by-op mode or in a staged context. This callable can be used with op-by-op mode or in a staged context. This callable can be used with
@ -113,67 +123,58 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
def make_tensorspec(a_jax): def make_tensorspec(a_jax):
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype) a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
a_tf_shape = [ a_tf_shape = [
d if core.is_constant_dim(d) else None for d in a_jax.shape d if core.is_constant_dim(d) else None for d in a_jax.shape]
]
return tf.TensorSpec(a_tf_shape, a_tf_dtype) return tf.TensorSpec(a_tf_shape, a_tf_dtype)
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax)) args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
def check_tf_result(r_tf): if output_shape_dtype is not None:
# Check that the TF function returns values of expected types. This output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
# improves error reporting, preventing hard-to-diagnose errors downstream output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
try: else:
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf) output_avals, output_shape_dtype_tree = None, None
except Exception as e:
msg = ("The called TF function returns a result that is not "
f"convertible to JAX: {r_tf}.")
raise ValueError(msg) from e
res_treedef = None # We'll store here the result treedef res_treedef = None # We'll store here the result treedef
res_tf_flat = None # For error reporting res_tf_flat = None # For error reporting
# The function below will be called at least once, either in eager # The function below will be called at least once, either in eager
# or in graph mode. # mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]: def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
args_tf = args_treedef.unflatten(args_tf_flat) args_tf = args_treedef.unflatten(args_tf_flat)
res_tf = callable_tf(*args_tf) res_tf = callable_tf(*args_tf)
nonlocal res_treedef, res_tf_flat nonlocal res_treedef, res_tf_flat
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf) res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
for r_tf in res_tf_flat: assert res_treedef is None or res_treedef == res_treedef_now, (
check_tf_result(r_tf) f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}"
res_treedef = res_treedef_now res_treedef = res_treedef_now
return res_tf_flat if output_avals is not None:
if res_treedef != output_shape_dtype_tree:
raise ValueError(
"The pytree of the TensorFlow function results does not match the "
"pytree of the declared output_shape_dtype:\n"
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
assert len(output_avals) == len(res_tf_flat)
checked_res_tf_flat = [
check_tf_result(i, r_tf, r_aval)
for i, (r_tf, r_aval) in enumerate(
zip(res_tf_flat,
(output_avals if output_avals is not None
else (None,) * len(res_tf_flat))))]
return checked_res_tf_flat
# Prepare a tf.function ahead of time, to cache the concrete functions. This # Prepare a tf.function ahead of time, to cache the concrete functions. This
# won't be used in op-by-op execution mode. # won't be used in op-by-op execution mode.
function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True) function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True)
input_shapes_tf = [s.shape for s in args_flat_sig_tf]
output_shapes_tf = _get_concrete_function_tf(
function_flat_tf, args_flat_sig_tf
).output_shapes
if not all(s.is_fully_defined() for s in input_shapes_tf) and not all(
s.is_fully_defined() for s in output_shapes_tf
):
for a_jax, a_tf_shape in zip(args_flat_jax, input_shapes_tf):
if not a_tf_shape.is_fully_defined():
msg = (
"call_tf cannot be applied to shape-polymorphic arguments unless"
" all the output shapes are static. Found argument shape:"
f" {a_jax.shape}. See"
" https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
" for a discussion."
)
raise ValueError(msg)
res_jax_flat = call_tf_p.bind( res_jax_flat = call_tf_p.bind(
*args_flat_jax, *args_flat_jax,
# Carry the actual function such that op-by-op call can call in TF eager mode. # Carry the actual function such that op-by-op call can call in TF eager mode.
callable_flat_tf=callable_flat_tf, callable_flat_tf=callable_flat_tf,
function_flat_tf=function_flat_tf, function_flat_tf=function_flat_tf,
args_flat_sig_tf=args_flat_sig_tf, args_flat_sig_tf=args_flat_sig_tf,
output_avals=output_avals,
has_side_effects=has_side_effects) has_side_effects=has_side_effects)
# We must have called callable_flat_tf by nοw
assert res_treedef is not None assert res_treedef is not None
# Sometimes, in compiled mode, we get a different number of results than we # Sometimes, in compiled mode, we get a different number of results than we
# got when tracing the TF function (and building the res_treedef). This # got when tracing the TF function (and building the res_treedef). This
@ -248,6 +249,44 @@ def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
return util.wraps(callable_tf)(make_call) return util.wraps(callable_tf)(make_call)
def check_tf_result(idx: int, r_tf: TfVal, r_aval: Optional[core.ShapedArray]) -> TfVal:
# Check that the TF function returns values of expected types. This
# improves error reporting, preventing hard-to-diagnose errors downstream
try:
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
except Exception as e:
msg = ("The called TF function returns a result that is not "
f"convertible to JAX: {r_tf}.")
raise ValueError(msg) from e
if r_aval is None:
return r_tf
# We convert to TF type, and canonicalize to 32-bit if necessary
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
# check at runtime that the returned shape matches the declared shape. I wish
# that tf.ensure_shape did this, but it can only take shapes that contain None
# not computed shapes. However, in eager mode we should be able to resolve
# the declared shapes to constants and we get better checking.
if tf.executing_eagerly():
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
else:
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
# We do as much checking as we can here, instead of relying on tf.ensure_shape
# because the latter gives different errors in eager vs. compiled mode.
if (r_tf.dtype != r_aval_dtype_tf or
len(r_tf.shape) != len(r_aval_shape_tf) or
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
msg = ("The shapes or dtypes returned by the TensorFlow function "
"do not match the declared output_shape_dtype:\n"
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
raise ValueError(msg)
# At this point tf.ensure_shape does not do much, it should never throw an
# error, albeit it may refine the shape a bit.
return tf.ensure_shape(r_tf, r_aval_shape_tf)
call_tf_p = core.Primitive("call_tf") call_tf_p = core.Primitive("call_tf")
call_tf_p.multiple_results = True call_tf_p.multiple_results = True
@ -309,39 +348,42 @@ effects.remat_allowed_effects.add_type(CallTfEffect)
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect) effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
def _call_tf_abstract_eval(*_, def _call_tf_abstract_eval(*args_flat_avals,
function_flat_tf, function_flat_tf,
args_flat_sig_tf, args_flat_sig_tf,
has_side_effects, **__): has_side_effects,
output_avals, **__):
# Called only when we form a Jaxpr, i.e., under jit, scan, etc. # Called only when we form a Jaxpr, i.e., under jit, scan, etc.
effects = {call_tf_effect} if has_side_effects else set()
# If not output_avals is given, then we ask TF to infer the output shapes.
# We call this even if output_avals is given because it will ensure that
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
# there is a small cost of calling it more often than needed.
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
args_flat_sig_tf) args_flat_sig_tf)
if output_avals is not None:
return output_avals, effects
def is_fully_known_shape(s): def is_fully_known_shape(s):
return s.rank is not None and all([d is not None for d in s]) return s.rank is not None and all([d is not None for d in s])
effects = {call_tf_effect} if has_side_effects else set()
if all([is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes]):
return (
tuple([
# We convert to JAX type, and canonicalize to 32-bit if necessary
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes)
]),
effects)
# There are some cases when TF shape inference is not powerful enough to if all(is_fully_known_shape(s)
# figure out the output shapes (e.g., b/128924522), even in situations where for s in concrete_function_flat_tf.output_shapes):
# XLA can compile the code, from which we can get the shapes. avals_from_tf = tuple(
# We convert to JAX type, and canonicalize to 32-bit if necessary
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes))
return avals_from_tf, effects
# We use the "cpu" as the platform, since JAX abstract eval is not platform msg = ("call_tf cannot call functions whose output has dynamic shape. "
# specific; the "cpu" backend is always available and for abstract evaluation f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
# it should not matter which platform we use. "Consider using the `output_shape_dtype` argument to call_tf. "
_, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
"CPU") " for a discussion.")
return tuple(result_avals), effects raise ValueError(msg)
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
@ -372,6 +414,11 @@ def _code_generator_and_avals(
) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]], ) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]],
Sequence[ir.Value]]], Sequence[ir.Value]]],
Sequence[core.ShapedArray]]: Sequence[core.ShapedArray]]:
# TODO(necula): we have refactored the code to not need to lower the code
# just in order to get the avals, so in fact the returned avals from this
# function are never used. We keep it here for now in case we detect
# a regressions, but if not we should simplify this function.
# Returns and caches a code generator (taking a builder and the # Returns and caches a code generator (taking a builder and the
# XlaOps for the arguments) and a sequence of result abstract shapes. # XlaOps for the arguments) and a sequence of result abstract shapes.
@ -478,8 +525,7 @@ def _register_call_lowering(platform):
for platform in ("cpu", "cuda", "tpu"): for platform in ("cpu", "cuda", "tpu"):
_register_call_lowering(platform) _register_call_lowering(platform)
# Support the call_tf under jax2tf.convert # Support the call_tf under jax2tf.convert in eager mode
TfVal = jax2tf_internal.TfVal
def _jax2tf_call_tf(*args: TfVal, def _jax2tf_call_tf(*args: TfVal,
callable_flat_tf: Callable, callable_flat_tf: Callable,
**_) -> TfVal: **_) -> TfVal:

View File

@ -335,8 +335,7 @@ def convert(fun_jax: Callable,
_thread_local_state.tf_outer_name_scope = tf.get_current_name_scope() _thread_local_state.tf_outer_name_scope = tf.get_current_name_scope()
# TODO: is there a better way to check if we are inside a transformation? # TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean( if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
) and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf # It is Ok to nest convert when we are inside a call_tf
raise ValueError( raise ValueError(
"convert must be used outside all JAX transformations." + "convert must be used outside all JAX transformations." +

View File

@ -44,6 +44,12 @@ def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
else: else:
return func return func
def _maybe_tf_jit(with_jit: bool, func: Callable) -> Callable:
if with_jit:
return tf.function(func, autograph=False, jit_compile=True)
else:
return func
def _named_test(**kwargs): def _named_test(**kwargs):
return dict(kwargs, return dict(kwargs,
testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())])) testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]))
@ -53,8 +59,7 @@ _parameterized_jit = parameterized.named_parameters(
for with_jit in [True, False]) for with_jit in [True, False])
_call_tf_non_compileable_error = "Error compiling TensorFlow function. call_tf can used in a staged context .* only with compileable functions" _call_tf_non_compileable_error = "Error compiling TensorFlow function. call_tf can used in a staged context .* only with compileable functions"
_call_tf_dynamic_shape_error = "Compiled TensorFlow function has dynamic output shape.* call_tf can used in a staged context .* only with compileable functions" _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape"
class CallTfTest(tf_test_util.JaxToTfTestCase): class CallTfTest(tf_test_util.JaxToTfTestCase):
@ -171,8 +176,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
x = np.array([True, False], dtype=np.bool_) x = np.array([True, False], dtype=np.bool_)
self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) # Works in eager mode self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) # Works in eager mode
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
_call_tf_dynamic_shape_error):
jax.jit(f_jax)(x) jax.jit(f_jax)(x)
def test_error_bad_result_tensorarray(self): def test_error_bad_result_tensorarray(self):
@ -569,9 +573,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(x[0:x[1]], res1) self.assertAllClose(x[0:x[1]], res1)
# Now under jit, should fail because the function is not compileable # Now under jit, should fail because the function is not compileable
with self.assertRaisesRegex( with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
ValueError, "Compiled TensorFlow function has dynamic output shape"
):
fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
fun_jax(x) fun_jax(x)
@ -1099,29 +1101,143 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(expected, res1, check_dtypes=False) self.assertAllClose(expected, res1, check_dtypes=False)
# Now under jit, should fail because the function is not compileable # Now under jit, should fail because the function is not compileable
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
_call_tf_dynamic_shape_error):
fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
fun_jax(x) fun_jax(x)
# TODO(necula): this should work in op-by-op mode, but it fails because # TODO(necula): this should work in op-by-op mode, but it fails because
# jax2tf.convert does abstract evaluation. # jax2tf.convert does abstract evaluation.
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
_call_tf_dynamic_shape_error):
fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf)) fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
fun_tf_rt(x) fun_tf_rt(x)
def test_shape_polymorphism_error(self): @_parameterized_jit
x = np.array([.7, .8], dtype=np.float32) def test_shape_poly_static_output_shape(self, with_jit=True):
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.")
x = np.array([0.7, 0.8], dtype=np.float32)
def fun_tf(x): def fun_tf(x):
return tf.math.sin(x) return tf.math.reduce_sum(tf.math.sin(x))
fun_jax = jax2tf.call_tf(fun_tf) fun_jax = jax2tf.call_tf(fun_tf)
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
self.assertAllClose(fun_tf(x), fun_tf_rt(x))
@_parameterized_jit
def test_shape_poly(self, with_jit=False):
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
y = jax2tf.call_tf(tf.math.sin,
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x)
z = jnp.cos(y)
w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0),
output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z)
assert w.shape[0] == 2 * x.shape[0]
return w
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
res_tf = fun_tf_rt(x)
self.assertAllClose(fun_jax(x), res_tf)
@_parameterized_jit
def test_shape_poly_pytree_result(self, with_jit=True):
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native lowering.")
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
# Returns a tuple
y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)),
output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x)
assert y[0].shape[0] == x.shape[0]
assert y[1].shape[0] == 2 * x.shape[0]
return y
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
res_tf = fun_tf_rt(x)
self.assertAllClose(fun_jax(x), res_tf)
@_parameterized_jit
def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True):
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin)(x)
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
fun_tf_rt(x)
@_parameterized_jit
def test_shape_poly_error_mismatch_output_shape_dtype_tree(self, with_jit=False):
x = np.array([7, 8, 9, 10], dtype=np.float32)
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin,
output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape, x.dtype)))(x)
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "call_tf cannot be applied to shape-polymorphic arguments" ValueError,
): "The pytree of the TensorFlow function results does not match the pytree of the declared output_shape_dtype"):
fun_tf_rt(x)
@parameterized.named_parameters(
_named_test(with_jit=with_jit, kind=kind)
for with_jit in [True, False]
for kind in ["bad_rank", "bad_dim", "bad_dtype", "bad_dtype_x64"])
def test_shape_poly_error_mismatch_output_shape_dtype(self, with_jit=False, kind="bad_rank"):
x = np.array([7, 8, 9, 10], dtype=np.float32)
if kind == "bad_rank":
def fun_jax(x):
return jax2tf.call_tf(lambda x: x,
# Wrong shape rank
output_shape_dtype=jax.ShapeDtypeStruct((), x.dtype))(x)
elif kind == "bad_dim":
def fun_jax(x):
bad_shape = (5 + x.shape[0],)
y = jax2tf.call_tf(lambda x: x,
# Wrong dimension
output_shape_dtype=jax.ShapeDtypeStruct(bad_shape, x.dtype))(x)
# JAX will believe that the following is Ok, leading to downstream error in TF
return y + jnp.ones(bad_shape, dtype=x.dtype)
elif kind == "bad_dtype":
def fun_jax(x):
return jax2tf.call_tf(lambda x: x,
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.int32))(x)
elif kind == "bad_dtype_x64":
def fun_jax(x):
return jax2tf.call_tf(lambda x: x * np.float64(3.),
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.float64))(x)
else:
assert False
expect_ex = ValueError
expect_error = r"The shapes or dtypes returned by the TensorFlow function do not match the declared output_shape_dtype"
# Call without shape polymorphism
fun_tf_rt = _maybe_tf_jit(with_jit, jax2tf.convert(fun_jax))
with self.assertRaisesRegex(expect_ex, expect_error):
fun_tf_rt(x)
# Now with shape polymorphism
if kind == "bad_dim" and with_jit:
# TODO: in jit more the error pops up later, at AddV2
expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
if kind == "bad_dim" and config.jax2tf_default_experimental_native_lowering:
# TODO(b/268386622): call_tf with shape polymorphism and native lowering.
expect_error = "Error compiling TensorFlow function. call_tf can used .* only with compileable functions with static output shapes"
fun_tf_rt = _maybe_tf_jit(with_jit,
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
with self.assertRaisesRegex(expect_ex, expect_error):
fun_tf_rt(x) fun_tf_rt(x)
def test_inner_native_lowering(self): def test_inner_native_lowering(self):
@ -1142,22 +1258,6 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
self.assertIn('op: "XlaCallModule"', f_outer_graph) self.assertIn('op: "XlaCallModule"', f_outer_graph)
self.assertNotIn('op: "Sin"', f_outer_graph) self.assertNotIn('op: "Sin"', f_outer_graph)
@_parameterized_jit
def test_shape_polymorphism_static_output_shape(self, with_jit=True):
# TODO(b/268386622) Dynamic shapes not yet supported.
if config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("Skip test because of dynamic shapes.")
x = np.array([0.7, 0.8], dtype=np.float32)
def fun_tf(x):
return tf.math.reduce_sum(tf.math.sin(x))
fun_jax = jax2tf.call_tf(fun_tf)
fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])
if with_jit:
fun_tf_rt = tf.function(jit_compile=True, autograph=False)(fun_tf_rt)
self.assertAllClose(fun_tf(x), fun_tf_rt(x))
@parameterized.named_parameters( @parameterized.named_parameters(
_named_test(f2_function=f2_function, f2_saved_model=f2_saved_model, _named_test(f2_function=f2_function, f2_saved_model=f2_saved_model,
f4_function=f4_function, f4_saved_model=f4_saved_model) f4_function=f4_function, f4_saved_model=f4_saved_model)