mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax2tf] Fix the round-trip call_tf(convert)
Also cleaned the handling of global state in jax2tf.
This commit is contained in:
parent
3d1a6a308e
commit
1994f6df4a
@ -8,9 +8,15 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
-->
|
||||
|
||||
## jax 0.2.14 (unreleased)
|
||||
## jax 0.2.15 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
|
||||
* New features:
|
||||
|
||||
* Breaking changes:
|
||||
|
||||
* Bug fixes:
|
||||
* Fixed bug that prevented round-tripping from JAX to TF and back:
|
||||
`jax2tf.call_tf(jax2tf.convert)` ({jax-issue}`#6947`).
|
||||
|
||||
## jax 0.2.14 (June 10 2021)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
|
||||
@ -23,7 +29,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* The {func}`jax2tf.convert` supports shape polymorphism even when the
|
||||
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
|
||||
({jax-issue}`#6827`).
|
||||
|
||||
* The {func}`jax2tf.convert` generates custom attributes with location information
|
||||
in TF ops. The code that XLA generates after jax2tf
|
||||
has the same location information as JAX/XLA.
|
||||
|
@ -772,7 +772,7 @@ As a trivial example, consider computing ``sin(cos(1.))`` with ``sin`` done in J
|
||||
# It should return a similar result. This function will be called using
|
||||
# TensorFlow eager mode if called from outside JAX staged contexts (`jit`,
|
||||
# `pmap`, or control-flow primitives), and will be called using TensorFlow
|
||||
# graph mode otherwise. In the latter case, the function must be compileable
|
||||
# compiled mode otherwise. In the latter case, the function must be compileable
|
||||
# with XLA (`tf.function(func, jit_compile=True)`)
|
||||
def cos_tf(x):
|
||||
return tf.math.cos(x)
|
||||
|
@ -23,18 +23,18 @@ https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#call
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax._src import util
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
from . import jax2tf as jax2tf_internal
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
@ -86,19 +86,25 @@ def call_tf(func_tf: Callable) -> Callable:
|
||||
|
||||
args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax)
|
||||
args_tf_sig_flat = [
|
||||
tf.TensorSpec(np.shape(a_jax), _to_tf_dtype(_dtype(a_jax)))
|
||||
tf.TensorSpec(np.shape(a_jax), jax2tf_internal._to_tf_dtype(_dtype(a_jax)))
|
||||
for a_jax in args_jax_flat
|
||||
]
|
||||
args_tf_sig = tf.nest.map_structure(
|
||||
lambda a_jax: tf.TensorSpec(
|
||||
np.shape(a_jax), _to_tf_dtype(_dtype(a_jax))), args_jax)
|
||||
np.shape(a_jax), jax2tf_internal._to_tf_dtype(_dtype(a_jax))), args_jax)
|
||||
|
||||
# Trace once through the function to get the result shape
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
func_tf_concrete = tf.function(func_tf).get_concrete_function(*args_tf_sig)
|
||||
|
||||
res_tf_sig_flat, res_treedef = tree_util.tree_flatten(
|
||||
func_tf_concrete.structured_outputs)
|
||||
|
||||
res_jax_flat = call_tf_p.bind(
|
||||
*args_jax_flat,
|
||||
# Carry the actual function such that op-by-op call can call in TF eager mode.
|
||||
func_tf=func_tf,
|
||||
func_tf_concrete=func_tf_concrete,
|
||||
args_treedef=args_jax_treedef,
|
||||
args_tf_sig_flat=args_tf_sig_flat,
|
||||
res_treedef=res_treedef,
|
||||
@ -167,6 +173,8 @@ def _call_tf_impl(*args_jax_flat, args_treedef, func_tf, **_):
|
||||
return tf.constant(np.asarray(arg_jax))
|
||||
|
||||
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
# Call in TF eager mode
|
||||
res_tf = func_tf(*args_treedef.unflatten(args_tf_flat))
|
||||
res_tf_flat, _ = tree_util.tree_flatten(res_tf)
|
||||
# TODO(necula): check the result for tree and aval
|
||||
@ -190,7 +198,7 @@ call_tf_p.def_impl(_call_tf_impl)
|
||||
|
||||
def _call_tf_abstract_eval(*_, res_tf_sig_flat, **__):
|
||||
return tuple([
|
||||
core.ShapedArray(np.shape(r), _to_jax_dtype(r.dtype))
|
||||
core.ShapedArray(np.shape(r), jax2tf_internal._to_jax_dtype(r.dtype))
|
||||
for r in res_tf_sig_flat
|
||||
])
|
||||
|
||||
@ -198,7 +206,7 @@ def _call_tf_abstract_eval(*_, res_tf_sig_flat, **__):
|
||||
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
|
||||
|
||||
|
||||
def _call_tf_translation_rule(builder, *args_op, func_tf,
|
||||
def _call_tf_translation_rule(builder, *args_op, func_tf, func_tf_concrete,
|
||||
args_treedef, args_tf_sig_flat, res_tf_sig_flat,
|
||||
**_):
|
||||
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
|
||||
@ -209,7 +217,7 @@ def _call_tf_translation_rule(builder, *args_op, func_tf,
|
||||
]
|
||||
args_tf = args_treedef.unflatten(args_tf_flat)
|
||||
func_tf = tf.function(func_tf, jit_compile=True)
|
||||
func_tf_concrete = func_tf.get_concrete_function(*args_tf)
|
||||
#func_tf_concrete = func_tf.get_concrete_function(*args_tf)
|
||||
captured_ops = [] # Same order as captured_inputs
|
||||
if func_tf_concrete.captured_inputs:
|
||||
# The function uses either captured variables or tensors.
|
||||
@ -248,13 +256,15 @@ def _call_tf_translation_rule(builder, *args_op, func_tf,
|
||||
|
||||
xla.translations[call_tf_p] = _call_tf_translation_rule
|
||||
|
||||
TfVal = jax2tf_internal.TfVal
|
||||
def _jax2tf_call_tf(*args: TfVal,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray,
|
||||
func_tf: Callable,
|
||||
**kwargs) -> TfVal:
|
||||
res_tf = func_tf(*args)
|
||||
res_tf_flat = tf.nest.flatten(res_tf)
|
||||
# TODO: check that the return values have the right signature
|
||||
return res_tf_flat
|
||||
|
||||
def _to_tf_dtype(jax_dtype):
|
||||
if jax_dtype == dtypes.float0:
|
||||
return tf.float32
|
||||
else:
|
||||
return tf.dtypes.as_dtype(jax_dtype)
|
||||
|
||||
|
||||
def _to_jax_dtype(tf_dtype):
|
||||
return tf_dtype.as_numpy_dtype
|
||||
jax2tf_internal.tf_impl_with_avals[call_tf_p] = _jax2tf_call_tf
|
||||
|
@ -81,6 +81,10 @@ TfVal = Any
|
||||
DType = Any
|
||||
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
||||
|
||||
# A dimension environment maps dimension variables to TF expressions that
|
||||
# compute the value of the dimension. These expressions refer to the TF
|
||||
# function arguments.
|
||||
_ShapeEnv = Dict[str, TfVal]
|
||||
|
||||
def _is_tfval(v: TfVal) -> bool:
|
||||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||||
@ -111,12 +115,6 @@ tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
|
||||
# core.AbstractValue, or a tuple thereof when primitive.multiple_results).
|
||||
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
|
||||
|
||||
# XLA is not linked in all environments; when converting a primitive, if this
|
||||
# variable is disabled, we try harder to use only standard TF ops if they are
|
||||
# applicable to the concrete use case; if the resulting conversion path ends up
|
||||
# requiring a TFXLA operation, an exception is thrown instead.
|
||||
_enable_xla = True
|
||||
|
||||
# In order to ensure that JAX picks up the proper user-frame for source
|
||||
# locations we will register the TensorFlow source path as an internal
|
||||
# path with source_info_util. The typical stack when a JAX primitive
|
||||
@ -132,17 +130,48 @@ _enable_xla = True
|
||||
# also.
|
||||
# We register the TensorFlow source path lazily
|
||||
_has_registered_tf_source_path = False
|
||||
# Whether to actually include XLA op metadata
|
||||
_include_xla_op_metadata = True
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
self.name_stack = ""
|
||||
# XLA is not linked in all environments; when converting a primitive, if this
|
||||
# variable is disabled, we try harder to use only standard TF ops if they are
|
||||
# applicable to the concrete use case; if the resulting conversion path ends up
|
||||
# requiring a TFXLA operation, an exception is thrown instead.
|
||||
self.enable_xla = True
|
||||
|
||||
# Keep track if we are inside a call_tf. In that context we disable the
|
||||
# safety check that we are not inside JAX transformations.
|
||||
self.inside_call_tf = False
|
||||
|
||||
# Maps dimension variables to TF expressions
|
||||
self.shape_env: _ShapeEnv = {}
|
||||
|
||||
# Whether to actually include XLA op metadata in the generated TF ops
|
||||
self.include_xla_op_metadata = True
|
||||
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
def _get_current_name_stack():
|
||||
return _thread_local_state.name_stack
|
||||
def _xla_disabled_error(primitive_name: str,
|
||||
extra_msg: Optional[str] = None) -> Exception:
|
||||
assert not _enable_xla
|
||||
assert not _thread_local_state.enable_xla
|
||||
msg = f"Call to {primitive_name} cannot be converted with enable_xla=False."
|
||||
if extra_msg:
|
||||
msg += f" {extra_msg}"
|
||||
return NotImplementedError(msg)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inside_call_tf():
|
||||
# Set the inside_call_tf flag for a context.
|
||||
prev = _thread_local_state.inside_call_tf
|
||||
_thread_local_state.inside_call_tf = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_thread_local_state.inside_call_tf = prev
|
||||
|
||||
@partial(api_util.api_hook, tag="jax2tf_convert")
|
||||
def convert(fun: Callable,
|
||||
*,
|
||||
@ -214,9 +243,10 @@ def convert(fun: Callable,
|
||||
name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf"))
|
||||
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
|
||||
# TODO: is there a better way to check if we are inside a transformation?
|
||||
if not core.trace_state_clean():
|
||||
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
|
||||
# It is Ok to nest convert when we are inside a call_tf
|
||||
raise ValueError("convert must be used outside all JAX transformations." +
|
||||
f"Trace state: {core.thread_local_state.trace_state}")
|
||||
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
|
||||
|
||||
def check_arg(a):
|
||||
if not _is_tfval(a):
|
||||
@ -308,16 +338,15 @@ def convert(fun: Callable,
|
||||
return in_cts
|
||||
|
||||
try:
|
||||
global _shape_env
|
||||
assert not _shape_env, f"Unexpected shape environment {_shape_env}"
|
||||
global _enable_xla
|
||||
prev_enable_xla = _enable_xla
|
||||
_enable_xla = enable_xla
|
||||
global _include_xla_op_metadata
|
||||
prev_include_xla_op_metadata = _include_xla_op_metadata
|
||||
_include_xla_op_metadata = False
|
||||
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
|
||||
|
||||
_shape_env = shapeenv
|
||||
prev_enable_xla = _thread_local_state.enable_xla
|
||||
_thread_local_state.enable_xla = enable_xla
|
||||
|
||||
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
|
||||
_thread_local_state.include_xla_op_metadata = False
|
||||
|
||||
_thread_local_state.shape_env = shapeenv
|
||||
global _has_registered_tf_source_path
|
||||
if not _has_registered_tf_source_path:
|
||||
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
|
||||
@ -345,9 +374,9 @@ def convert(fun: Callable,
|
||||
for o, _ in out_flat_raw
|
||||
]
|
||||
finally:
|
||||
_shape_env = {}
|
||||
_enable_xla = prev_enable_xla
|
||||
_include_xla_op_metadata = prev_include_xla_op_metadata
|
||||
_thread_local_state.shape_env = {}
|
||||
_thread_local_state.enable_xla = prev_enable_xla
|
||||
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
|
||||
|
||||
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
|
||||
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
|
||||
@ -371,15 +400,6 @@ def dtype_of_val(val: TfVal) -> DType:
|
||||
|
||||
# Internals
|
||||
|
||||
# TODO: add all globals here
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
self.name_stack = ""
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
def _get_current_name_stack():
|
||||
return _thread_local_state.name_stack
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _extended_name_stack(extra_name_stack: Optional[str]):
|
||||
prev_name_stack = _thread_local_state.name_stack
|
||||
@ -526,10 +546,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
return tf.convert_to_tensor(val, dtype=conversion_dtype), jax_dtype
|
||||
|
||||
|
||||
# A dimension environment maps dimension variables to TF expressions that
|
||||
# compute the value of the dimension. These expressions refer to the TF
|
||||
# function arguments.
|
||||
_ShapeEnv = Dict[str, TfVal]
|
||||
def _args_to_avals_and_env(
|
||||
args: Sequence[TfVal],
|
||||
arg_jax_dtypes: Sequence[DType],
|
||||
@ -573,14 +589,10 @@ def _args_to_avals_and_env(
|
||||
return avals, shapeenv
|
||||
|
||||
|
||||
# A shape environment maps shape variables to TfVal.
|
||||
_shape_env = {} # type: _ShapeEnv
|
||||
|
||||
|
||||
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
||||
assert all(map(lambda x: x is not None, shape)), (
|
||||
f"Argument shape should be a valid JAX shape but got {shape}")
|
||||
return shape_poly.eval_shape(shape, _shape_env)
|
||||
return shape_poly.eval_shape(shape, _thread_local_state.shape_env)
|
||||
|
||||
|
||||
def shape_as_value(x):
|
||||
@ -800,7 +812,7 @@ class TensorFlowTrace(core.Trace):
|
||||
else:
|
||||
return impl(*args_tf, **params)
|
||||
|
||||
if _include_xla_op_metadata:
|
||||
if _thread_local_state.include_xla_op_metadata:
|
||||
op_metadata = xla.make_op_metadata(primitive, params,
|
||||
name_stack=_get_current_name_stack(),
|
||||
source_info=source_info_util.current())
|
||||
@ -953,7 +965,6 @@ tf_not_yet_impl = [
|
||||
"lu_pivots_to_permutation",
|
||||
"rng_bit_generator",
|
||||
"xla_pmap",
|
||||
"call_tf",
|
||||
]
|
||||
|
||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||
@ -1521,7 +1532,7 @@ def _conv_general_dilated(lhs, rhs, *,
|
||||
_out_aval: core.AbstractValue):
|
||||
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
||||
out_tf_shape = _aval_to_tf_shape(_out_aval)
|
||||
if not _enable_xla:
|
||||
if not _thread_local_state.enable_xla:
|
||||
return _try_tf_conv(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
@ -1580,7 +1591,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
|
||||
if _enable_xla:
|
||||
if _thread_local_state.enable_xla:
|
||||
dnums_proto = xla_data_pb2.DotDimensionNumbers()
|
||||
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
|
||||
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
|
||||
@ -1723,7 +1734,7 @@ def _pad(operand, padding_value, *, padding_config,
|
||||
_out_aval: core.AbstractValue):
|
||||
del _in_avals
|
||||
low, high, interior = util.unzip3(padding_config)
|
||||
if _enable_xla:
|
||||
if _thread_local_state.enable_xla:
|
||||
out = tfxla.pad(operand, padding_value, low, high, interior)
|
||||
return out
|
||||
|
||||
@ -1911,7 +1922,7 @@ def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
|
||||
"""
|
||||
assert len(consts) == 0, "Reduction computation cannot have constants"
|
||||
|
||||
if not _enable_xla:
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("reduce_window")
|
||||
|
||||
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
|
||||
@ -2029,7 +2040,7 @@ def _specialized_reduce_window(reducer,
|
||||
Returns:
|
||||
The reduced operand.
|
||||
"""
|
||||
if not _enable_xla and name in ["reduce_window_max", "reduce_window_sum"]:
|
||||
if not _thread_local_state.enable_xla and name in ["reduce_window_max", "reduce_window_sum"]:
|
||||
return _try_tf_pool(name, operand, window_dimensions, window_strides,
|
||||
padding, base_dilation, window_dilation)
|
||||
|
||||
@ -2123,7 +2134,7 @@ tf_impl[lax.select_and_scatter_p] = _select_and_scatter
|
||||
@partial(bool_to_int8, argnums=(0, 1))
|
||||
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||
window_strides, padding, _in_avals, _out_aval):
|
||||
if not _enable_xla:
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("select_and_scatter_add")
|
||||
init_value = tf.zeros((), operand.dtype)
|
||||
select_fn = (
|
||||
@ -2173,7 +2184,7 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
|
||||
_in_avals, _out_aval):
|
||||
"""Tensorflow implementation of gather."""
|
||||
del _in_avals, unique_indices
|
||||
if not _enable_xla:
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("gather")
|
||||
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
|
||||
slice_sizes_tf = _eval_shape(slice_sizes)
|
||||
@ -2209,7 +2220,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes,
|
||||
start_indices = tf.stack(start_indices)
|
||||
slice_sizes = _eval_shape(slice_sizes)
|
||||
|
||||
if _enable_xla:
|
||||
if _thread_local_state.enable_xla:
|
||||
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes)
|
||||
# TODO: implement shape inference for XlaDynamicSlice
|
||||
res.set_shape(_aval_to_tf_shape(_out_aval))
|
||||
@ -2259,7 +2270,7 @@ def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts,
|
||||
del unique_indices, _in_avals
|
||||
assert len(update_consts) == 0, "Update computation cannot have constants"
|
||||
|
||||
if not _enable_xla:
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("scatter")
|
||||
|
||||
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
|
||||
@ -2293,7 +2304,7 @@ tf_impl_with_avals[lax.scatter_add_p] = _scatter
|
||||
|
||||
|
||||
def _dynamic_update_slice(operand, update, *start_indices):
|
||||
if not _enable_xla:
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("dynamic_update_slice")
|
||||
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
|
||||
|
||||
@ -2428,7 +2439,7 @@ tf_impl[lax.top_k_p] = _top_k
|
||||
|
||||
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
||||
num_keys: int) -> Tuple[TfVal, ...]:
|
||||
if not _enable_xla:
|
||||
if not _thread_local_state.enable_xla:
|
||||
raise _xla_disabled_error("sort")
|
||||
assert 1 <= num_keys <= len(operands)
|
||||
assert 0 <= dimension < len(
|
||||
|
@ -462,7 +462,7 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
spec_ = spec.replace(" ", "")
|
||||
if spec_[0] == "(":
|
||||
if spec_[-1] != ")":
|
||||
raise ValueError(spec)
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
|
||||
spec_ = spec_[1:-1]
|
||||
spec_ = spec_.rstrip(",")
|
||||
if not spec_:
|
||||
|
@ -25,6 +25,7 @@ from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -58,10 +59,12 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
|
||||
@parameterized_jit
|
||||
def test_eval_scalar_arg(self, with_jit=False):
|
||||
#@parameterized_jit
|
||||
def test_eval_scalar_arg(self, with_jit=True):
|
||||
def f_tf(x):
|
||||
return tf.math.sin(x)
|
||||
x = 3.
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
|
||||
res = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
|
||||
self.assertAllClose(jnp.sin(x), res, check_dtypes=False)
|
||||
|
||||
@parameterized_jit
|
||||
@ -119,6 +122,16 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
res = fun_jax(x, y)
|
||||
self.assertAllClose((np.float32(12.), np.float64(11.)), res)
|
||||
|
||||
def test_eval_non_compileable(self):
|
||||
# Check that in op-by-op we call a function in eager mode.
|
||||
def f_tf_non_compileable(x):
|
||||
return tf.strings.length(tf.strings.format("Hello {}!", [x]))
|
||||
|
||||
f_jax = jax2tf.call_tf(f_tf_non_compileable)
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f_tf_non_compileable(x).numpy(), f_jax(x))
|
||||
|
||||
|
||||
@parameterized_jit
|
||||
def test_control_flow(self, with_jit=True):
|
||||
|
||||
@ -319,6 +332,89 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
res = jax.pmap(fun_jax)(x)
|
||||
self.assertAllClose(np.float32(3. * (x + 2)), res)
|
||||
|
||||
def test_round_trip(self):
|
||||
f_jax = jnp.sin
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
|
||||
def test_round_trip_custom_grad(self):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def f_fwd(x):
|
||||
return f(x), np.float32(3.) * x
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def f_bwd(residual, ct_b):
|
||||
return residual * ct_b,
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
|
||||
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f(x), f_rt(x))
|
||||
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
|
||||
|
||||
def test_round_trip_shape_poly(self):
|
||||
f_jax = jnp.sin
|
||||
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
|
||||
polymorphic_shapes=["(b, ...)"]))
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||
|
||||
def test_round_trip_saved_model_shape_poly(self):
|
||||
tracing_count = 0
|
||||
def f_jax(x):
|
||||
nonlocal tracing_count
|
||||
tracing_count += 1
|
||||
return jnp.sin(x)
|
||||
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
res_jax = f_jax(x)
|
||||
self.assertEqual(1, tracing_count)
|
||||
# Will trace twice, it seems. Once to get the result signature, and once again
|
||||
# for the actual saving.
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||
self.assertGreaterEqual(tracing_count, 2)
|
||||
tracing_count = 0
|
||||
f_jax_rt = jax2tf.call_tf(restored_f)
|
||||
self.assertAllClose(res_jax, f_jax_rt(x))
|
||||
# Ensure that restored_f works at other batch size as well
|
||||
y = np.concatenate([x, x])
|
||||
self.assertEqual(0, tracing_count)
|
||||
res_jax_y = f_jax(y)
|
||||
self.assertEqual(1, tracing_count)
|
||||
# No more tracing for f_jax_rt
|
||||
self.assertAllClose(res_jax_y, f_jax_rt(y))
|
||||
self.assertEqual(1, tracing_count)
|
||||
|
||||
def test_round_trip_custom_grad_saved_model(self):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x * x
|
||||
|
||||
# f_fwd: a -> (b, residual)
|
||||
def f_fwd(x):
|
||||
return f(x), np.float32(3.) * x
|
||||
# f_bwd: (residual, CT b) -> [CT a]
|
||||
def f_bwd(residual, ct_b):
|
||||
return residual * ct_b,
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
def g(x):
|
||||
return jnp.sum(f(x))
|
||||
|
||||
g_tf = tf_test_util.SaveAndLoadFunction(
|
||||
jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
|
||||
[tf.TensorSpec([None], dtype=tf.float32)])
|
||||
g_rt = jax2tf.call_tf(g_tf)
|
||||
x = np.array([0.7], dtype=np.float32)
|
||||
self.assertAllClose(g(x), g_rt(x))
|
||||
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
|
||||
|
||||
def test_module_documentation(self):
|
||||
def cos_tf(x):
|
||||
return tf.math.cos(x)
|
||||
@ -342,6 +438,12 @@ class CallTfTest(jtu.JaxTestCase):
|
||||
print(jax.make_jaxpr(cos_tf_sin_jax)(x))
|
||||
print(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
|
||||
|
||||
def test_round_trip_reverse(self):
|
||||
f_tf = tf.math.sin
|
||||
f_tf_rt = jax2tf.convert(jax2tf.call_tf(f_tf))
|
||||
x = np.float32(0.7)
|
||||
self.assertAllClose(f_tf(x).numpy(), f_tf_rt(x).numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
@ -32,15 +30,6 @@ config.parse_flags_with_absl()
|
||||
|
||||
class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def save_and_load_model(self, model: tf.Module) -> tf.Module:
|
||||
# Roundtrip through saved model on disk.
|
||||
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
|
||||
tf.saved_model.save(
|
||||
model, model_dir,
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
|
||||
restored_model = tf.saved_model.load(model_dir)
|
||||
return restored_model
|
||||
|
||||
def test_eval(self):
|
||||
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
||||
model = tf.Module()
|
||||
@ -50,7 +39,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
)
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(model.f(x), f_jax(x))
|
||||
restored_model = self.save_and_load_model(model)
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
|
||||
def test_gradient_disabled(self):
|
||||
@ -62,7 +51,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(model.f(x), f_jax(x))
|
||||
restored_model = self.save_and_load_model(model)
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
xv = tf.Variable(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
|
||||
@ -92,7 +81,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
x = np.array(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(model.f(x), f_jax(x))
|
||||
restored_model = self.save_and_load_model(model)
|
||||
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||
xv = tf.Variable(0.7, dtype=jnp.float32)
|
||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||
with tf.GradientTape() as tape:
|
||||
@ -106,14 +95,9 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
# JAX. We check that this information is preserved through a savedmodel
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
res = f_tf(*args)
|
||||
|
||||
model = tf.Module()
|
||||
input_signature = list(tf.TensorSpec(a.shape, a.dtype) for a in args)
|
||||
model.f = tf.function(f_tf,
|
||||
autograph=False,
|
||||
input_signature=input_signature)
|
||||
restored_model = self.save_and_load_model(model)
|
||||
res_restored = restored_model.f(*args)
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, input_signature)
|
||||
res_restored = restored_f(*args)
|
||||
self.assertAllClose(res, res_restored)
|
||||
|
||||
def test_xla_context_preserved_slice(self):
|
||||
@ -159,12 +143,9 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
jnp.sin(np.array([3.14, 2.78], dtype=np.float16)))
|
||||
|
||||
# Save and restore SavedModel
|
||||
model = tf.Module()
|
||||
model.f = tf.function(
|
||||
composed_fn,
|
||||
input_signature=[tf.TensorSpec((2,), dtype=tf.string)])
|
||||
restored_model = self.save_and_load_model(model)
|
||||
res_tf_restored = restored_model.f(x_str)
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(composed_fn,
|
||||
[tf.TensorSpec((2,), dtype=tf.string)])
|
||||
res_tf_restored = restored_f(x_str)
|
||||
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())
|
||||
|
||||
|
||||
|
@ -609,6 +609,16 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
res_jax,
|
||||
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
|
||||
|
||||
def test_saved_model_shape_poly(self):
|
||||
f_jax = jnp.sin
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||
self.assertAllClose(f_jax(x), restored_f(x))
|
||||
# Ensure that restored_f works at other batch size as well
|
||||
y = np.concatenate([x, x])
|
||||
self.assertAllClose(f_jax(y), restored_f(y))
|
||||
|
||||
def test_readme_example(self):
|
||||
"""Some of the examples from the README."""
|
||||
def image_mask_jax(images, mask):
|
||||
|
@ -15,10 +15,11 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, List, Optional, Sequence
|
||||
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
@ -74,6 +75,26 @@ class OpMetadataGraph:
|
||||
source_line: str
|
||||
|
||||
|
||||
def SaveAndLoadModel(model: tf.Module) -> tf.Module:
|
||||
# Roundtrip through saved model on disk.
|
||||
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
|
||||
tf.saved_model.save(
|
||||
model, model_dir,
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
|
||||
restored_model = tf.saved_model.load(model_dir)
|
||||
return restored_model
|
||||
|
||||
def SaveAndLoadFunction(f_tf: Callable,
|
||||
input_signature: Sequence[tf.TensorSpec]) -> Callable:
|
||||
# Roundtrip through saved model on disk
|
||||
model = tf.Module()
|
||||
model.f = tf.function(f_tf,
|
||||
autograph=False,
|
||||
input_signature=input_signature)
|
||||
restored = SaveAndLoadModel(model)
|
||||
return restored.f
|
||||
|
||||
|
||||
class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -344,7 +365,6 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
|
||||
return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)
|
||||
|
||||
|
||||
def CheckOpMetadata(self, jax_fun, x,
|
||||
expected: Sequence[OpMetadataGraph],
|
||||
include_xla_op_metadata=True):
|
||||
|
Loading…
x
Reference in New Issue
Block a user