mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[jax2tf] Improved error checking for call_tf.
Cleaned up the abstract evaluation for call_tf to work around a bug in TF whereby experimental_get_compiler_ir cannot be used in a tf.function context. Added more error messages, e.g., for the case when the TF function has shape-influencing inputs.
This commit is contained in:
parent
36d06dbb61
commit
a966157548
@ -906,30 +906,66 @@ custom gradients,
|
||||
see the test `test_round_trip_custom_grad_saved_model`
|
||||
in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py).
|
||||
|
||||
## Notes:
|
||||
All the metadata inserted by TF during tracing and compilation, e.g.,
|
||||
source location information and op names, is carried through to the
|
||||
JAX XLA computation.
|
||||
|
||||
* The TF function must be compileable (`tf.function(func, jit_compile=True`)
|
||||
when used in a JAX staging context, but not when used in a
|
||||
JAX op-by-op mode.
|
||||
* All the metadata inserted by TF during tracing and compilation, e.g.,
|
||||
source location information and op names, is carried through to the
|
||||
JAX XLA computation.
|
||||
* The TF custom gradients are respected, since it is TF that generates the
|
||||
gradient computation.
|
||||
* In op-by-op mode, when we call TensorFlow in eager mode, we use
|
||||
DLPack to try to avoid copying the data. This works for CPU (for
|
||||
DeviceArray data or for np.ndarray that are aligned on 16-byte
|
||||
boundaries) and on GPU (for DeviceArray).
|
||||
The zero-copy does not yet work on TPU.
|
||||
* ``call_tf`` works best with pure TF functions that do not capture
|
||||
``tf.Variable``s or tensors from the environment, and all such
|
||||
context is passed in explicitly through arguments, and if variables
|
||||
are modified, the resulting values are passed out through results.
|
||||
There is a best-effort mechanism that can handle variable capture
|
||||
and variable updates,
|
||||
except in the case of a function that modifies ``tf.Variable``s
|
||||
and is used in a JAX jitted context. Calling the ``inpure_func_tf``
|
||||
will give an error:
|
||||
The TF custom gradients are respected, since it is TF that generates the
|
||||
gradient computation.
|
||||
|
||||
In op-by-op mode, when we call TensorFlow in eager mode, we use
|
||||
DLPack to try to avoid copying the data. This works for CPU (for
|
||||
DeviceArray data or for np.ndarray that are aligned on 16-byte
|
||||
boundaries) and on GPU (for DeviceArray).
|
||||
The zero-copy does not yet work on TPU.
|
||||
|
||||
|
||||
### Limitations of call_tf
|
||||
|
||||
The TF function must be compileable (`tf.function(func, jit_compile=True`)
|
||||
when used in a JAX staging context, e.g., `jax.jit`, `lax.scan`, `lax.cond`,
|
||||
but not when used in a JAX op-by-op mode. For example, the following
|
||||
function uses strings operations that are not supported by XLA:
|
||||
|
||||
```python
|
||||
def f_tf_non_compileable(x):
|
||||
return tf.strings.length(tf.strings.format("Hello {}!", [x]))
|
||||
|
||||
f_jax = jax2tf.call_tf(f_tf_non_compileable)
|
||||
# Works in op-by-op mode
|
||||
f_jax(np.float32(42.))
|
||||
|
||||
# Fails in jit mode
|
||||
jax.jit(f_jax)(np.float(42.))
|
||||
```
|
||||
|
||||
Another similar situation is when a function uses input values in
|
||||
place of shapes. In this case TF actually does compile the function
|
||||
but re-compiles it for each distinct value of the input. This is
|
||||
not allowed when used from JAX:
|
||||
|
||||
```python
|
||||
def f_tf_dynamic_shape(x):
|
||||
return x[x[0]:5]
|
||||
x = np.array([1, 2], dtype=np.int32)
|
||||
|
||||
f_jax = jax2tf.call_tf(f_tf_dynamic_shape)
|
||||
# Works in op-by-op mode
|
||||
f_jax(x)
|
||||
|
||||
# Fails in jit mode
|
||||
jax.jit(f_jax)(x)
|
||||
```
|
||||
|
||||
``call_tf`` works best with pure TF functions that do not capture
|
||||
``tf.Variable``s or tensors from the environment, and all such
|
||||
context is passed in explicitly through arguments, and if variables
|
||||
are modified, the resulting values are passed out through results.
|
||||
There is a best-effort mechanism that can handle variable capture
|
||||
and variable updates,
|
||||
except in the case of a function that modifies ``tf.Variable``s
|
||||
and is used in a JAX jitted context. Calling the ``inpure_func_tf``
|
||||
will give an error:
|
||||
|
||||
```python
|
||||
var1 = tf.Variable(1.)
|
||||
@ -940,7 +976,7 @@ in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/ja
|
||||
jax.jit(jax2tf.call_tf(impure_func_tf))(tf.constant(2.)) # Fails in jit mode
|
||||
```
|
||||
|
||||
The error can be avoided by passing the variable explicitly:
|
||||
The error can be avoided by passing the variable explicitly:
|
||||
|
||||
```python
|
||||
def pure_func_tf(x, var1)
|
||||
@ -948,25 +984,35 @@ in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/ja
|
||||
return x + new_var1, new_var1
|
||||
```
|
||||
|
||||
This use case is likely to be revisited.
|
||||
This use case is likely to be revisited.
|
||||
|
||||
## TODO
|
||||
A TF function wrapped with `call_tf` cannot be applied to inputs whose
|
||||
shapes are not constants. The may arise when you try to apply
|
||||
`jax2tf.convert` with polymorphic shapes on the result of
|
||||
`call_tf`:
|
||||
|
||||
* Ensure that there is no array copy through the host when running in eager
|
||||
mode (JAX op-by-op).
|
||||
```python
|
||||
def fun_jax(x):
|
||||
return jax2tf.call_tf(tf.math.sin)(x)
|
||||
|
||||
# The following will throw an error.
|
||||
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."])(x)
|
||||
```
|
||||
|
||||
# Additional notes
|
||||
This is unsatisfying, because the result of the above conversion
|
||||
could be simply `tf.math.sin`, which is batch polymorphic. But
|
||||
JAX cannot keep track of shapes through a `call_tf` call, and it
|
||||
cannot be sure that the shape-polymorphic conversion is safe.
|
||||
|
||||
# Misc notes
|
||||
|
||||
## TensorFlow versions supported
|
||||
|
||||
The ``jax2tf.convert`` and `call_tf` require very recent versions of TensorFlow.
|
||||
As of today, the tests are run using `tf_nightly==2.6.0-dev20210611`.
|
||||
|
||||
As of today, the tests are run using `tf_nightly==2.7.0.dev20210715`.
|
||||
|
||||
## Running on GPU
|
||||
|
||||
|
||||
To run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support
|
||||
for CUDA. One must be mindful to install a version of CUDA that is compatible
|
||||
with both [jaxlib](https://github.com/google/jax/blob/main/README.md#pip-installation) and
|
||||
|
@ -24,7 +24,7 @@ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#callin
|
||||
"""
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Sequence, Tuple
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
@ -97,10 +97,16 @@ def call_tf(callable_tf: Callable) -> Callable:
|
||||
return v
|
||||
|
||||
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
|
||||
args_flat_sig_tf = tuple(
|
||||
tf.TensorSpec(a_jax.shape, jax2tf_internal._to_tf_dtype(a_jax.dtype))
|
||||
for a_jax in args_flat_jax
|
||||
)
|
||||
def make_tensorspec(a_jax):
|
||||
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
|
||||
if any(not core.is_constant_dim(d) for d in a_jax.shape):
|
||||
msg = ("call_tf cannot be applies to shape-polymorphic arguments. "
|
||||
f"Found argument shape: {a_jax.shape}. "
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
|
||||
raise ValueError(msg)
|
||||
|
||||
return tf.TensorSpec(a_jax.shape, a_tf_dtype)
|
||||
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
|
||||
|
||||
res_treedef = None # We'll store here the result treedef
|
||||
# The function below will be called at least once, either in eager
|
||||
@ -116,7 +122,7 @@ def call_tf(callable_tf: Callable) -> Callable:
|
||||
|
||||
# Prepare a tf.function ahead of time, to cache the concrete functions. This
|
||||
# won't be used in op-by-op execution mode.
|
||||
function_flat_tf = tf.function(callable_flat_tf, jit_compile=True)
|
||||
function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True)
|
||||
|
||||
res_jax_flat = call_tf_p.bind(
|
||||
*args_flat_jax,
|
||||
@ -222,29 +228,11 @@ call_tf_p.def_impl(_call_tf_impl)
|
||||
def _call_tf_abstract_eval(*_,
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf, **__):
|
||||
# It seems that we cannot count on just TF shape inference to get the
|
||||
# resulting shapes, because tf.function.get_concrete_function sometimes
|
||||
# returns partially known shapes for a TF graph that was loaded with unknown
|
||||
# shapes (e.g., b/128924522). So, we just compile
|
||||
# the code and use the shapes that XLA has figured out. This is safe here
|
||||
# because we only need to get an abstract value when we form a Jaxpr, which
|
||||
# will eventually be lowered to XLA.
|
||||
_, callee_xla_comp = _concrete_function_and_xla_comp(function_flat_tf, args_flat_sig_tf)
|
||||
result_shape = callee_xla_comp.program_shape().result_shape()
|
||||
if not result_shape.is_tuple():
|
||||
# TF does not wrap singletons as tuples, but JAX expects tuples because
|
||||
# call_tf is a multiple_results primitive.
|
||||
result_shapes = (result_shape,)
|
||||
else:
|
||||
result_shapes = result_shape.tuple_shapes()
|
||||
|
||||
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
|
||||
def res_shape_to_aval(res_shape: xla.XlaShape) -> core.AbstractValue:
|
||||
return core.ShapedArray(res_shape.dimensions(),
|
||||
dtypes.canonicalize_dtype(res_shape.numpy_dtype()))
|
||||
|
||||
return tuple(map(res_shape_to_aval, result_shapes))
|
||||
|
||||
# See comments in _code_generator_and_avals of why we overkill and do a
|
||||
# full compilation only to get the abstract avals.
|
||||
_, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf,
|
||||
code_gen_optional=True)
|
||||
return tuple(result_avals)
|
||||
|
||||
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
|
||||
|
||||
@ -253,19 +241,56 @@ def _call_tf_translation_rule(builder: xla.XlaComputationBuilder, *args_op,
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
**_):
|
||||
# This will most likely hit the cache, because use used it for abstract_eval
|
||||
concrete_function_flat_tf, callee_xla_comp = _concrete_function_and_xla_comp(function_flat_tf, args_flat_sig_tf)
|
||||
# This will most likely hit the cache, because we used it for abstract_eval
|
||||
code_gen, _ = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf, # type: ignore
|
||||
code_gen_optional=False)
|
||||
assert code_gen is not None
|
||||
return code_gen(builder, args_op)
|
||||
|
||||
captured_ops = [] # Same order as captured_inputs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _code_generator_and_avals(
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
code_gen_optional=False
|
||||
) -> Tuple[Optional[Callable[[xla.XlaComputationBuilder, Sequence[xla.XlaOp]], xla.XlaOp]],
|
||||
Sequence[core.ShapedArray]]:
|
||||
# Returns and caches a code generator (taking a builder and the
|
||||
# XlaOps for the arguments) and a sequence of result abstract shapes.
|
||||
|
||||
# It turns out that both for abstract evaluation and for actual compilation
|
||||
# it is useful to actually generate the HLO. This is true because in some
|
||||
# cases just TF-level shape inference is not precise enough to recover the
|
||||
# output shapes (e.g., b/128924522), even in situations where XLA can compile
|
||||
# the code, from which we can get the shapes.
|
||||
|
||||
# Due to bugs like b/193754660, the compilation may fail. To work around this
|
||||
# issue we pass the `code_gen_optional` when in an abstract evaluation context
|
||||
# in which case we fallback on TF shape inference.
|
||||
|
||||
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
|
||||
# We know of one case when TF is sensitive to the values of the tensors that
|
||||
# affect shapes in the computation. In those cases, however, those tensors
|
||||
# are inlined in the computation, which we detect below.
|
||||
args_tf_flat = [
|
||||
tf.constant((0 if a.dtype != tf.bool else False),
|
||||
shape=a.shape,
|
||||
dtype=a.dtype) for a in args_flat_sig_tf]
|
||||
|
||||
# TODO(necula): For unoptimized HLO, does it make a difference which device we use?
|
||||
tf_device_name = "/device:CPU:0"
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
concrete_function_flat_tf = function_flat_tf.get_concrete_function(*args_tf_flat)
|
||||
|
||||
captured_inputs = []
|
||||
if concrete_function_flat_tf.captured_inputs:
|
||||
# The function uses either captured variables or tensors.
|
||||
msg = (
|
||||
"call_tf works best with a TensorFlow function that does not capture "
|
||||
"variables or tensors from the context. "
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax for a discussion. "
|
||||
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
|
||||
"call_tf works best with a TensorFlow function that does not capture "
|
||||
"variables or tensors from the context. "
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion. "
|
||||
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
|
||||
logging.warning(msg)
|
||||
|
||||
next_var_idx = 0
|
||||
for inp in concrete_function_flat_tf.captured_inputs:
|
||||
if inp.dtype == tf.resource: # A variable; assume the next variable
|
||||
@ -273,75 +298,105 @@ def _call_tf_translation_rule(builder: xla.XlaComputationBuilder, *args_op,
|
||||
# TODO(necula): better checking that we are picking the right variable
|
||||
var = concrete_function_flat_tf.variables[next_var_idx]
|
||||
next_var_idx += 1
|
||||
inp_const = np.asarray(var)
|
||||
captured_inputs.append(var)
|
||||
else:
|
||||
inp_const = np.asarray(inp)
|
||||
captured_ops.append(xops.ConstantLiteral(builder, np.asarray(inp_const)))
|
||||
captured_inputs.append(inp)
|
||||
|
||||
res_tf = xops.Call(builder, callee_xla_comp, args_op + tuple(captured_ops))
|
||||
result_shape = callee_xla_comp.program_shape().result_shape()
|
||||
if not result_shape.is_tuple():
|
||||
# TF does not wrap singletons as tuples, but JAX expects tuples because
|
||||
# call_tf is a multiple_results primitive.
|
||||
res_untupled = (res_tf,)
|
||||
else:
|
||||
res_untupled = tuple(xops.GetTupleElement(res_tf, i) # type: ignore
|
||||
for i in range(len(result_shape.tuple_shapes())))
|
||||
# We may have to cast the results to x32 for JAX
|
||||
def canonicalize_res(res):
|
||||
res_dtype = builder.get_shape(res).numpy_dtype()
|
||||
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
|
||||
if res_dtype != jax_res_dtype:
|
||||
new_etype = xla_client.dtype_to_etype(jax_res_dtype)
|
||||
return xops.ConvertElementType(res, new_element_type=new_etype)
|
||||
else:
|
||||
return res
|
||||
|
||||
canonical_res_untupled = tuple(map(canonicalize_res,
|
||||
res_untupled))
|
||||
return xops.Tuple(builder, canonical_res_untupled)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _concrete_function_and_xla_comp(
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf) -> Tuple[TfConcreteFunction,
|
||||
xla_client.XlaComputation]:
|
||||
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
|
||||
args_tf_flat = [
|
||||
tf.constant((0 if a.dtype != tf.bool else False),
|
||||
shape=a.shape,
|
||||
dtype=a.dtype) for a in args_flat_sig_tf
|
||||
]
|
||||
|
||||
# TODO(necula): For unoptimized HLO, does it make a difference which device we use?
|
||||
tf_device_name = "/device:CPU:0"
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
try:
|
||||
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)(
|
||||
stage="hlo_serialized", device_name=tf_device_name)
|
||||
except Exception as e:
|
||||
msg = ("Error compiling TensorFlow function. call_tf can used " +
|
||||
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
||||
"compileable functions.")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# The above has traced the function and in fact has cached a ConcreteFunction
|
||||
# Grab it now, so that we don't have to construct `args_tf_flat` only to
|
||||
# get a cache hit.
|
||||
concrete_function_flat_tf = function_flat_tf.get_concrete_function(*args_tf_flat)
|
||||
return concrete_function_flat_tf, xla_client.XlaComputation(func_tf_hlo)
|
||||
try:
|
||||
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)(
|
||||
stage="hlo_serialized", device_name=tf_device_name)
|
||||
except Exception as e:
|
||||
if type(e) is TypeError and "An op outside of the function building code" in str(e):
|
||||
# TODO(b/193754660): this may happen if we are in a function context
|
||||
# Try to salvage the situation if we are just doing abstract_eval, maybe
|
||||
# for jax2tf.convert. We can do that if all the output_shapes are known.
|
||||
def is_fully_known_shape(s):
|
||||
return s.rank is not None and all([d is not None for d in s])
|
||||
if code_gen_optional and (
|
||||
all([is_fully_known_shape(s)
|
||||
for s in concrete_function_flat_tf.output_shapes])):
|
||||
result_avals = [
|
||||
# We convert to JAX type, and canonicalize to 32-bit if necessary
|
||||
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
|
||||
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
|
||||
concrete_function_flat_tf.output_shapes)]
|
||||
return None, result_avals
|
||||
msg = ("Error compiling TensorFlow function. call_tf can used " +
|
||||
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
||||
"compileable functions. " +
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
xla_comp = xla_client.XlaComputation(func_tf_hlo)
|
||||
# Check that the function does not have compile-time constant inputs that
|
||||
# have been inlined in the compiled code.
|
||||
xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes()
|
||||
# Add the captured_inputs to args_flat_sig_tf
|
||||
expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs)
|
||||
expected_parameter_shapes = [
|
||||
xla_client.Shape.array_shape(
|
||||
xla_client.dtype_to_etype(arg_sig.dtype.as_numpy_dtype),
|
||||
arg_sig.shape.as_list()).with_major_to_minor_layout_if_absent()
|
||||
for arg_sig in expected_args_flat_sig_tf]
|
||||
if xla_comp_parameter_shapes != expected_parameter_shapes:
|
||||
msg = ("Compiled TensorFlow function has unexpected parameter types " +
|
||||
f"{xla_comp_parameter_shapes}, while the expected types are " +
|
||||
f"{expected_parameter_shapes}. Perhaps the TensorFlow function " +
|
||||
"has shape-influencing inputs, and thus needs to be recompiled " +
|
||||
"for each value of some inputs. " +
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.")
|
||||
raise ValueError(msg)
|
||||
|
||||
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
|
||||
def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray:
|
||||
res_dtype = res_shape.numpy_dtype()
|
||||
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
|
||||
return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
|
||||
|
||||
result_shape = xla_comp.program_shape().result_shape()
|
||||
if not result_shape.is_tuple():
|
||||
# TF does not wrap singletons as tuples, but JAX expects tuples because
|
||||
# call_tf is a multiple_results primitive.
|
||||
result_shapes = (result_shape,)
|
||||
else:
|
||||
result_shapes = result_shape.tuple_shapes()
|
||||
|
||||
result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore
|
||||
|
||||
def code_gen(builder: xla.XlaShape, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
|
||||
captured_ops = [xops.ConstantLiteral(builder, np.asarray(inp))
|
||||
for inp in captured_inputs]
|
||||
|
||||
res_tf = xops.Call(builder, xla_comp, args_op + tuple(captured_ops)) # type: ignore
|
||||
def post_process_result(idx: int, res_aval: core.ShapedArray, res_shape: xla.XlaShape):
|
||||
res_op = res_tf
|
||||
if result_shape.is_tuple():
|
||||
res_op = xops.GetTupleElement(res_tf, idx)
|
||||
if res_aval.dtype != res_shape.numpy_dtype():
|
||||
res_op = xops.ConvertElementType(
|
||||
res_op,
|
||||
new_element_type=xla_client.dtype_to_etype(res_aval.dtype))
|
||||
return res_op
|
||||
|
||||
results = [
|
||||
post_process_result(i, res_aval, res_shape)
|
||||
for i, (res_aval, res_shape) in enumerate(zip(result_avals,
|
||||
result_shapes))]
|
||||
return xops.Tuple(builder, results)
|
||||
|
||||
return code_gen, result_avals
|
||||
|
||||
xla.translations[call_tf_p] = _call_tf_translation_rule
|
||||
|
||||
TfVal = jax2tf_internal.TfVal
|
||||
def _jax2tf_call_tf(*args: TfVal,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray,
|
||||
callable_flat_tf: Callable,
|
||||
**_) -> TfVal:
|
||||
res_tf_flat = callable_flat_tf(*args)
|
||||
return res_tf_flat
|
||||
|
||||
jax2tf_internal.tf_impl_with_avals[call_tf_p] = _jax2tf_call_tf
|
||||
jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
|
||||
|
@ -769,6 +769,7 @@ class TensorFlowTracer(core.Tracer):
|
||||
|
||||
if config.jax_enable_checks:
|
||||
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
|
||||
assert len(self._aval.shape) == len(val_shape), f"_aval.shape={self._aval.shape} different rank than val_shape={val_shape}"
|
||||
for aval_dim, val_dim in zip(
|
||||
self._aval.shape, val_shape): # type: ignore[attr-defined]
|
||||
if val_dim is None:
|
||||
@ -853,6 +854,9 @@ class TensorFlowTrace(core.Trace):
|
||||
params) -> TensorFlowTracer:
|
||||
impl, impl_needs_avals = self.get_primitive_impl(primitive)
|
||||
args_avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
|
||||
# This is a bit conservative, doing abstract_eval even in op-by-op execution
|
||||
# but we needed it for, e.g., shape_polymorphism where only JAX's
|
||||
# abstract evaluation rules can properly track polymorphic shapes.
|
||||
out_aval = primitive.abstract_eval(*args_avals, **params)
|
||||
args_tf: Sequence[TfVal] = [t.val for t in tracers]
|
||||
def invoke_impl() -> TfVal:
|
||||
@ -2475,7 +2479,8 @@ def _slice(operand, start_indices, limit_indices, strides, _in_avals,
|
||||
_eval_shape(strides)))
|
||||
out = operand[slices]
|
||||
# TODO(b/184503314): improve shape inference for __getitem__
|
||||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||||
#out.set_shape(_aval_to_tf_shape(_out_aval))
|
||||
#assert False, f"start_indices={start_indices}, limit_indices={limit_indices}, strides={strides}, out={out}"
|
||||
return out
|
||||
|
||||
|
||||
|
@ -76,7 +76,7 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(4., res, check_dtypes=False)
|
||||
|
||||
@parameterized_jit
|
||||
def test_eval_numpy_arg(self, with_jit=False):
|
||||
def test_eval_numpy_arg(self, with_jit=True):
|
||||
x = np.ones((2, 3), dtype=np.float32)
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
|
||||
self.assertAllClose(jnp.sin(x), res)
|
||||
@ -207,21 +207,23 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
np.array([True, False, False, False], dtype=np.bool_), res)
|
||||
|
||||
def test_x64_input(self):
|
||||
@parameterized_jit
|
||||
def test_x64_input(self, with_jit=True):
|
||||
def f_tf(x):
|
||||
return tf.math.sin(x)
|
||||
|
||||
x = 5. # TF interprets this as f64
|
||||
res_call_tf = jax2tf.call_tf(f_tf)(x)
|
||||
res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
|
||||
res_jax = jnp.sin(x)
|
||||
self.assertAllClose(res_call_tf, res_jax)
|
||||
|
||||
def test_x64_output(self):
|
||||
@parameterized_jit
|
||||
def test_x64_output(self, with_jit=True):
|
||||
def f_tf(x):
|
||||
return (tf.constant(3., tf.float64), x)
|
||||
|
||||
x = np.float32(5.)
|
||||
res_call_tf = jax2tf.call_tf(f_tf)(x)
|
||||
res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
|
||||
res_jax = (3., x)
|
||||
self.assertAllClose(res_call_tf, res_jax)
|
||||
|
||||
@ -242,6 +244,20 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
||||
self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
|
||||
|
||||
@parameterized_jit
|
||||
def test_with_var_read_x64(self, with_jit=True):
|
||||
if jtu.device_under_test() == "gpu":
|
||||
raise unittest.SkipTest("Test fails on GPU")
|
||||
outer_var_array = np.array([3., 4.], dtype=np.float64)
|
||||
outer_var = tf.Variable(outer_var_array)
|
||||
|
||||
def fun_tf(x):
|
||||
return x * tf.cast(outer_var, x.dtype) + 1.
|
||||
|
||||
x = np.array([2., 5.,], dtype=np.float32)
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
||||
self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
|
||||
|
||||
def test_with_var_different_shape(self):
|
||||
# See https://github.com/google/jax/issues/6050
|
||||
if jtu.device_under_test() == "gpu":
|
||||
@ -283,6 +299,28 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
||||
self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
|
||||
|
||||
@parameterized_jit
|
||||
def test_with_tensor_capture_x64(self, with_jit=True):
|
||||
outer_tensor = tf.constant(3., dtype=np.float64)
|
||||
|
||||
def fun_tf(x):
|
||||
return x * tf.cast(outer_tensor * 3.14, tf.float32) + 1.
|
||||
|
||||
x = np.float32(2.)
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
||||
self.assertAllClose(x * 3. * 3.14 + 1., res, check_dtypes=False)
|
||||
|
||||
@parameterized_jit
|
||||
def test_with_value_capture(self, with_jit=True):
|
||||
outer_val = np.array(3., dtype=np.float32)
|
||||
|
||||
def fun_tf(x):
|
||||
return x * outer_val + 1.
|
||||
|
||||
x = np.float32(2.)
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
||||
self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
|
||||
|
||||
@parameterized_jit
|
||||
def test_with_multiple_capture(self, with_jit=True):
|
||||
if jtu.device_under_test() == "gpu":
|
||||
@ -475,137 +513,108 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
res = jax.pmap(fun_jax)(x)
|
||||
self.assertAllClose(np.float32(3. * (x + 2)), res)
|
||||
|
||||
def test_round_trip(self):
|
||||
f_jax = jnp.sin
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
def test_function_compile_time_constant_inputs(self):
|
||||
# Call a function for which shape inference does not give an output
|
||||
# shape.
|
||||
x = np.array([1, 2, 3], dtype=np.int32)
|
||||
def fun_tf(x): # x:i32[3]
|
||||
# Indexing with a dynamic slice makes the TF shape inference return
|
||||
# a partially known shape.
|
||||
end_idx = x[1]
|
||||
res = x[0:end_idx]
|
||||
return res
|
||||
|
||||
def test_round_trip_pytree(self):
|
||||
def f_jax(x): # x: dict(a=f32, b=f32)
|
||||
return dict(a=x["a"]+1., b=x)
|
||||
x = dict(a=0.7, b=0.8)
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
# Call in eager mode. Should work!
|
||||
res1 = jax2tf.call_tf(fun_tf)(x)
|
||||
self.assertAllClose(x[0:x[1]], res1)
|
||||
|
||||
def test_round_trip_custom_grad(self):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * x
|
||||
# Now under jit, should fail because the function is not compileable
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Compiled TensorFlow function has unexpected parameter types"):
|
||||
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
||||
fun_jax(x)
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def f_fwd(x):
|
||||
return f(x), np.float32(3.) * x
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def f_bwd(residual, ct_b):
|
||||
return residual * ct_b,
|
||||
def test_experimental_get_compiler_ir_design_doc(self):
|
||||
# Not a test of call_tf, but more of how experimental_get_compiler_ir works.
|
||||
# Examples are from the design doc.
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
# Constant slice. This is the common case.
|
||||
x = np.zeros((10,), dtype=np.int32)
|
||||
|
||||
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f(x), f_rt(x))
|
||||
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
|
||||
def fun_tf(x):
|
||||
begin = 0
|
||||
return x[begin:5] # x must be a compile-time constant
|
||||
|
||||
def test_round_trip_shape_poly(self):
|
||||
f_jax = jnp.sin
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
|
||||
polymorphic_shapes=["(b, ...)"]))
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("(arg0.1: s32[10]) -> s32[5]", hlo)
|
||||
|
||||
def test_round_trip_saved_model_shape_poly(self):
|
||||
tracing_count = 0
|
||||
def f_jax(x):
|
||||
nonlocal tracing_count
|
||||
tracing_count += 1
|
||||
return jnp.sin(x)
|
||||
# Non-constant slice, but compile-time constant depending only on values.
|
||||
x = np.zeros((10,), dtype=np.int32)
|
||||
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
res_jax = f_jax(x)
|
||||
self.assertEqual(1, tracing_count)
|
||||
# Will trace twice, it seems. Once to get the result signature, and once again
|
||||
# for the actual saving.
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||
self.assertGreaterEqual(tracing_count, 2)
|
||||
tracing_count = 0
|
||||
f_jax_rt = jax2tf.call_tf(restored_f)
|
||||
self.assertAllClose(res_jax, f_jax_rt(x))
|
||||
# Ensure that restored_f works at other batch size as well
|
||||
y = np.concatenate([x, x])
|
||||
self.assertEqual(0, tracing_count)
|
||||
res_jax_y = f_jax(y)
|
||||
self.assertEqual(1, tracing_count)
|
||||
# No more tracing for f_jax_rt
|
||||
self.assertAllClose(res_jax_y, f_jax_rt(y))
|
||||
self.assertEqual(1, tracing_count)
|
||||
def fun_tf(x):
|
||||
begin = x[0]
|
||||
return x[begin:5] # x must be a compile-time constant
|
||||
|
||||
def test_round_trip_custom_grad_saved_model(self):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * x
|
||||
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("() -> s32[5]", hlo)
|
||||
x = np.ones((10,), dtype=np.int32)
|
||||
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("() -> s32[4]", hlo)
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def f_fwd(x):
|
||||
return f(x), np.float32(3.) * x
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def f_bwd(residual, ct_b):
|
||||
return residual * ct_b,
|
||||
# Non-constant slice, but compile-time constant depending only on shapes.
|
||||
x = np.zeros((10,), dtype=np.int32)
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
def g(x):
|
||||
return jnp.sum(f(x))
|
||||
def fun_tf(x):
|
||||
begin = tf.shape(x)[0] - 2 # begin is a compile-time constant, even if x is not
|
||||
return x[begin:]
|
||||
|
||||
g_tf = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
|
||||
[tf.TensorSpec([None], dtype=tf.float32)])
|
||||
g_rt = jax2tf.call_tf(g_tf)
|
||||
x = np.array([0.7], dtype=np.float32)
|
||||
self.assertAllClose(g(x), g_rt(x))
|
||||
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
|
||||
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("(arg0.1: s32[10]) -> s32[2]", hlo)
|
||||
|
||||
def test_round_trip_without_gradient_saved_model(self):
|
||||
# Explicitly with_gradient=False
|
||||
f_jax = jnp.sum
|
||||
# Capture a variable
|
||||
outer_var = tf.Variable(np.array([3.], dtype=np.float32))
|
||||
x = np.array([2., 3., 4.], dtype=np.float32)
|
||||
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
f_tf = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(f_jax, with_gradient=False),
|
||||
[tf.TensorSpec(x.shape, dtype=x.dtype)])
|
||||
f_rt = jax2tf.call_tf(f_tf)
|
||||
def fun_tf(x):
|
||||
return x * tf.broadcast_to(outer_var, x.shape) + 1.
|
||||
|
||||
self.assertAllClose(f_jax(x), f_rt(x))
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
|
||||
jax.grad(f_rt)(x)
|
||||
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("(arg0.1: f32[3], arg1.2: f32[1]) -> f32[3]", hlo)
|
||||
|
||||
def test_round_trip_saved_model_no_gradients(self):
|
||||
# Save without gradients
|
||||
f_jax = jnp.sum
|
||||
# Capture a constant
|
||||
outer_ct = np.array([3.], dtype=np.float32)
|
||||
x = np.array([2., 3., 4.], dtype=np.float32)
|
||||
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
f_tf = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(f_jax, with_gradient=True),
|
||||
[tf.TensorSpec(x.shape, dtype=x.dtype)],
|
||||
save_gradients=False)
|
||||
f_rt = jax2tf.call_tf(f_tf)
|
||||
def fun_tf(x):
|
||||
return x * tf.broadcast_to(outer_ct, x.shape) + 1.
|
||||
|
||||
self.assertAllClose(f_jax(x), f_rt(x))
|
||||
# TODO: clean this up b/191117111: it should fail with a clear error
|
||||
# The following results in a confusing error:
|
||||
# TypeError: An op outside of the function building code is being passed
|
||||
# a "Graph" tensor. It is possible to have Graph tensors
|
||||
# leak out of the function building context by including a
|
||||
# tf.init_scope in your function building code.
|
||||
# For example, the following function will fail:
|
||||
# @tf.function
|
||||
# def has_init_scope():
|
||||
# my_constant = tf.constant(1.)
|
||||
# with tf.init_scope():
|
||||
# added = my_constant * 2
|
||||
# The graph tensor has name: args_0:0
|
||||
# g = jax.grad(f_rt)(x)
|
||||
hlo = tf.function(fun_tf, jit_compile=True).experimental_get_compiler_ir(x)()
|
||||
self.assertIn("(arg0.1: f32[3]) -> f32[3]", hlo)
|
||||
|
||||
# Call get_compiler_ir in a function context
|
||||
x = np.array([2., 3., 4.], dtype=np.float32)
|
||||
|
||||
# TODO(b/193754660)
|
||||
# def fun_tf_outer(x):
|
||||
# x_const = tf.constant(0, shape=x.shape, dtype=x.dtype)
|
||||
# _ = tf.function(tf.math.sin, jit_compile=True).experimental_get_compiler_ir(x_const)()
|
||||
|
||||
# with self.assertRaisesRegex(
|
||||
# TypeError, "An op outside of the function building code is being passed"):
|
||||
# tf.function(fun_tf_outer)(x)
|
||||
#
|
||||
# with self.assertRaisesRegex(
|
||||
# TypeError, "An op outside of the function building code is being passed"):
|
||||
# tf.function(fun_tf_outer, jit_compile=True)(x)
|
||||
|
||||
# Call get_concrete_function in a graph context
|
||||
def fun_tf_outer(x):
|
||||
_ = tf.function(tf.math.sin, jit_compile=True).get_concrete_function(tf.TensorSpec(x.shape, x.dtype))
|
||||
return x
|
||||
|
||||
# Outside of a function context, this works.
|
||||
_ = tf.function(fun_tf_outer)(x)
|
||||
_ = tf.function(fun_tf_outer, jit_compile=True)(x)
|
||||
|
||||
def test_module_documentation(self):
|
||||
def cos_tf(x):
|
||||
@ -630,12 +639,289 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
print(jax.make_jaxpr(cos_tf_sin_jax)(x))
|
||||
print(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
|
||||
|
||||
def test_round_trip_reverse(self):
|
||||
f_tf = tf.math.sin
|
||||
f_tf_rt = jax2tf.convert(jax2tf.call_tf(f_tf))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f_tf(x).numpy(), f_tf_rt(x).numpy())
|
||||
class RoundTripToJaxTest(jtu.JaxTestCase):
|
||||
"Reloading output of jax2tf into JAX with call_tf"
|
||||
def setUp(self):
|
||||
if tf is None:
|
||||
raise unittest.SkipTest("Test requires tensorflow")
|
||||
# TODO(b/171320191): this line works around a missing context initialization
|
||||
# bug in TensorFlow.
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
|
||||
def test_simple(self):
|
||||
f_jax = jnp.sin
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
|
||||
def test_pytree(self):
|
||||
def f_jax(x): # x: dict(a=f32, b=f32)
|
||||
return dict(a=x["a"]+1., b=x)
|
||||
x = dict(a=0.7, b=0.8)
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
|
||||
def test_custom_grad(self):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def f_fwd(x):
|
||||
return f(x), np.float32(3.) * x
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def f_bwd(residual, ct_b):
|
||||
return residual * ct_b,
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
|
||||
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f(x), f_rt(x))
|
||||
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
|
||||
|
||||
def test_shape_poly(self):
|
||||
f_jax = jnp.sin
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
|
||||
polymorphic_shapes=["(b, ...)"]))
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
|
||||
def test_saved_model_simple(self):
|
||||
def f_jax(x):
|
||||
return jnp.sin(x)
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec(x.shape, x.dtype)])
|
||||
restored_jax = jax2tf.call_tf(restored_tf)
|
||||
self.assertAllClose(f_jax(x), restored_jax(x))
|
||||
|
||||
def test_saved_model_variables(self):
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
param = np.array([1., 2.], dtype=np.float32)
|
||||
def f_jax(param, x):
|
||||
return jnp.sin(x) + jnp.cos(param)
|
||||
|
||||
param_v = tf.Variable(param)
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
_, restored_model = tf_test_util.SaveAndLoadFunction(
|
||||
lambda x: f_tf(param_v, x),
|
||||
[tf.TensorSpec(x.shape, x.dtype)],
|
||||
variables=(param_v,))
|
||||
restored_jax = jax2tf.call_tf(restored_model.f)
|
||||
self.assertAllClose(f_jax(param, x), restored_jax(x))
|
||||
self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
|
||||
|
||||
def test_saved_model_shape_poly(self):
|
||||
tracing_count = 0
|
||||
def f_jax(x):
|
||||
nonlocal tracing_count
|
||||
tracing_count += 1
|
||||
return jnp.sin(x)
|
||||
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
res_jax = f_jax(x)
|
||||
self.assertEqual(1, tracing_count)
|
||||
# Will trace twice, it seems. Once to get the result signature, and once again
|
||||
# for the actual saving.
|
||||
restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||
self.assertGreaterEqual(tracing_count, 2)
|
||||
tracing_count = 0
|
||||
f_jax_rt = jax2tf.call_tf(restored_f)
|
||||
self.assertAllClose(res_jax, f_jax_rt(x))
|
||||
# Ensure that restored_f works at other batch size as well
|
||||
y = np.concatenate([x, x])
|
||||
self.assertEqual(0, tracing_count)
|
||||
res_jax_y = f_jax(y)
|
||||
self.assertEqual(1, tracing_count)
|
||||
# No more tracing for f_jax_rt
|
||||
self.assertAllClose(res_jax_y, f_jax_rt(y))
|
||||
self.assertEqual(1, tracing_count)
|
||||
|
||||
def test_custom_grad_saved_model(self):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def f_fwd(x):
|
||||
return f(x), np.float32(3.) * x
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def f_bwd(residual, ct_b):
|
||||
return residual * ct_b,
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
def g(x):
|
||||
return jnp.sum(f(x))
|
||||
|
||||
g_tf, _ = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
|
||||
[tf.TensorSpec([None], dtype=tf.float32)])
|
||||
g_rt = jax2tf.call_tf(g_tf)
|
||||
x = np.array([0.7], dtype=np.float32)
|
||||
self.assertAllClose(g(x), g_rt(x))
|
||||
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
|
||||
|
||||
def test_without_gradient_saved_model(self):
|
||||
# Explicitly with_gradient=False
|
||||
f_jax = jnp.sum
|
||||
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
f_tf, _ = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(f_jax, with_gradient=False),
|
||||
[tf.TensorSpec(x.shape, dtype=x.dtype)])
|
||||
f_rt = jax2tf.call_tf(f_tf)
|
||||
|
||||
self.assertAllClose(f_jax(x), f_rt(x))
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
|
||||
jax.grad(f_rt)(x)
|
||||
|
||||
def test_saved_model_no_gradients(self):
|
||||
# Save without gradients
|
||||
f_jax = jnp.sum
|
||||
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
f_tf, _ = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(f_jax, with_gradient=True),
|
||||
[tf.TensorSpec(x.shape, dtype=x.dtype)],
|
||||
save_gradients=False)
|
||||
f_rt = jax2tf.call_tf(f_tf)
|
||||
|
||||
self.assertAllClose(f_jax(x), f_rt(x))
|
||||
# TODO: clean this up b/191117111: it should fail with a clear error
|
||||
# The following results in a confusing error:
|
||||
# TypeError: An op outside of the function building code is being passed
|
||||
# a "Graph" tensor. It is possible to have Graph tensors
|
||||
# leak out of the function building context by including a
|
||||
# tf.init_scope in your function building code.
|
||||
# For example, the following function will fail:
|
||||
# @tf.function
|
||||
# def has_init_scope():
|
||||
# my_constant = tf.constant(1.)
|
||||
# with tf.init_scope():
|
||||
# added = my_constant * 2
|
||||
# The graph tensor has name: args_0:0
|
||||
# g = jax.grad(f_rt)(x)
|
||||
|
||||
class RoundTripToTfTest(jtu.JaxTestCase):
|
||||
"Reloading output of call_tf into TF with jax2tf."
|
||||
|
||||
def setUp(self):
|
||||
if tf is None:
|
||||
raise unittest.SkipTest("Test requires tensorflow")
|
||||
# TODO(b/171320191): this line works around a missing context initialization
|
||||
# bug in TensorFlow.
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
|
||||
def test_alternate(self):
|
||||
# Alternate sin/cos with sin in TF and cos in JAX
|
||||
f_tf_inner = tf.math.sin
|
||||
def f_jax(x_jax):
|
||||
y_jax = jnp.cos(x_jax)
|
||||
z_jax = jax2tf.call_tf(f_tf_inner)(y_jax)
|
||||
return jnp.cos(z_jax)
|
||||
def f_tf_outer(x_tf):
|
||||
y_tf = tf.math.sin(x_tf)
|
||||
z_tf = jax2tf.convert(f_jax)(y_tf)
|
||||
return tf.math.sin(z_tf)
|
||||
|
||||
x = np.float32(0.7)
|
||||
|
||||
self.assertAllClose(np.sin(np.cos(np.sin(np.cos(np.sin(x))))),
|
||||
f_tf_outer(x).numpy())
|
||||
xv = tf.Variable(x)
|
||||
with tf.GradientTape() as tape:
|
||||
res = f_tf_outer(xv)
|
||||
g_tf = tape.gradient(res, xv)
|
||||
# Eager
|
||||
expected_res = np.sin(np.cos(np.sin(np.cos(np.sin(x)))))
|
||||
self.assertAllClose(expected_res, f_tf_outer(x).numpy())
|
||||
|
||||
# Gradient
|
||||
expected_grad = (np.cos(np.cos(np.sin(np.cos(np.sin(x))))) *
|
||||
np.sin(np.sin(np.cos(np.sin(x)))) *
|
||||
np.cos(np.cos(np.sin(x))) *
|
||||
np.sin(np.sin(x)) *
|
||||
np.cos(x))
|
||||
self.assertAllClose(expected_grad, g_tf.numpy())
|
||||
|
||||
# Graph
|
||||
self.assertAllClose(expected_res,
|
||||
tf.function(f_tf_outer, autograph=False)(x).numpy())
|
||||
|
||||
# Compiled
|
||||
self.assertAllClose(expected_res,
|
||||
tf.function(f_tf_outer, autograph=False,
|
||||
jit_compile=True)(x).numpy())
|
||||
|
||||
def test_saved_model(self):
|
||||
x = np.array([.7, .8], dtype=np.float32)
|
||||
def fun_tf(x):
|
||||
return tf.math.sin(x)
|
||||
def fun_jax(x):
|
||||
return jax2tf.call_tf(fun_tf)(x)
|
||||
|
||||
# Now convert and save to SavedModel
|
||||
fun_tf_rt = jax2tf.convert(fun_jax)
|
||||
res = fun_tf_rt(x)
|
||||
self.assertAllClose(np.sin(x), res.numpy())
|
||||
|
||||
# TODO(b/193754660)
|
||||
res = tf.function(fun_tf_rt)(x)
|
||||
self.assertAllClose(np.sin(x), res.numpy())
|
||||
|
||||
res = tf.function(fun_tf_rt, jit_compile=True)(x)
|
||||
self.assertAllClose(np.sin(x), res.numpy())
|
||||
|
||||
reloaded_f, _ = tf_test_util.SaveAndLoadFunction(
|
||||
fun_tf_rt, input_signature=[tf.TensorSpec(x.shape, x.dtype)])
|
||||
res = reloaded_f(x)
|
||||
self.assertAllClose(np.sin(x), res.numpy())
|
||||
|
||||
def test_function_dynamic_shape(self):
|
||||
# Call a function for which shape inference does not give an output
|
||||
# shape.
|
||||
x = np.array([-1, 0, 1], dtype=np.int32)
|
||||
def fun_tf(x): # x:i32[3]
|
||||
# The shape depends on the value of x
|
||||
res = tf.where(x >= 0)
|
||||
return res
|
||||
|
||||
# Call in eager mode. Should work!
|
||||
res1 = jax2tf.call_tf(fun_tf)(x)
|
||||
expected = np.array([[1], [2]])
|
||||
self.assertAllClose(expected, res1, check_dtypes=False)
|
||||
|
||||
# Now under jit, should fail because the function is not compileable
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Error compiling TensorFlow function. call_tf can used in a staged context"):
|
||||
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
||||
fun_jax(x)
|
||||
|
||||
# TODO(necula): this should work in op-by-op mode, but it fails because
|
||||
# jax2tf.convert does abstract evaluation.
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Error compiling TensorFlow function. call_tf can used in a staged context"):
|
||||
fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
|
||||
fun_tf_rt(x)
|
||||
|
||||
def test_shape_polymorphism_error(self):
|
||||
x = np.array([.7, .8], dtype=np.float32)
|
||||
def fun_tf(x):
|
||||
return tf.math.sin(x)
|
||||
|
||||
fun_jax = jax2tf.call_tf(fun_tf)
|
||||
|
||||
fun_tf_rt = jax2tf.convert(fun_jax,
|
||||
polymorphic_shapes=["b, ..."])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"call_tf cannot be applies to shape-polymorphic arguments"):
|
||||
fun_tf_rt(x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -161,7 +161,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
res = f_tf(*args)
|
||||
input_signature = list(tf.TensorSpec(a.shape, a.dtype) for a in args)
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, input_signature)
|
||||
restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_signature)
|
||||
res_restored = restored_f(*args)
|
||||
self.assertAllClose(res, res_restored)
|
||||
|
||||
@ -208,7 +208,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
jnp.sin(np.array([3.14, 2.78], dtype=np.float16)))
|
||||
|
||||
# Save and restore SavedModel
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(composed_fn,
|
||||
restored_f, _ = tf_test_util.SaveAndLoadFunction(composed_fn,
|
||||
[tf.TensorSpec((2,), dtype=tf.string)])
|
||||
res_tf_restored = restored_f(x_str)
|
||||
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())
|
||||
|
@ -721,7 +721,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax = jnp.sin
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||
restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||
self.assertAllClose(f_jax(x), restored_f(x))
|
||||
# Ensure that restored_f works at other batch size as well
|
||||
y = np.concatenate([x, x])
|
||||
@ -1369,7 +1369,8 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# test will be called "test_prim_xxx_...".
|
||||
# If you want to run this test for only one harness that includes "foo"
|
||||
# in the name, add parameter `one_containing="foo"` to parameterized below.
|
||||
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES)
|
||||
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES,
|
||||
one_containing="dynamic_slice_enable_xla=True_poly_axes=[0]")
|
||||
def test_prim(self, harness: Harness):
|
||||
args = harness.dyn_args_maker(self.rng())
|
||||
poly_axes = harness.params["poly_axes"] # type: Sequence[Sequence[int]]
|
||||
|
@ -17,7 +17,7 @@ import dataclasses
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
@ -87,14 +87,17 @@ def SaveAndLoadModel(model: tf.Module,
|
||||
|
||||
def SaveAndLoadFunction(f_tf: Callable,
|
||||
input_signature: Sequence[tf.TensorSpec],
|
||||
save_gradients=True) -> Callable:
|
||||
# Roundtrip through saved model on disk
|
||||
model = tf.Module()
|
||||
variables: Sequence[tf.Variable] = (),
|
||||
save_gradients=True) -> Tuple[Callable, tf.train.Checkpoint]:
|
||||
# Roundtrip through saved model on disk. Return the Checkpoint also
|
||||
# for the cases when there are variables.
|
||||
model = tf.train.Checkpoint()
|
||||
model.f = tf.function(f_tf,
|
||||
autograph=False,
|
||||
input_signature=input_signature)
|
||||
model.variables = variables
|
||||
restored = SaveAndLoadModel(model, save_gradients=save_gradients)
|
||||
return restored.f
|
||||
return restored.f, restored
|
||||
|
||||
|
||||
class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user