[jax2tf] Improved error checking for call_tf.

Cleaned up the abstract evaluation for call_tf to
work around a bug in TF whereby experimental_get_compiler_ir
cannot be used in a tf.function context.

Added more error messages, e.g., for the case when
the TF function has shape-influencing inputs.
This commit is contained in:
George Necula 2021-07-10 18:49:25 +03:00
parent 36d06dbb61
commit a966157548
7 changed files with 653 additions and 257 deletions

View File

@ -906,30 +906,66 @@ custom gradients,
see the test `test_round_trip_custom_grad_saved_model`
in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py).
## Notes:
All the metadata inserted by TF during tracing and compilation, e.g.,
source location information and op names, is carried through to the
JAX XLA computation.
* The TF function must be compileable (`tf.function(func, jit_compile=True`)
when used in a JAX staging context, but not when used in a
JAX op-by-op mode.
* All the metadata inserted by TF during tracing and compilation, e.g.,
source location information and op names, is carried through to the
JAX XLA computation.
* The TF custom gradients are respected, since it is TF that generates the
gradient computation.
* In op-by-op mode, when we call TensorFlow in eager mode, we use
DLPack to try to avoid copying the data. This works for CPU (for
DeviceArray data or for np.ndarray that are aligned on 16-byte
boundaries) and on GPU (for DeviceArray).
The zero-copy does not yet work on TPU.
* ``call_tf`` works best with pure TF functions that do not capture
``tf.Variable``s or tensors from the environment, and all such
context is passed in explicitly through arguments, and if variables
are modified, the resulting values are passed out through results.
There is a best-effort mechanism that can handle variable capture
and variable updates,
except in the case of a function that modifies ``tf.Variable``s
and is used in a JAX jitted context. Calling the ``inpure_func_tf``
will give an error:
The TF custom gradients are respected, since it is TF that generates the
gradient computation.
In op-by-op mode, when we call TensorFlow in eager mode, we use
DLPack to try to avoid copying the data. This works for CPU (for
DeviceArray data or for np.ndarray that are aligned on 16-byte
boundaries) and on GPU (for DeviceArray).
The zero-copy does not yet work on TPU.
### Limitations of call_tf
The TF function must be compileable (`tf.function(func, jit_compile=True`)
when used in a JAX staging context, e.g., `jax.jit`, `lax.scan`, `lax.cond`,
but not when used in a JAX op-by-op mode. For example, the following
function uses strings operations that are not supported by XLA:
```python
def f_tf_non_compileable(x):
return tf.strings.length(tf.strings.format("Hello {}!", [x]))
f_jax = jax2tf.call_tf(f_tf_non_compileable)
# Works in op-by-op mode
f_jax(np.float32(42.))
# Fails in jit mode
jax.jit(f_jax)(np.float(42.))
```
Another similar situation is when a function uses input values in
place of shapes. In this case TF actually does compile the function
but re-compiles it for each distinct value of the input. This is
not allowed when used from JAX:
```python
def f_tf_dynamic_shape(x):
return x[x[0]:5]
x = np.array([1, 2], dtype=np.int32)
f_jax = jax2tf.call_tf(f_tf_dynamic_shape)
# Works in op-by-op mode
f_jax(x)
# Fails in jit mode
jax.jit(f_jax)(x)
```
``call_tf`` works best with pure TF functions that do not capture
``tf.Variable``s or tensors from the environment, and all such
context is passed in explicitly through arguments, and if variables
are modified, the resulting values are passed out through results.
There is a best-effort mechanism that can handle variable capture
and variable updates,
except in the case of a function that modifies ``tf.Variable``s
and is used in a JAX jitted context. Calling the ``inpure_func_tf``
will give an error:
```python
var1 = tf.Variable(1.)
@ -940,7 +976,7 @@ in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/ja
jax.jit(jax2tf.call_tf(impure_func_tf))(tf.constant(2.)) # Fails in jit mode
```
The error can be avoided by passing the variable explicitly:
The error can be avoided by passing the variable explicitly:
```python
def pure_func_tf(x, var1)
@ -948,25 +984,35 @@ in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/ja
return x + new_var1, new_var1
```
This use case is likely to be revisited.
This use case is likely to be revisited.
## TODO
A TF function wrapped with `call_tf` cannot be applied to inputs whose
shapes are not constants. The may arise when you try to apply
`jax2tf.convert` with polymorphic shapes on the result of
`call_tf`:
* Ensure that there is no array copy through the host when running in eager
mode (JAX op-by-op).
```python
def fun_jax(x):
return jax2tf.call_tf(tf.math.sin)(x)
# The following will throw an error.
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
```
# Additional notes
This is unsatisfying, because the result of the above conversion
could be simply `tf.math.sin`, which is batch polymorphic. But
JAX cannot keep track of shapes through a `call_tf` call, and it
cannot be sure that the shape-polymorphic conversion is safe.
# Misc notes
## TensorFlow versions supported
The ``jax2tf.convert`` and `call_tf` require very recent versions of TensorFlow.
As of today, the tests are run using `tf_nightly==2.6.0-dev20210611`.
As of today, the tests are run using `tf_nightly==2.7.0.dev20210715`.
## Running on GPU
To run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support
for CUDA. One must be mindful to install a version of CUDA that is compatible
with both [jaxlib](https://github.com/google/jax/blob/main/README.md#pip-installation) and

View File

@ -24,7 +24,7 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#callin
"""
import functools
import logging
from typing import Any, Callable, Sequence, Tuple
from typing import Any, Callable, Optional, Sequence, Tuple
import jax
from jax import core
@ -97,10 +97,16 @@ def call_tf(callable_tf: Callable) -> Callable:
return v
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
args_flat_sig_tf = tuple(
tf.TensorSpec(a_jax.shape, jax2tf_internal._to_tf_dtype(a_jax.dtype))
for a_jax in args_flat_jax
)
def make_tensorspec(a_jax):
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
if any(not core.is_constant_dim(d) for d in a_jax.shape):
msg = ("call_tf cannot be applies to shape-polymorphic arguments. "
f"Found argument shape: {a_jax.shape}. "
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
raise ValueError(msg)
return tf.TensorSpec(a_jax.shape, a_tf_dtype)
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
res_treedef = None # We'll store here the result treedef
# The function below will be called at least once, either in eager
@ -116,7 +122,7 @@ def call_tf(callable_tf: Callable) -> Callable:
# Prepare a tf.function ahead of time, to cache the concrete functions. This
# won't be used in op-by-op execution mode.
function_flat_tf = tf.function(callable_flat_tf, jit_compile=True)
function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True)
res_jax_flat = call_tf_p.bind(
*args_flat_jax,
@ -222,29 +228,11 @@ call_tf_p.def_impl(_call_tf_impl)
def _call_tf_abstract_eval(*_,
function_flat_tf,
args_flat_sig_tf, **__):
# It seems that we cannot count on just TF shape inference to get the
# resulting shapes, because tf.function.get_concrete_function sometimes
# returns partially known shapes for a TF graph that was loaded with unknown
# shapes (e.g., b/128924522). So, we just compile
# the code and use the shapes that XLA has figured out. This is safe here
# because we only need to get an abstract value when we form a Jaxpr, which
# will eventually be lowered to XLA.
_, callee_xla_comp = _concrete_function_and_xla_comp(function_flat_tf, args_flat_sig_tf)
result_shape = callee_xla_comp.program_shape().result_shape()
if not result_shape.is_tuple():
# TF does not wrap singletons as tuples, but JAX expects tuples because
# call_tf is a multiple_results primitive.
result_shapes = (result_shape,)
else:
result_shapes = result_shape.tuple_shapes()
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
def res_shape_to_aval(res_shape: xla.XlaShape) -> core.AbstractValue:
return core.ShapedArray(res_shape.dimensions(),
dtypes.canonicalize_dtype(res_shape.numpy_dtype()))
return tuple(map(res_shape_to_aval, result_shapes))
# See comments in _code_generator_and_avals of why we overkill and do a
# full compilation only to get the abstract avals.
_, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf,
code_gen_optional=True)
return tuple(result_avals)
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
@ -253,19 +241,56 @@ def _call_tf_translation_rule(builder: xla.XlaComputationBuilder, *args_op,
function_flat_tf,
args_flat_sig_tf,
**_):
# This will most likely hit the cache, because use used it for abstract_eval
concrete_function_flat_tf, callee_xla_comp = _concrete_function_and_xla_comp(function_flat_tf, args_flat_sig_tf)
# This will most likely hit the cache, because we used it for abstract_eval
code_gen, _ = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, # type: ignore
code_gen_optional=False)
assert code_gen is not None
return code_gen(builder, args_op)
captured_ops = [] # Same order as captured_inputs
@functools.lru_cache(maxsize=128)
def _code_generator_and_avals(
function_flat_tf,
args_flat_sig_tf,
code_gen_optional=False
) -> Tuple[Optional[Callable[[xla.XlaComputationBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
Sequence[core.ShapedArray]]:
# Returns and caches a code generator (taking a builder and the
# XlaOps for the arguments) and a sequence of result abstract shapes.
# It turns out that both for abstract evaluation and for actual compilation
# it is useful to actually generate the HLO. This is true because in some
# cases just TF-level shape inference is not precise enough to recover the
# output shapes (e.g., b/128924522), even in situations where XLA can compile
# the code, from which we can get the shapes.
# Due to bugs like b/193754660, the compilation may fail. To work around this
# issue we pass the `code_gen_optional` when in an abstract evaluation context
# in which case we fallback on TF shape inference.
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
# We know of one case when TF is sensitive to the values of the tensors that
# affect shapes in the computation. In those cases, however, those tensors
# are inlined in the computation, which we detect below.
args_tf_flat = [
tf.constant((0 if a.dtype != tf.bool else False),
shape=a.shape,
dtype=a.dtype) for a in args_flat_sig_tf]
# TODO(necula): For unoptimized HLO, does it make a difference which device we use?
tf_device_name = "/device:CPU:0"
with jax2tf_internal.inside_call_tf():
concrete_function_flat_tf = function_flat_tf.get_concrete_function(*args_tf_flat)
captured_inputs = []
if concrete_function_flat_tf.captured_inputs:
# The function uses either captured variables or tensors.
msg = (
"call_tf works best with a TensorFlow function that does not capture "
"variables or tensors from the context. "
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax for a discussion. "
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
"call_tf works best with a TensorFlow function that does not capture "
"variables or tensors from the context. "
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion. "
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
logging.warning(msg)
next_var_idx = 0
for inp in concrete_function_flat_tf.captured_inputs:
if inp.dtype == tf.resource: # A variable; assume the next variable
@ -273,75 +298,105 @@ def _call_tf_translation_rule(builder: xla.XlaComputationBuilder, *args_op,
# TODO(necula): better checking that we are picking the right variable
var = concrete_function_flat_tf.variables[next_var_idx]
next_var_idx += 1
inp_const = np.asarray(var)
captured_inputs.append(var)
else:
inp_const = np.asarray(inp)
captured_ops.append(xops.ConstantLiteral(builder, np.asarray(inp_const)))
captured_inputs.append(inp)
res_tf = xops.Call(builder, callee_xla_comp, args_op + tuple(captured_ops))
result_shape = callee_xla_comp.program_shape().result_shape()
if not result_shape.is_tuple():
# TF does not wrap singletons as tuples, but JAX expects tuples because
# call_tf is a multiple_results primitive.
res_untupled = (res_tf,)
else:
res_untupled = tuple(xops.GetTupleElement(res_tf, i) # type: ignore
for i in range(len(result_shape.tuple_shapes())))
# We may have to cast the results to x32 for JAX
def canonicalize_res(res):
res_dtype = builder.get_shape(res).numpy_dtype()
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
if res_dtype != jax_res_dtype:
new_etype = xla_client.dtype_to_etype(jax_res_dtype)
return xops.ConvertElementType(res, new_element_type=new_etype)
else:
return res
canonical_res_untupled = tuple(map(canonicalize_res,
res_untupled))
return xops.Tuple(builder, canonical_res_untupled)
@functools.lru_cache(maxsize=128)
def _concrete_function_and_xla_comp(
function_flat_tf,
args_flat_sig_tf) -> Tuple[TfConcreteFunction,
xla_client.XlaComputation]:
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
args_tf_flat = [
tf.constant((0 if a.dtype != tf.bool else False),
shape=a.shape,
dtype=a.dtype) for a in args_flat_sig_tf
]
# TODO(necula): For unoptimized HLO, does it make a difference which device we use?
tf_device_name = "/device:CPU:0"
with jax2tf_internal.inside_call_tf():
try:
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)(
stage="hlo_serialized", device_name=tf_device_name)
except Exception as e:
msg = ("Error compiling TensorFlow function. call_tf can used " +
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
"compileable functions.")
raise ValueError(msg) from e
# The above has traced the function and in fact has cached a ConcreteFunction
# Grab it now, so that we don't have to construct `args_tf_flat` only to
# get a cache hit.
concrete_function_flat_tf = function_flat_tf.get_concrete_function(*args_tf_flat)
return concrete_function_flat_tf, xla_client.XlaComputation(func_tf_hlo)
try:
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)(
stage="hlo_serialized", device_name=tf_device_name)
except Exception as e:
if type(e) is TypeError and "An op outside of the function building code" in str(e):
# TODO(b/193754660): this may happen if we are in a function context
# Try to salvage the situation if we are just doing abstract_eval, maybe
# for jax2tf.convert. We can do that if all the output_shapes are known.
def is_fully_known_shape(s):
return s.rank is not None and all([d is not None for d in s])
if code_gen_optional and (
all([is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes])):
result_avals = [
# We convert to JAX type, and canonicalize to 32-bit if necessary
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes)]
return None, result_avals
msg = ("Error compiling TensorFlow function. call_tf can used " +
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
"compileable functions. " +
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
raise ValueError(msg) from e
xla_comp = xla_client.XlaComputation(func_tf_hlo)
# Check that the function does not have compile-time constant inputs that
# have been inlined in the compiled code.
xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes()
# Add the captured_inputs to args_flat_sig_tf
expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs)
expected_parameter_shapes = [
xla_client.Shape.array_shape(
xla_client.dtype_to_etype(arg_sig.dtype.as_numpy_dtype),
arg_sig.shape.as_list()).with_major_to_minor_layout_if_absent()
for arg_sig in expected_args_flat_sig_tf]
if xla_comp_parameter_shapes != expected_parameter_shapes:
msg = ("Compiled TensorFlow function has unexpected parameter types " +
f"{xla_comp_parameter_shapes}, while the expected types are " +
f"{expected_parameter_shapes}. Perhaps the TensorFlow function " +
"has shape-influencing inputs, and thus needs to be recompiled " +
"for each value of some inputs. " +
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
raise ValueError(msg)
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray:
res_dtype = res_shape.numpy_dtype()
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
result_shape = xla_comp.program_shape().result_shape()
if not result_shape.is_tuple():
# TF does not wrap singletons as tuples, but JAX expects tuples because
# call_tf is a multiple_results primitive.
result_shapes = (result_shape,)
else:
result_shapes = result_shape.tuple_shapes()
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
def code_gen(builder: xla.XlaShape, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
captured_ops = [xops.ConstantLiteral(builder, np.asarray(inp))
for inp in captured_inputs]
res_tf = xops.Call(builder, xla_comp, args_op + tuple(captured_ops)) # type: ignore
def post_process_result(idx: int, res_aval: core.ShapedArray, res_shape: xla.XlaShape):
res_op = res_tf
if result_shape.is_tuple():
res_op = xops.GetTupleElement(res_tf, idx)
if res_aval.dtype != res_shape.numpy_dtype():
res_op = xops.ConvertElementType(
res_op,
new_element_type=xla_client.dtype_to_etype(res_aval.dtype))
return res_op
results = [
post_process_result(i, res_aval, res_shape)
for i, (res_aval, res_shape) in enumerate(zip(result_avals,
result_shapes))]
return xops.Tuple(builder, results)
return code_gen, result_avals
xla.translations[call_tf_p] = _call_tf_translation_rule
TfVal = jax2tf_internal.TfVal
def _jax2tf_call_tf(*args: TfVal,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
callable_flat_tf: Callable,
**_) -> TfVal:
res_tf_flat = callable_flat_tf(*args)
return res_tf_flat
jax2tf_internal.tf_impl_with_avals[call_tf_p] = _jax2tf_call_tf
jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf

View File

@ -769,6 +769,7 @@ class TensorFlowTracer(core.Tracer):
if config.jax_enable_checks:
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
assert len(self._aval.shape) == len(val_shape), f"_aval.shape={self._aval.shape} different rank than val_shape={val_shape}"
for aval_dim, val_dim in zip(
self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
@ -853,6 +854,9 @@ class TensorFlowTrace(core.Trace):
params) -> TensorFlowTracer:
impl, impl_needs_avals = self.get_primitive_impl(primitive)
args_avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
# This is a bit conservative, doing abstract_eval even in op-by-op execution
# but we needed it for, e.g., shape_polymorphism where only JAX's
# abstract evaluation rules can properly track polymorphic shapes.
out_aval = primitive.abstract_eval(*args_avals, **params)
args_tf: Sequence[TfVal] = [t.val for t in tracers]
def invoke_impl() -> TfVal:
@ -2475,7 +2479,8 @@ def _slice(operand, start_indices, limit_indices, strides, _in_avals,
_eval_shape(strides)))
out = operand[slices]
# TODO(b/184503314): improve shape inference for __getitem__
out.set_shape(_aval_to_tf_shape(_out_aval))
#out.set_shape(_aval_to_tf_shape(_out_aval))
#assert False, f"start_indices={start_indices}, limit_indices={limit_indices}, strides={strides}, out={out}"
return out

View File

@ -76,7 +76,7 @@ class CallTfTest(jtu.JaxTestCase):
self.assertAllClose(4., res, check_dtypes=False)
@parameterized_jit
def test_eval_numpy_arg(self, with_jit=False):
def test_eval_numpy_arg(self, with_jit=True):
x = np.ones((2, 3), dtype=np.float32)
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
self.assertAllClose(jnp.sin(x), res)
@ -207,21 +207,23 @@ class CallTfTest(jtu.JaxTestCase):
self.assertAllClose(
np.array([True, False, False, False], dtype=np.bool_), res)
def test_x64_input(self):
@parameterized_jit
def test_x64_input(self, with_jit=True):
def f_tf(x):
return tf.math.sin(x)
x = 5. # TF interprets this as f64
res_call_tf = jax2tf.call_tf(f_tf)(x)
res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
res_jax = jnp.sin(x)
self.assertAllClose(res_call_tf, res_jax)
def test_x64_output(self):
@parameterized_jit
def test_x64_output(self, with_jit=True):
def f_tf(x):
return (tf.constant(3., tf.float64), x)
x = np.float32(5.)
res_call_tf = jax2tf.call_tf(f_tf)(x)
res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
res_jax = (3., x)
self.assertAllClose(res_call_tf, res_jax)
@ -242,6 +244,20 @@ class CallTfTest(jtu.JaxTestCase):
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
@parameterized_jit
def test_with_var_read_x64(self, with_jit=True):
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("Test fails on GPU")
outer_var_array = np.array([3., 4.], dtype=np.float64)
outer_var = tf.Variable(outer_var_array)
def fun_tf(x):
return x * tf.cast(outer_var, x.dtype) + 1.
x = np.array([2., 5.,], dtype=np.float32)
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
def test_with_var_different_shape(self):
# See https://github.com/google/jax/issues/6050
if jtu.device_under_test() == "gpu":
@ -283,6 +299,28 @@ class CallTfTest(jtu.JaxTestCase):
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
@parameterized_jit
def test_with_tensor_capture_x64(self, with_jit=True):
outer_tensor = tf.constant(3., dtype=np.float64)
def fun_tf(x):
return x * tf.cast(outer_tensor * 3.14, tf.float32) + 1.
x = np.float32(2.)
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
self.assertAllClose(x * 3. * 3.14 + 1., res, check_dtypes=False)
@parameterized_jit
def test_with_value_capture(self, with_jit=True):
outer_val = np.array(3., dtype=np.float32)
def fun_tf(x):
return x * outer_val + 1.
x = np.float32(2.)
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
@parameterized_jit
def test_with_multiple_capture(self, with_jit=True):
if jtu.device_under_test() == "gpu":
@ -475,137 +513,108 @@ class CallTfTest(jtu.JaxTestCase):
res = jax.pmap(fun_jax)(x)
self.assertAllClose(np.float32(3. * (x + 2)), res)
def test_round_trip(self):
f_jax = jnp.sin
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
x = np.float32(0.7)
self.assertAllClose(f_jax(x), f_jax_rt(x))
def test_function_compile_time_constant_inputs(self):
# Call a function for which shape inference does not give an output
# shape.
x = np.array([1, 2, 3], dtype=np.int32)
def fun_tf(x): # x:i32[3]
# Indexing with a dynamic slice makes the TF shape inference return
# a partially known shape.
end_idx = x[1]
res = x[0:end_idx]
return res
def test_round_trip_pytree(self):
def f_jax(x): # x: dict(a=f32, b=f32)
return dict(a=x["a"]+1., b=x)
x = dict(a=0.7, b=0.8)
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
self.assertAllClose(f_jax(x), f_jax_rt(x))
# Call in eager mode. Should work!
res1 = jax2tf.call_tf(fun_tf)(x)
self.assertAllClose(x[0:x[1]], res1)
def test_round_trip_custom_grad(self):
@jax.custom_vjp
def f(x):
return x * x
# Now under jit, should fail because the function is not compileable
with self.assertRaisesRegex(ValueError,
"Compiled TensorFlow function has unexpected parameter types"):
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
fun_jax(x)
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), np.float32(3.) * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
def test_experimental_get_compiler_ir_design_doc(self):
# Not a test of call_tf, but more of how experimental_get_compiler_ir works.
# Examples are from the design doc.
f.defvjp(f_fwd, f_bwd)
# Constant slice. This is the common case.
x = np.zeros((10,), dtype=np.int32)
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
x = np.float32(0.7)
self.assertAllClose(f(x), f_rt(x))
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
def fun_tf(x):
begin = 0
return x[begin:5] # x must be a compile-time constant
def test_round_trip_shape_poly(self):
f_jax = jnp.sin
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
polymorphic_shapes=["(b, ...)"]))
x = np.array([0.7, 0.8], dtype=np.float32)
self.assertAllClose(f_jax(x), f_jax_rt(x))
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
self.assertIn("(arg0.1: s32[10]) -> s32[5]", hlo)
def test_round_trip_saved_model_shape_poly(self):
tracing_count = 0
def f_jax(x):
nonlocal tracing_count
tracing_count += 1
return jnp.sin(x)
# Non-constant slice, but compile-time constant depending only on values.
x = np.zeros((10,), dtype=np.int32)
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
x = np.array([0.7, 0.8], dtype=np.float32)
res_jax = f_jax(x)
self.assertEqual(1, tracing_count)
# Will trace twice, it seems. Once to get the result signature, and once again
# for the actual saving.
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
self.assertGreaterEqual(tracing_count, 2)
tracing_count = 0
f_jax_rt = jax2tf.call_tf(restored_f)
self.assertAllClose(res_jax, f_jax_rt(x))
# Ensure that restored_f works at other batch size as well
y = np.concatenate([x, x])
self.assertEqual(0, tracing_count)
res_jax_y = f_jax(y)
self.assertEqual(1, tracing_count)
# No more tracing for f_jax_rt
self.assertAllClose(res_jax_y, f_jax_rt(y))
self.assertEqual(1, tracing_count)
def fun_tf(x):
begin = x[0]
return x[begin:5] # x must be a compile-time constant
def test_round_trip_custom_grad_saved_model(self):
@jax.custom_vjp
def f(x):
return x * x
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
self.assertIn("() -> s32[5]", hlo)
x = np.ones((10,), dtype=np.int32)
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
self.assertIn("() -> s32[4]", hlo)
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), np.float32(3.) * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
# Non-constant slice, but compile-time constant depending only on shapes.
x = np.zeros((10,), dtype=np.int32)
f.defvjp(f_fwd, f_bwd)
def g(x):
return jnp.sum(f(x))
def fun_tf(x):
begin = tf.shape(x)[0] - 2 # begin is a compile-time constant, even if x is not
return x[begin:]
g_tf = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
[tf.TensorSpec([None], dtype=tf.float32)])
g_rt = jax2tf.call_tf(g_tf)
x = np.array([0.7], dtype=np.float32)
self.assertAllClose(g(x), g_rt(x))
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
self.assertIn("(arg0.1: s32[10]) -> s32[2]", hlo)
def test_round_trip_without_gradient_saved_model(self):
# Explicitly with_gradient=False
f_jax = jnp.sum
# Capture a variable
outer_var = tf.Variable(np.array([3.], dtype=np.float32))
x = np.array([2., 3., 4.], dtype=np.float32)
x = np.array([0.7, 0.8], dtype=np.float32)
f_tf = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(f_jax, with_gradient=False),
[tf.TensorSpec(x.shape, dtype=x.dtype)])
f_rt = jax2tf.call_tf(f_tf)
def fun_tf(x):
return x * tf.broadcast_to(outer_var, x.shape) + 1.
self.assertAllClose(f_jax(x), f_rt(x))
with self.assertRaisesRegex(Exception,
"Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
jax.grad(f_rt)(x)
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
self.assertIn("(arg0.1: f32[3], arg1.2: f32[1]) -> f32[3]", hlo)
def test_round_trip_saved_model_no_gradients(self):
# Save without gradients
f_jax = jnp.sum
# Capture a constant
outer_ct = np.array([3.], dtype=np.float32)
x = np.array([2., 3., 4.], dtype=np.float32)
x = np.array([0.7, 0.8], dtype=np.float32)
f_tf = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(f_jax, with_gradient=True),
[tf.TensorSpec(x.shape, dtype=x.dtype)],
save_gradients=False)
f_rt = jax2tf.call_tf(f_tf)
def fun_tf(x):
return x * tf.broadcast_to(outer_ct, x.shape) + 1.
self.assertAllClose(f_jax(x), f_rt(x))
# TODO: clean this up b/191117111: it should fail with a clear error
# The following results in a confusing error:
# TypeError: An op outside of the function building code is being passed
# a "Graph" tensor. It is possible to have Graph tensors
# leak out of the function building context by including a
# tf.init_scope in your function building code.
# For example, the following function will fail:
# @tf.function
# def has_init_scope():
# my_constant = tf.constant(1.)
# with tf.init_scope():
# added = my_constant * 2
# The graph tensor has name: args_0:0
# g = jax.grad(f_rt)(x)
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
self.assertIn("(arg0.1: f32[3]) -> f32[3]", hlo)
# Call get_compiler_ir in a function context
x = np.array([2., 3., 4.], dtype=np.float32)
# TODO(b/193754660)
# def fun_tf_outer(x):
# x_const = tf.constant(0, shape=x.shape, dtype=x.dtype)
# _ = tf.function(tf.math.sin, jit_compile=True).experimental_get_compiler_ir(x_const)()
# with self.assertRaisesRegex(
# TypeError, "An op outside of the function building code is being passed"):
# tf.function(fun_tf_outer)(x)
#
# with self.assertRaisesRegex(
# TypeError, "An op outside of the function building code is being passed"):
# tf.function(fun_tf_outer, jit_compile=True)(x)
# Call get_concrete_function in a graph context
def fun_tf_outer(x):
_ = tf.function(tf.math.sin, jit_compile=True).get_concrete_function(tf.TensorSpec(x.shape, x.dtype))
return x
# Outside of a function context, this works.
_ = tf.function(fun_tf_outer)(x)
_ = tf.function(fun_tf_outer, jit_compile=True)(x)
def test_module_documentation(self):
def cos_tf(x):
@ -630,12 +639,289 @@ class CallTfTest(jtu.JaxTestCase):
print(jax.make_jaxpr(cos_tf_sin_jax)(x))
print(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
def test_round_trip_reverse(self):
f_tf = tf.math.sin
f_tf_rt = jax2tf.convert(jax2tf.call_tf(f_tf))
x = np.float32(0.7)
self.assertAllClose(f_tf(x).numpy(), f_tf_rt(x).numpy())
class RoundTripToJaxTest(jtu.JaxTestCase):
"Reloading output of jax2tf into JAX with call_tf"
def setUp(self):
if tf is None:
raise unittest.SkipTest("Test requires tensorflow")
# TODO(b/171320191): this line works around a missing context initialization
# bug in TensorFlow.
_ = tf.add(1, 1)
super().setUp()
def test_simple(self):
f_jax = jnp.sin
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
x = np.float32(0.7)
self.assertAllClose(f_jax(x), f_jax_rt(x))
def test_pytree(self):
def f_jax(x): # x: dict(a=f32, b=f32)
return dict(a=x["a"]+1., b=x)
x = dict(a=0.7, b=0.8)
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
self.assertAllClose(f_jax(x), f_jax_rt(x))
def test_custom_grad(self):
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), np.float32(3.) * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
x = np.float32(0.7)
self.assertAllClose(f(x), f_rt(x))
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
def test_shape_poly(self):
f_jax = jnp.sin
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
polymorphic_shapes=["(b, ...)"]))
x = np.array([0.7, 0.8], dtype=np.float32)
self.assertAllClose(f_jax(x), f_jax_rt(x))
def test_saved_model_simple(self):
def f_jax(x):
return jnp.sin(x)
x = np.array([0.7, 0.8], dtype=np.float32)
f_tf = jax2tf.convert(f_jax)
restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec(x.shape, x.dtype)])
restored_jax = jax2tf.call_tf(restored_tf)
self.assertAllClose(f_jax(x), restored_jax(x))
def test_saved_model_variables(self):
x = np.array([0.7, 0.8], dtype=np.float32)
param = np.array([1., 2.], dtype=np.float32)
def f_jax(param, x):
return jnp.sin(x) + jnp.cos(param)
param_v = tf.Variable(param)
f_tf = jax2tf.convert(f_jax)
_, restored_model = tf_test_util.SaveAndLoadFunction(
lambda x: f_tf(param_v, x),
[tf.TensorSpec(x.shape, x.dtype)],
variables=(param_v,))
restored_jax = jax2tf.call_tf(restored_model.f)
self.assertAllClose(f_jax(param, x), restored_jax(x))
self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
def test_saved_model_shape_poly(self):
tracing_count = 0
def f_jax(x):
nonlocal tracing_count
tracing_count += 1
return jnp.sin(x)
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
x = np.array([0.7, 0.8], dtype=np.float32)
res_jax = f_jax(x)
self.assertEqual(1, tracing_count)
# Will trace twice, it seems. Once to get the result signature, and once again
# for the actual saving.
restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
self.assertGreaterEqual(tracing_count, 2)
tracing_count = 0
f_jax_rt = jax2tf.call_tf(restored_f)
self.assertAllClose(res_jax, f_jax_rt(x))
# Ensure that restored_f works at other batch size as well
y = np.concatenate([x, x])
self.assertEqual(0, tracing_count)
res_jax_y = f_jax(y)
self.assertEqual(1, tracing_count)
# No more tracing for f_jax_rt
self.assertAllClose(res_jax_y, f_jax_rt(y))
self.assertEqual(1, tracing_count)
def test_custom_grad_saved_model(self):
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), np.float32(3.) * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
def g(x):
return jnp.sum(f(x))
g_tf, _ = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
[tf.TensorSpec([None], dtype=tf.float32)])
g_rt = jax2tf.call_tf(g_tf)
x = np.array([0.7], dtype=np.float32)
self.assertAllClose(g(x), g_rt(x))
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
def test_without_gradient_saved_model(self):
# Explicitly with_gradient=False
f_jax = jnp.sum
x = np.array([0.7, 0.8], dtype=np.float32)
f_tf, _ = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(f_jax, with_gradient=False),
[tf.TensorSpec(x.shape, dtype=x.dtype)])
f_rt = jax2tf.call_tf(f_tf)
self.assertAllClose(f_jax(x), f_rt(x))
with self.assertRaisesRegex(Exception,
"Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
jax.grad(f_rt)(x)
def test_saved_model_no_gradients(self):
# Save without gradients
f_jax = jnp.sum
x = np.array([0.7, 0.8], dtype=np.float32)
f_tf, _ = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(f_jax, with_gradient=True),
[tf.TensorSpec(x.shape, dtype=x.dtype)],
save_gradients=False)
f_rt = jax2tf.call_tf(f_tf)
self.assertAllClose(f_jax(x), f_rt(x))
# TODO: clean this up b/191117111: it should fail with a clear error
# The following results in a confusing error:
# TypeError: An op outside of the function building code is being passed
# a "Graph" tensor. It is possible to have Graph tensors
# leak out of the function building context by including a
# tf.init_scope in your function building code.
# For example, the following function will fail:
# @tf.function
# def has_init_scope():
# my_constant = tf.constant(1.)
# with tf.init_scope():
# added = my_constant * 2
# The graph tensor has name: args_0:0
# g = jax.grad(f_rt)(x)
class RoundTripToTfTest(jtu.JaxTestCase):
"Reloading output of call_tf into TF with jax2tf."
def setUp(self):
if tf is None:
raise unittest.SkipTest("Test requires tensorflow")
# TODO(b/171320191): this line works around a missing context initialization
# bug in TensorFlow.
_ = tf.add(1, 1)
super().setUp()
def test_alternate(self):
# Alternate sin/cos with sin in TF and cos in JAX
f_tf_inner = tf.math.sin
def f_jax(x_jax):
y_jax = jnp.cos(x_jax)
z_jax = jax2tf.call_tf(f_tf_inner)(y_jax)
return jnp.cos(z_jax)
def f_tf_outer(x_tf):
y_tf = tf.math.sin(x_tf)
z_tf = jax2tf.convert(f_jax)(y_tf)
return tf.math.sin(z_tf)
x = np.float32(0.7)
self.assertAllClose(np.sin(np.cos(np.sin(np.cos(np.sin(x))))),
f_tf_outer(x).numpy())
xv = tf.Variable(x)
with tf.GradientTape() as tape:
res = f_tf_outer(xv)
g_tf = tape.gradient(res, xv)
# Eager
expected_res = np.sin(np.cos(np.sin(np.cos(np.sin(x)))))
self.assertAllClose(expected_res, f_tf_outer(x).numpy())
# Gradient
expected_grad = (np.cos(np.cos(np.sin(np.cos(np.sin(x))))) *
np.sin(np.sin(np.cos(np.sin(x)))) *
np.cos(np.cos(np.sin(x))) *
np.sin(np.sin(x)) *
np.cos(x))
self.assertAllClose(expected_grad, g_tf.numpy())
# Graph
self.assertAllClose(expected_res,
tf.function(f_tf_outer, autograph=False)(x).numpy())
# Compiled
self.assertAllClose(expected_res,
tf.function(f_tf_outer, autograph=False,
jit_compile=True)(x).numpy())
def test_saved_model(self):
x = np.array([.7, .8], dtype=np.float32)
def fun_tf(x):
return tf.math.sin(x)
def fun_jax(x):
return jax2tf.call_tf(fun_tf)(x)
# Now convert and save to SavedModel
fun_tf_rt = jax2tf.convert(fun_jax)
res = fun_tf_rt(x)
self.assertAllClose(np.sin(x), res.numpy())
# TODO(b/193754660)
res = tf.function(fun_tf_rt)(x)
self.assertAllClose(np.sin(x), res.numpy())
res = tf.function(fun_tf_rt, jit_compile=True)(x)
self.assertAllClose(np.sin(x), res.numpy())
reloaded_f, _ = tf_test_util.SaveAndLoadFunction(
fun_tf_rt, input_signature=[tf.TensorSpec(x.shape, x.dtype)])
res = reloaded_f(x)
self.assertAllClose(np.sin(x), res.numpy())
def test_function_dynamic_shape(self):
# Call a function for which shape inference does not give an output
# shape.
x = np.array([-1, 0, 1], dtype=np.int32)
def fun_tf(x): # x:i32[3]
# The shape depends on the value of x
res = tf.where(x >= 0)
return res
# Call in eager mode. Should work!
res1 = jax2tf.call_tf(fun_tf)(x)
expected = np.array([[1], [2]])
self.assertAllClose(expected, res1, check_dtypes=False)
# Now under jit, should fail because the function is not compileable
with self.assertRaisesRegex(ValueError,
"Error compiling TensorFlow function. call_tf can used in a staged context"):
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
fun_jax(x)
# TODO(necula): this should work in op-by-op mode, but it fails because
# jax2tf.convert does abstract evaluation.
with self.assertRaisesRegex(ValueError,
"Error compiling TensorFlow function. call_tf can used in a staged context"):
fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
fun_tf_rt(x)
def test_shape_polymorphism_error(self):
x = np.array([.7, .8], dtype=np.float32)
def fun_tf(x):
return tf.math.sin(x)
fun_jax = jax2tf.call_tf(fun_tf)
fun_tf_rt = jax2tf.convert(fun_jax,
polymorphic_shapes=["b, ..."])
with self.assertRaisesRegex(
ValueError,
"call_tf cannot be applies to shape-polymorphic arguments"):
fun_tf_rt(x)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -161,7 +161,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
f_tf = jax2tf.convert(f_jax)
res = f_tf(*args)
input_signature = list(tf.TensorSpec(a.shape, a.dtype) for a in args)
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, input_signature)
restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_signature)
res_restored = restored_f(*args)
self.assertAllClose(res, res_restored)
@ -208,7 +208,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
jnp.sin(np.array([3.14, 2.78], dtype=np.float16)))
# Save and restore SavedModel
restored_f = tf_test_util.SaveAndLoadFunction(composed_fn,
restored_f, _ = tf_test_util.SaveAndLoadFunction(composed_fn,
[tf.TensorSpec((2,), dtype=tf.string)])
res_tf_restored = restored_f(x_str)
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())

View File

@ -721,7 +721,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
f_jax = jnp.sin
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
x = np.array([0.7, 0.8], dtype=np.float32)
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
self.assertAllClose(f_jax(x), restored_f(x))
# Ensure that restored_f works at other batch size as well
y = np.concatenate([x, x])
@ -1369,7 +1369,8 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# test will be called "test_prim_xxx_...".
# If you want to run this test for only one harness that includes "foo"
# in the name, add parameter `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES)
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES,
one_containing="dynamic_slice_enable_xla=True_poly_axes=[0]")
def test_prim(self, harness: Harness):
args = harness.dyn_args_maker(self.rng())
poly_axes = harness.params["poly_axes"] # type: Sequence[Sequence[int]]

View File

@ -17,7 +17,7 @@ import dataclasses
import logging
import os
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence, Tuple
from absl.testing import absltest
import jax
@ -87,14 +87,17 @@ def SaveAndLoadModel(model: tf.Module,
def SaveAndLoadFunction(f_tf: Callable,
input_signature: Sequence[tf.TensorSpec],
save_gradients=True) -> Callable:
# Roundtrip through saved model on disk
model = tf.Module()
variables: Sequence[tf.Variable] = (),
save_gradients=True) -> Tuple[Callable, tf.train.Checkpoint]:
# Roundtrip through saved model on disk. Return the Checkpoint also
# for the cases when there are variables.
model = tf.train.Checkpoint()
model.f = tf.function(f_tf,
autograph=False,
input_signature=input_signature)
model.variables = variables
restored = SaveAndLoadModel(model, save_gradients=save_gradients)
return restored.f
return restored.f, restored
class JaxToTfTestCase(jtu.JaxTestCase):