[jax2tf] Fix the round-trip call_tf(convert)

Also cleaned the handling of global state in jax2tf.
This commit is contained in:
George Necula 2021-06-10 17:01:22 +02:00
parent 3d1a6a308e
commit 1994f6df4a
9 changed files with 247 additions and 108 deletions

View File

@ -8,9 +8,15 @@ Remember to align the itemized text with the first line of an item within a list
PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
-->
## jax 0.2.14 (unreleased)
## jax 0.2.15 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
* New features:
* Breaking changes:
* Bug fixes:
* Fixed bug that prevented round-tripping from JAX to TF and back:
`jax2tf.call_tf(jax2tf.convert)` ({jax-issue}`#6947`).
## jax 0.2.14 (June 10 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
@ -23,7 +29,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* The {func}`jax2tf.convert` supports shape polymorphism even when the
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
({jax-issue}`#6827`).
* The {func}`jax2tf.convert` generates custom attributes with location information
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.

View File

@ -772,7 +772,7 @@ As a trivial example, consider computing ``sin(cos(1.))`` with ``sin`` done in J
# It should return a similar result. This function will be called using
# TensorFlow eager mode if called from outside JAX staged contexts (`jit`,
# `pmap`, or control-flow primitives), and will be called using TensorFlow
# graph mode otherwise. In the latter case, the function must be compileable
# compiled mode otherwise. In the latter case, the function must be compileable
# with XLA (`tf.function(func, jit_compile=True)`)
def cos_tf(x):
return tf.math.cos(x)

View File

@ -23,18 +23,18 @@ https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#call
"""
import logging
from typing import Callable
from typing import Callable, Sequence
import jax
from jax import core
from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import util
from jax.interpreters import xla
from jax.lib import xla_bridge
from jax.lib import xla_client
from . import jax2tf as jax2tf_internal
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -86,19 +86,25 @@ def call_tf(func_tf: Callable) -> Callable:
args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax)
args_tf_sig_flat = [
tf.TensorSpec(np.shape(a_jax), _to_tf_dtype(_dtype(a_jax)))
tf.TensorSpec(np.shape(a_jax), jax2tf_internal._to_tf_dtype(_dtype(a_jax)))
for a_jax in args_jax_flat
]
args_tf_sig = tf.nest.map_structure(
lambda a_jax: tf.TensorSpec(
np.shape(a_jax), _to_tf_dtype(_dtype(a_jax))), args_jax)
np.shape(a_jax), jax2tf_internal._to_tf_dtype(_dtype(a_jax))), args_jax)
# Trace once through the function to get the result shape
with jax2tf_internal.inside_call_tf():
func_tf_concrete = tf.function(func_tf).get_concrete_function(*args_tf_sig)
res_tf_sig_flat, res_treedef = tree_util.tree_flatten(
func_tf_concrete.structured_outputs)
res_jax_flat = call_tf_p.bind(
*args_jax_flat,
# Carry the actual function such that op-by-op call can call in TF eager mode.
func_tf=func_tf,
func_tf_concrete=func_tf_concrete,
args_treedef=args_jax_treedef,
args_tf_sig_flat=args_tf_sig_flat,
res_treedef=res_treedef,
@ -167,6 +173,8 @@ def _call_tf_impl(*args_jax_flat, args_treedef, func_tf, **_):
return tf.constant(np.asarray(arg_jax))
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
with jax2tf_internal.inside_call_tf():
# Call in TF eager mode
res_tf = func_tf(*args_treedef.unflatten(args_tf_flat))
res_tf_flat, _ = tree_util.tree_flatten(res_tf)
# TODO(necula): check the result for tree and aval
@ -190,7 +198,7 @@ call_tf_p.def_impl(_call_tf_impl)
def _call_tf_abstract_eval(*_, res_tf_sig_flat, **__):
return tuple([
core.ShapedArray(np.shape(r), _to_jax_dtype(r.dtype))
core.ShapedArray(np.shape(r), jax2tf_internal._to_jax_dtype(r.dtype))
for r in res_tf_sig_flat
])
@ -198,7 +206,7 @@ def _call_tf_abstract_eval(*_, res_tf_sig_flat, **__):
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
def _call_tf_translation_rule(builder, *args_op, func_tf,
def _call_tf_translation_rule(builder, *args_op, func_tf, func_tf_concrete,
args_treedef, args_tf_sig_flat, res_tf_sig_flat,
**_):
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
@ -209,7 +217,7 @@ def _call_tf_translation_rule(builder, *args_op, func_tf,
]
args_tf = args_treedef.unflatten(args_tf_flat)
func_tf = tf.function(func_tf, jit_compile=True)
func_tf_concrete = func_tf.get_concrete_function(*args_tf)
#func_tf_concrete = func_tf.get_concrete_function(*args_tf)
captured_ops = [] # Same order as captured_inputs
if func_tf_concrete.captured_inputs:
# The function uses either captured variables or tensors.
@ -248,13 +256,15 @@ def _call_tf_translation_rule(builder, *args_op, func_tf,
xla.translations[call_tf_p] = _call_tf_translation_rule
TfVal = jax2tf_internal.TfVal
def _jax2tf_call_tf(*args: TfVal,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
func_tf: Callable,
**kwargs) -> TfVal:
res_tf = func_tf(*args)
res_tf_flat = tf.nest.flatten(res_tf)
# TODO: check that the return values have the right signature
return res_tf_flat
def _to_tf_dtype(jax_dtype):
if jax_dtype == dtypes.float0:
return tf.float32
else:
return tf.dtypes.as_dtype(jax_dtype)
def _to_jax_dtype(tf_dtype):
return tf_dtype.as_numpy_dtype
jax2tf_internal.tf_impl_with_avals[call_tf_p] = _jax2tf_call_tf

View File

@ -81,6 +81,10 @@ TfVal = Any
DType = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
# A dimension environment maps dimension variables to TF expressions that
# compute the value of the dimension. These expressions refer to the TF
# function arguments.
_ShapeEnv = Dict[str, TfVal]
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
@ -111,12 +115,6 @@ tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
# core.AbstractValue, or a tuple thereof when primitive.multiple_results).
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
# XLA is not linked in all environments; when converting a primitive, if this
# variable is disabled, we try harder to use only standard TF ops if they are
# applicable to the concrete use case; if the resulting conversion path ends up
# requiring a TFXLA operation, an exception is thrown instead.
_enable_xla = True
# In order to ensure that JAX picks up the proper user-frame for source
# locations we will register the TensorFlow source path as an internal
# path with source_info_util. The typical stack when a JAX primitive
@ -132,17 +130,48 @@ _enable_xla = True
# also.
# We register the TensorFlow source path lazily
_has_registered_tf_source_path = False
# Whether to actually include XLA op metadata
_include_xla_op_metadata = True
class _ThreadLocalState(threading.local):
def __init__(self):
self.name_stack = ""
# XLA is not linked in all environments; when converting a primitive, if this
# variable is disabled, we try harder to use only standard TF ops if they are
# applicable to the concrete use case; if the resulting conversion path ends up
# requiring a TFXLA operation, an exception is thrown instead.
self.enable_xla = True
# Keep track if we are inside a call_tf. In that context we disable the
# safety check that we are not inside JAX transformations.
self.inside_call_tf = False
# Maps dimension variables to TF expressions
self.shape_env: _ShapeEnv = {}
# Whether to actually include XLA op metadata in the generated TF ops
self.include_xla_op_metadata = True
_thread_local_state = _ThreadLocalState()
def _get_current_name_stack():
return _thread_local_state.name_stack
def _xla_disabled_error(primitive_name: str,
extra_msg: Optional[str] = None) -> Exception:
assert not _enable_xla
assert not _thread_local_state.enable_xla
msg = f"Call to {primitive_name} cannot be converted with enable_xla=False."
if extra_msg:
msg += f" {extra_msg}"
return NotImplementedError(msg)
@contextlib.contextmanager
def inside_call_tf():
# Set the inside_call_tf flag for a context.
prev = _thread_local_state.inside_call_tf
_thread_local_state.inside_call_tf = True
try:
yield
finally:
_thread_local_state.inside_call_tf = prev
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun: Callable,
*,
@ -214,9 +243,10 @@ def convert(fun: Callable,
name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf"))
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean():
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError("convert must be used outside all JAX transformations." +
f"Trace state: {core.thread_local_state.trace_state}")
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
def check_arg(a):
if not _is_tfval(a):
@ -308,16 +338,15 @@ def convert(fun: Callable,
return in_cts
try:
global _shape_env
assert not _shape_env, f"Unexpected shape environment {_shape_env}"
global _enable_xla
prev_enable_xla = _enable_xla
_enable_xla = enable_xla
global _include_xla_op_metadata
prev_include_xla_op_metadata = _include_xla_op_metadata
_include_xla_op_metadata = False
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
_shape_env = shapeenv
prev_enable_xla = _thread_local_state.enable_xla
_thread_local_state.enable_xla = enable_xla
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
_thread_local_state.include_xla_op_metadata = False
_thread_local_state.shape_env = shapeenv
global _has_registered_tf_source_path
if not _has_registered_tf_source_path:
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
@ -345,9 +374,9 @@ def convert(fun: Callable,
for o, _ in out_flat_raw
]
finally:
_shape_env = {}
_enable_xla = prev_enable_xla
_include_xla_op_metadata = prev_include_xla_op_metadata
_thread_local_state.shape_env = {}
_thread_local_state.enable_xla = prev_enable_xla
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
@ -371,15 +400,6 @@ def dtype_of_val(val: TfVal) -> DType:
# Internals
# TODO: add all globals here
class _ThreadLocalState(threading.local):
def __init__(self):
self.name_stack = ""
_thread_local_state = _ThreadLocalState()
def _get_current_name_stack():
return _thread_local_state.name_stack
@contextlib.contextmanager
def _extended_name_stack(extra_name_stack: Optional[str]):
prev_name_stack = _thread_local_state.name_stack
@ -526,10 +546,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
return tf.convert_to_tensor(val, dtype=conversion_dtype), jax_dtype
# A dimension environment maps dimension variables to TF expressions that
# compute the value of the dimension. These expressions refer to the TF
# function arguments.
_ShapeEnv = Dict[str, TfVal]
def _args_to_avals_and_env(
args: Sequence[TfVal],
arg_jax_dtypes: Sequence[DType],
@ -573,14 +589,10 @@ def _args_to_avals_and_env(
return avals, shapeenv
# A shape environment maps shape variables to TfVal.
_shape_env = {} # type: _ShapeEnv
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
return shape_poly.eval_shape(shape, _shape_env)
return shape_poly.eval_shape(shape, _thread_local_state.shape_env)
def shape_as_value(x):
@ -800,7 +812,7 @@ class TensorFlowTrace(core.Trace):
else:
return impl(*args_tf, **params)
if _include_xla_op_metadata:
if _thread_local_state.include_xla_op_metadata:
op_metadata = xla.make_op_metadata(primitive, params,
name_stack=_get_current_name_stack(),
source_info=source_info_util.current())
@ -953,7 +965,6 @@ tf_not_yet_impl = [
"lu_pivots_to_permutation",
"rng_bit_generator",
"xla_pmap",
"call_tf",
]
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
@ -1521,7 +1532,7 @@ def _conv_general_dilated(lhs, rhs, *,
_out_aval: core.AbstractValue):
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
out_tf_shape = _aval_to_tf_shape(_out_aval)
if not _enable_xla:
if not _thread_local_state.enable_xla:
return _try_tf_conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
@ -1580,7 +1591,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
if _enable_xla:
if _thread_local_state.enable_xla:
dnums_proto = xla_data_pb2.DotDimensionNumbers()
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
@ -1723,7 +1734,7 @@ def _pad(operand, padding_value, *, padding_config,
_out_aval: core.AbstractValue):
del _in_avals
low, high, interior = util.unzip3(padding_config)
if _enable_xla:
if _thread_local_state.enable_xla:
out = tfxla.pad(operand, padding_value, low, high, interior)
return out
@ -1911,7 +1922,7 @@ def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
"""
assert len(consts) == 0, "Reduction computation cannot have constants"
if not _enable_xla:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("reduce_window")
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
@ -2029,7 +2040,7 @@ def _specialized_reduce_window(reducer,
Returns:
The reduced operand.
"""
if not _enable_xla and name in ["reduce_window_max", "reduce_window_sum"]:
if not _thread_local_state.enable_xla and name in ["reduce_window_max", "reduce_window_sum"]:
return _try_tf_pool(name, operand, window_dimensions, window_strides,
padding, base_dilation, window_dilation)
@ -2123,7 +2134,7 @@ tf_impl[lax.select_and_scatter_p] = _select_and_scatter
@partial(bool_to_int8, argnums=(0, 1))
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
window_strides, padding, _in_avals, _out_aval):
if not _enable_xla:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("select_and_scatter_add")
init_value = tf.zeros((), operand.dtype)
select_fn = (
@ -2173,7 +2184,7 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
_in_avals, _out_aval):
"""Tensorflow implementation of gather."""
del _in_avals, unique_indices
if not _enable_xla:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("gather")
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
slice_sizes_tf = _eval_shape(slice_sizes)
@ -2209,7 +2220,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes,
start_indices = tf.stack(start_indices)
slice_sizes = _eval_shape(slice_sizes)
if _enable_xla:
if _thread_local_state.enable_xla:
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes)
# TODO: implement shape inference for XlaDynamicSlice
res.set_shape(_aval_to_tf_shape(_out_aval))
@ -2259,7 +2270,7 @@ def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts,
del unique_indices, _in_avals
assert len(update_consts) == 0, "Update computation cannot have constants"
if not _enable_xla:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("scatter")
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
@ -2293,7 +2304,7 @@ tf_impl_with_avals[lax.scatter_add_p] = _scatter
def _dynamic_update_slice(operand, update, *start_indices):
if not _enable_xla:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("dynamic_update_slice")
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
@ -2428,7 +2439,7 @@ tf_impl[lax.top_k_p] = _top_k
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
num_keys: int) -> Tuple[TfVal, ...]:
if not _enable_xla:
if not _thread_local_state.enable_xla:
raise _xla_disabled_error("sort")
assert 1 <= num_keys <= len(operands)
assert 0 <= dimension < len(

View File

@ -462,7 +462,7 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
spec_ = spec.replace(" ", "")
if spec_[0] == "(":
if spec_[-1] != ")":
raise ValueError(spec)
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
spec_ = spec_[1:-1]
spec_ = spec_.rstrip(",")
if not spec_:

View File

@ -25,6 +25,7 @@ from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np
@ -58,10 +59,12 @@ class CallTfTest(jtu.JaxTestCase):
_ = tf.add(1, 1)
super().setUp()
@parameterized_jit
def test_eval_scalar_arg(self, with_jit=False):
#@parameterized_jit
def test_eval_scalar_arg(self, with_jit=True):
def f_tf(x):
return tf.math.sin(x)
x = 3.
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
res = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
self.assertAllClose(jnp.sin(x), res, check_dtypes=False)
@parameterized_jit
@ -119,6 +122,16 @@ class CallTfTest(jtu.JaxTestCase):
res = fun_jax(x, y)
self.assertAllClose((np.float32(12.), np.float64(11.)), res)
def test_eval_non_compileable(self):
# Check that in op-by-op we call a function in eager mode.
def f_tf_non_compileable(x):
return tf.strings.length(tf.strings.format("Hello {}!", [x]))
f_jax = jax2tf.call_tf(f_tf_non_compileable)
x = np.float32(0.7)
self.assertAllClose(f_tf_non_compileable(x).numpy(), f_jax(x))
@parameterized_jit
def test_control_flow(self, with_jit=True):
@ -319,6 +332,89 @@ class CallTfTest(jtu.JaxTestCase):
res = jax.pmap(fun_jax)(x)
self.assertAllClose(np.float32(3. * (x + 2)), res)
def test_round_trip(self):
f_jax = jnp.sin
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
x = np.float32(0.7)
self.assertAllClose(f_jax(x), f_jax_rt(x))
def test_round_trip_custom_grad(self):
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), np.float32(3.) * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
x = np.float32(0.7)
self.assertAllClose(f(x), f_rt(x))
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
def test_round_trip_shape_poly(self):
f_jax = jnp.sin
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
polymorphic_shapes=["(b, ...)"]))
x = np.array([0.7, 0.8], dtype=np.float32)
self.assertAllClose(f_jax(x), f_jax_rt(x))
def test_round_trip_saved_model_shape_poly(self):
tracing_count = 0
def f_jax(x):
nonlocal tracing_count
tracing_count += 1
return jnp.sin(x)
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
x = np.array([0.7, 0.8], dtype=np.float32)
res_jax = f_jax(x)
self.assertEqual(1, tracing_count)
# Will trace twice, it seems. Once to get the result signature, and once again
# for the actual saving.
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
self.assertGreaterEqual(tracing_count, 2)
tracing_count = 0
f_jax_rt = jax2tf.call_tf(restored_f)
self.assertAllClose(res_jax, f_jax_rt(x))
# Ensure that restored_f works at other batch size as well
y = np.concatenate([x, x])
self.assertEqual(0, tracing_count)
res_jax_y = f_jax(y)
self.assertEqual(1, tracing_count)
# No more tracing for f_jax_rt
self.assertAllClose(res_jax_y, f_jax_rt(y))
self.assertEqual(1, tracing_count)
def test_round_trip_custom_grad_saved_model(self):
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), np.float32(3.) * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
def g(x):
return jnp.sum(f(x))
g_tf = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
[tf.TensorSpec([None], dtype=tf.float32)])
g_rt = jax2tf.call_tf(g_tf)
x = np.array([0.7], dtype=np.float32)
self.assertAllClose(g(x), g_rt(x))
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
def test_module_documentation(self):
def cos_tf(x):
return tf.math.cos(x)
@ -342,6 +438,12 @@ class CallTfTest(jtu.JaxTestCase):
print(jax.make_jaxpr(cos_tf_sin_jax)(x))
print(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
def test_round_trip_reverse(self):
f_tf = tf.math.sin
f_tf_rt = jax2tf.convert(jax2tf.call_tf(f_tf))
x = np.float32(0.7)
self.assertAllClose(f_tf(x).numpy(), f_tf_rt(x).numpy())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl.testing import absltest
import jax
@ -32,15 +30,6 @@ config.parse_flags_with_absl()
class SavedModelTest(tf_test_util.JaxToTfTestCase):
def save_and_load_model(self, model: tf.Module) -> tf.Module:
# Roundtrip through saved model on disk.
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
tf.saved_model.save(
model, model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
restored_model = tf.saved_model.load(model_dir)
return restored_model
def test_eval(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
model = tf.Module()
@ -50,7 +39,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
)
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = self.save_and_load_model(model)
restored_model = tf_test_util.SaveAndLoadModel(model)
self.assertAllClose(restored_model.f(x), f_jax(x))
def test_gradient_disabled(self):
@ -62,7 +51,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
input_signature=[tf.TensorSpec([], tf.float32)])
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = self.save_and_load_model(model)
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(0.7, dtype=jnp.float32)
self.assertAllClose(restored_model.f(x), f_jax(x))
@ -92,7 +81,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
input_signature=[tf.TensorSpec([], tf.float32)])
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = self.save_and_load_model(model)
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(0.7, dtype=jnp.float32)
self.assertAllClose(restored_model.f(x), f_jax(x))
with tf.GradientTape() as tape:
@ -106,14 +95,9 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
# JAX. We check that this information is preserved through a savedmodel
f_tf = jax2tf.convert(f_jax)
res = f_tf(*args)
model = tf.Module()
input_signature = list(tf.TensorSpec(a.shape, a.dtype) for a in args)
model.f = tf.function(f_tf,
autograph=False,
input_signature=input_signature)
restored_model = self.save_and_load_model(model)
res_restored = restored_model.f(*args)
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, input_signature)
res_restored = restored_f(*args)
self.assertAllClose(res, res_restored)
def test_xla_context_preserved_slice(self):
@ -159,12 +143,9 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
jnp.sin(np.array([3.14, 2.78], dtype=np.float16)))
# Save and restore SavedModel
model = tf.Module()
model.f = tf.function(
composed_fn,
input_signature=[tf.TensorSpec((2,), dtype=tf.string)])
restored_model = self.save_and_load_model(model)
res_tf_restored = restored_model.f(x_str)
restored_f = tf_test_util.SaveAndLoadFunction(composed_fn,
[tf.TensorSpec((2,), dtype=tf.string)])
res_tf_restored = restored_f(x_str)
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())

View File

@ -609,6 +609,16 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
res_jax,
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
def test_saved_model_shape_poly(self):
f_jax = jnp.sin
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
x = np.array([0.7, 0.8], dtype=np.float32)
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
self.assertAllClose(f_jax(x), restored_f(x))
# Ensure that restored_f works at other batch size as well
y = np.concatenate([x, x])
self.assertAllClose(f_jax(y), restored_f(y))
def test_readme_example(self):
"""Some of the examples from the README."""
def image_mask_jax(images, mask):

View File

@ -15,10 +15,11 @@
import contextlib
import dataclasses
import logging
import os
from typing import Any, Callable, List, Optional, Sequence
from absl.testing import absltest
import jax
from jax import dtypes
from jax import numpy as jnp
@ -74,6 +75,26 @@ class OpMetadataGraph:
source_line: str
def SaveAndLoadModel(model: tf.Module) -> tf.Module:
# Roundtrip through saved model on disk.
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
tf.saved_model.save(
model, model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
restored_model = tf.saved_model.load(model_dir)
return restored_model
def SaveAndLoadFunction(f_tf: Callable,
input_signature: Sequence[tf.TensorSpec]) -> Callable:
# Roundtrip through saved model on disk
model = tf.Module()
model.f = tf.function(f_tf,
autograph=False,
input_signature=input_signature)
restored = SaveAndLoadModel(model)
return restored.f
class JaxToTfTestCase(jtu.JaxTestCase):
def setUp(self):
@ -344,7 +365,6 @@ class JaxToTfTestCase(jtu.JaxTestCase):
return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)
def CheckOpMetadata(self, jax_fun, x,
expected: Sequence[OpMetadataGraph],
include_xla_op_metadata=True):