mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax2tf] Fix the round-trip call_tf(convert)
Also cleaned the handling of global state in jax2tf.
This commit is contained in:
parent
3d1a6a308e
commit
1994f6df4a
@ -8,10 +8,16 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||||
-->
|
-->
|
||||||
|
|
||||||
## jax 0.2.14 (unreleased)
|
## jax 0.2.15 (unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...master).
|
||||||
|
* New features:
|
||||||
|
|
||||||
|
* Breaking changes:
|
||||||
|
|
||||||
|
* Bug fixes:
|
||||||
|
* Fixed bug that prevented round-tripping from JAX to TF and back:
|
||||||
|
`jax2tf.call_tf(jax2tf.convert)` ({jax-issue}`#6947`).
|
||||||
|
|
||||||
## jax 0.2.14 (June 10 2021)
|
## jax 0.2.14 (June 10 2021)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
|
||||||
* New features:
|
* New features:
|
||||||
@ -23,7 +29,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* The {func}`jax2tf.convert` supports shape polymorphism even when the
|
* The {func}`jax2tf.convert` supports shape polymorphism even when the
|
||||||
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
|
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
|
||||||
({jax-issue}`#6827`).
|
({jax-issue}`#6827`).
|
||||||
|
|
||||||
* The {func}`jax2tf.convert` generates custom attributes with location information
|
* The {func}`jax2tf.convert` generates custom attributes with location information
|
||||||
in TF ops. The code that XLA generates after jax2tf
|
in TF ops. The code that XLA generates after jax2tf
|
||||||
has the same location information as JAX/XLA.
|
has the same location information as JAX/XLA.
|
||||||
|
@ -772,7 +772,7 @@ As a trivial example, consider computing ``sin(cos(1.))`` with ``sin`` done in J
|
|||||||
# It should return a similar result. This function will be called using
|
# It should return a similar result. This function will be called using
|
||||||
# TensorFlow eager mode if called from outside JAX staged contexts (`jit`,
|
# TensorFlow eager mode if called from outside JAX staged contexts (`jit`,
|
||||||
# `pmap`, or control-flow primitives), and will be called using TensorFlow
|
# `pmap`, or control-flow primitives), and will be called using TensorFlow
|
||||||
# graph mode otherwise. In the latter case, the function must be compileable
|
# compiled mode otherwise. In the latter case, the function must be compileable
|
||||||
# with XLA (`tf.function(func, jit_compile=True)`)
|
# with XLA (`tf.function(func, jit_compile=True)`)
|
||||||
def cos_tf(x):
|
def cos_tf(x):
|
||||||
return tf.math.cos(x)
|
return tf.math.cos(x)
|
||||||
|
@ -23,18 +23,18 @@ https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#call
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable
|
from typing import Callable, Sequence
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import core
|
from jax import core
|
||||||
from jax import dlpack
|
from jax import dlpack
|
||||||
from jax import dtypes
|
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import tree_util
|
from jax import tree_util
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax.interpreters import xla
|
from jax.interpreters import xla
|
||||||
from jax.lib import xla_bridge
|
from jax.lib import xla_bridge
|
||||||
from jax.lib import xla_client
|
from jax.lib import xla_client
|
||||||
|
from . import jax2tf as jax2tf_internal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf # type: ignore[import]
|
import tensorflow as tf # type: ignore[import]
|
||||||
@ -86,19 +86,25 @@ def call_tf(func_tf: Callable) -> Callable:
|
|||||||
|
|
||||||
args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax)
|
args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax)
|
||||||
args_tf_sig_flat = [
|
args_tf_sig_flat = [
|
||||||
tf.TensorSpec(np.shape(a_jax), _to_tf_dtype(_dtype(a_jax)))
|
tf.TensorSpec(np.shape(a_jax), jax2tf_internal._to_tf_dtype(_dtype(a_jax)))
|
||||||
for a_jax in args_jax_flat
|
for a_jax in args_jax_flat
|
||||||
]
|
]
|
||||||
args_tf_sig = tf.nest.map_structure(
|
args_tf_sig = tf.nest.map_structure(
|
||||||
lambda a_jax: tf.TensorSpec(
|
lambda a_jax: tf.TensorSpec(
|
||||||
np.shape(a_jax), _to_tf_dtype(_dtype(a_jax))), args_jax)
|
np.shape(a_jax), jax2tf_internal._to_tf_dtype(_dtype(a_jax))), args_jax)
|
||||||
func_tf_concrete = tf.function(func_tf).get_concrete_function(*args_tf_sig)
|
|
||||||
|
# Trace once through the function to get the result shape
|
||||||
|
with jax2tf_internal.inside_call_tf():
|
||||||
|
func_tf_concrete = tf.function(func_tf).get_concrete_function(*args_tf_sig)
|
||||||
|
|
||||||
res_tf_sig_flat, res_treedef = tree_util.tree_flatten(
|
res_tf_sig_flat, res_treedef = tree_util.tree_flatten(
|
||||||
func_tf_concrete.structured_outputs)
|
func_tf_concrete.structured_outputs)
|
||||||
|
|
||||||
res_jax_flat = call_tf_p.bind(
|
res_jax_flat = call_tf_p.bind(
|
||||||
*args_jax_flat,
|
*args_jax_flat,
|
||||||
|
# Carry the actual function such that op-by-op call can call in TF eager mode.
|
||||||
func_tf=func_tf,
|
func_tf=func_tf,
|
||||||
|
func_tf_concrete=func_tf_concrete,
|
||||||
args_treedef=args_jax_treedef,
|
args_treedef=args_jax_treedef,
|
||||||
args_tf_sig_flat=args_tf_sig_flat,
|
args_tf_sig_flat=args_tf_sig_flat,
|
||||||
res_treedef=res_treedef,
|
res_treedef=res_treedef,
|
||||||
@ -167,7 +173,9 @@ def _call_tf_impl(*args_jax_flat, args_treedef, func_tf, **_):
|
|||||||
return tf.constant(np.asarray(arg_jax))
|
return tf.constant(np.asarray(arg_jax))
|
||||||
|
|
||||||
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
|
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
|
||||||
res_tf = func_tf(*args_treedef.unflatten(args_tf_flat))
|
with jax2tf_internal.inside_call_tf():
|
||||||
|
# Call in TF eager mode
|
||||||
|
res_tf = func_tf(*args_treedef.unflatten(args_tf_flat))
|
||||||
res_tf_flat, _ = tree_util.tree_flatten(res_tf)
|
res_tf_flat, _ = tree_util.tree_flatten(res_tf)
|
||||||
# TODO(necula): check the result for tree and aval
|
# TODO(necula): check the result for tree and aval
|
||||||
|
|
||||||
@ -190,7 +198,7 @@ call_tf_p.def_impl(_call_tf_impl)
|
|||||||
|
|
||||||
def _call_tf_abstract_eval(*_, res_tf_sig_flat, **__):
|
def _call_tf_abstract_eval(*_, res_tf_sig_flat, **__):
|
||||||
return tuple([
|
return tuple([
|
||||||
core.ShapedArray(np.shape(r), _to_jax_dtype(r.dtype))
|
core.ShapedArray(np.shape(r), jax2tf_internal._to_jax_dtype(r.dtype))
|
||||||
for r in res_tf_sig_flat
|
for r in res_tf_sig_flat
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -198,7 +206,7 @@ def _call_tf_abstract_eval(*_, res_tf_sig_flat, **__):
|
|||||||
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
|
call_tf_p.def_abstract_eval(_call_tf_abstract_eval)
|
||||||
|
|
||||||
|
|
||||||
def _call_tf_translation_rule(builder, *args_op, func_tf,
|
def _call_tf_translation_rule(builder, *args_op, func_tf, func_tf_concrete,
|
||||||
args_treedef, args_tf_sig_flat, res_tf_sig_flat,
|
args_treedef, args_tf_sig_flat, res_tf_sig_flat,
|
||||||
**_):
|
**_):
|
||||||
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
|
# TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
|
||||||
@ -209,7 +217,7 @@ def _call_tf_translation_rule(builder, *args_op, func_tf,
|
|||||||
]
|
]
|
||||||
args_tf = args_treedef.unflatten(args_tf_flat)
|
args_tf = args_treedef.unflatten(args_tf_flat)
|
||||||
func_tf = tf.function(func_tf, jit_compile=True)
|
func_tf = tf.function(func_tf, jit_compile=True)
|
||||||
func_tf_concrete = func_tf.get_concrete_function(*args_tf)
|
#func_tf_concrete = func_tf.get_concrete_function(*args_tf)
|
||||||
captured_ops = [] # Same order as captured_inputs
|
captured_ops = [] # Same order as captured_inputs
|
||||||
if func_tf_concrete.captured_inputs:
|
if func_tf_concrete.captured_inputs:
|
||||||
# The function uses either captured variables or tensors.
|
# The function uses either captured variables or tensors.
|
||||||
@ -248,13 +256,15 @@ def _call_tf_translation_rule(builder, *args_op, func_tf,
|
|||||||
|
|
||||||
xla.translations[call_tf_p] = _call_tf_translation_rule
|
xla.translations[call_tf_p] = _call_tf_translation_rule
|
||||||
|
|
||||||
|
TfVal = jax2tf_internal.TfVal
|
||||||
|
def _jax2tf_call_tf(*args: TfVal,
|
||||||
|
_in_avals: Sequence[core.ShapedArray],
|
||||||
|
_out_aval: core.ShapedArray,
|
||||||
|
func_tf: Callable,
|
||||||
|
**kwargs) -> TfVal:
|
||||||
|
res_tf = func_tf(*args)
|
||||||
|
res_tf_flat = tf.nest.flatten(res_tf)
|
||||||
|
# TODO: check that the return values have the right signature
|
||||||
|
return res_tf_flat
|
||||||
|
|
||||||
def _to_tf_dtype(jax_dtype):
|
jax2tf_internal.tf_impl_with_avals[call_tf_p] = _jax2tf_call_tf
|
||||||
if jax_dtype == dtypes.float0:
|
|
||||||
return tf.float32
|
|
||||||
else:
|
|
||||||
return tf.dtypes.as_dtype(jax_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_jax_dtype(tf_dtype):
|
|
||||||
return tf_dtype.as_numpy_dtype
|
|
||||||
|
@ -81,6 +81,10 @@ TfVal = Any
|
|||||||
DType = Any
|
DType = Any
|
||||||
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
||||||
|
|
||||||
|
# A dimension environment maps dimension variables to TF expressions that
|
||||||
|
# compute the value of the dimension. These expressions refer to the TF
|
||||||
|
# function arguments.
|
||||||
|
_ShapeEnv = Dict[str, TfVal]
|
||||||
|
|
||||||
def _is_tfval(v: TfVal) -> bool:
|
def _is_tfval(v: TfVal) -> bool:
|
||||||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||||||
@ -111,12 +115,6 @@ tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
|
|||||||
# core.AbstractValue, or a tuple thereof when primitive.multiple_results).
|
# core.AbstractValue, or a tuple thereof when primitive.multiple_results).
|
||||||
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
|
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
|
||||||
|
|
||||||
# XLA is not linked in all environments; when converting a primitive, if this
|
|
||||||
# variable is disabled, we try harder to use only standard TF ops if they are
|
|
||||||
# applicable to the concrete use case; if the resulting conversion path ends up
|
|
||||||
# requiring a TFXLA operation, an exception is thrown instead.
|
|
||||||
_enable_xla = True
|
|
||||||
|
|
||||||
# In order to ensure that JAX picks up the proper user-frame for source
|
# In order to ensure that JAX picks up the proper user-frame for source
|
||||||
# locations we will register the TensorFlow source path as an internal
|
# locations we will register the TensorFlow source path as an internal
|
||||||
# path with source_info_util. The typical stack when a JAX primitive
|
# path with source_info_util. The typical stack when a JAX primitive
|
||||||
@ -132,17 +130,48 @@ _enable_xla = True
|
|||||||
# also.
|
# also.
|
||||||
# We register the TensorFlow source path lazily
|
# We register the TensorFlow source path lazily
|
||||||
_has_registered_tf_source_path = False
|
_has_registered_tf_source_path = False
|
||||||
# Whether to actually include XLA op metadata
|
|
||||||
_include_xla_op_metadata = True
|
|
||||||
|
|
||||||
|
class _ThreadLocalState(threading.local):
|
||||||
|
def __init__(self):
|
||||||
|
self.name_stack = ""
|
||||||
|
# XLA is not linked in all environments; when converting a primitive, if this
|
||||||
|
# variable is disabled, we try harder to use only standard TF ops if they are
|
||||||
|
# applicable to the concrete use case; if the resulting conversion path ends up
|
||||||
|
# requiring a TFXLA operation, an exception is thrown instead.
|
||||||
|
self.enable_xla = True
|
||||||
|
|
||||||
|
# Keep track if we are inside a call_tf. In that context we disable the
|
||||||
|
# safety check that we are not inside JAX transformations.
|
||||||
|
self.inside_call_tf = False
|
||||||
|
|
||||||
|
# Maps dimension variables to TF expressions
|
||||||
|
self.shape_env: _ShapeEnv = {}
|
||||||
|
|
||||||
|
# Whether to actually include XLA op metadata in the generated TF ops
|
||||||
|
self.include_xla_op_metadata = True
|
||||||
|
|
||||||
|
_thread_local_state = _ThreadLocalState()
|
||||||
|
|
||||||
|
def _get_current_name_stack():
|
||||||
|
return _thread_local_state.name_stack
|
||||||
def _xla_disabled_error(primitive_name: str,
|
def _xla_disabled_error(primitive_name: str,
|
||||||
extra_msg: Optional[str] = None) -> Exception:
|
extra_msg: Optional[str] = None) -> Exception:
|
||||||
assert not _enable_xla
|
assert not _thread_local_state.enable_xla
|
||||||
msg = f"Call to {primitive_name} cannot be converted with enable_xla=False."
|
msg = f"Call to {primitive_name} cannot be converted with enable_xla=False."
|
||||||
if extra_msg:
|
if extra_msg:
|
||||||
msg += f" {extra_msg}"
|
msg += f" {extra_msg}"
|
||||||
return NotImplementedError(msg)
|
return NotImplementedError(msg)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def inside_call_tf():
|
||||||
|
# Set the inside_call_tf flag for a context.
|
||||||
|
prev = _thread_local_state.inside_call_tf
|
||||||
|
_thread_local_state.inside_call_tf = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_thread_local_state.inside_call_tf = prev
|
||||||
|
|
||||||
@partial(api_util.api_hook, tag="jax2tf_convert")
|
@partial(api_util.api_hook, tag="jax2tf_convert")
|
||||||
def convert(fun: Callable,
|
def convert(fun: Callable,
|
||||||
*,
|
*,
|
||||||
@ -214,9 +243,10 @@ def convert(fun: Callable,
|
|||||||
name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf"))
|
name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf"))
|
||||||
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
|
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
|
||||||
# TODO: is there a better way to check if we are inside a transformation?
|
# TODO: is there a better way to check if we are inside a transformation?
|
||||||
if not core.trace_state_clean():
|
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
|
||||||
|
# It is Ok to nest convert when we are inside a call_tf
|
||||||
raise ValueError("convert must be used outside all JAX transformations." +
|
raise ValueError("convert must be used outside all JAX transformations." +
|
||||||
f"Trace state: {core.thread_local_state.trace_state}")
|
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
|
||||||
|
|
||||||
def check_arg(a):
|
def check_arg(a):
|
||||||
if not _is_tfval(a):
|
if not _is_tfval(a):
|
||||||
@ -308,16 +338,15 @@ def convert(fun: Callable,
|
|||||||
return in_cts
|
return in_cts
|
||||||
|
|
||||||
try:
|
try:
|
||||||
global _shape_env
|
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
|
||||||
assert not _shape_env, f"Unexpected shape environment {_shape_env}"
|
|
||||||
global _enable_xla
|
|
||||||
prev_enable_xla = _enable_xla
|
|
||||||
_enable_xla = enable_xla
|
|
||||||
global _include_xla_op_metadata
|
|
||||||
prev_include_xla_op_metadata = _include_xla_op_metadata
|
|
||||||
_include_xla_op_metadata = False
|
|
||||||
|
|
||||||
_shape_env = shapeenv
|
prev_enable_xla = _thread_local_state.enable_xla
|
||||||
|
_thread_local_state.enable_xla = enable_xla
|
||||||
|
|
||||||
|
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
|
||||||
|
_thread_local_state.include_xla_op_metadata = False
|
||||||
|
|
||||||
|
_thread_local_state.shape_env = shapeenv
|
||||||
global _has_registered_tf_source_path
|
global _has_registered_tf_source_path
|
||||||
if not _has_registered_tf_source_path:
|
if not _has_registered_tf_source_path:
|
||||||
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
|
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
|
||||||
@ -345,9 +374,9 @@ def convert(fun: Callable,
|
|||||||
for o, _ in out_flat_raw
|
for o, _ in out_flat_raw
|
||||||
]
|
]
|
||||||
finally:
|
finally:
|
||||||
_shape_env = {}
|
_thread_local_state.shape_env = {}
|
||||||
_enable_xla = prev_enable_xla
|
_thread_local_state.enable_xla = prev_enable_xla
|
||||||
_include_xla_op_metadata = prev_include_xla_op_metadata
|
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
|
||||||
|
|
||||||
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
|
out_flat = [tf.identity(x, "jax2tf_out") for x in out_flat]
|
||||||
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
|
out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
|
||||||
@ -371,15 +400,6 @@ def dtype_of_val(val: TfVal) -> DType:
|
|||||||
|
|
||||||
# Internals
|
# Internals
|
||||||
|
|
||||||
# TODO: add all globals here
|
|
||||||
class _ThreadLocalState(threading.local):
|
|
||||||
def __init__(self):
|
|
||||||
self.name_stack = ""
|
|
||||||
_thread_local_state = _ThreadLocalState()
|
|
||||||
|
|
||||||
def _get_current_name_stack():
|
|
||||||
return _thread_local_state.name_stack
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _extended_name_stack(extra_name_stack: Optional[str]):
|
def _extended_name_stack(extra_name_stack: Optional[str]):
|
||||||
prev_name_stack = _thread_local_state.name_stack
|
prev_name_stack = _thread_local_state.name_stack
|
||||||
@ -526,10 +546,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
|||||||
return tf.convert_to_tensor(val, dtype=conversion_dtype), jax_dtype
|
return tf.convert_to_tensor(val, dtype=conversion_dtype), jax_dtype
|
||||||
|
|
||||||
|
|
||||||
# A dimension environment maps dimension variables to TF expressions that
|
|
||||||
# compute the value of the dimension. These expressions refer to the TF
|
|
||||||
# function arguments.
|
|
||||||
_ShapeEnv = Dict[str, TfVal]
|
|
||||||
def _args_to_avals_and_env(
|
def _args_to_avals_and_env(
|
||||||
args: Sequence[TfVal],
|
args: Sequence[TfVal],
|
||||||
arg_jax_dtypes: Sequence[DType],
|
arg_jax_dtypes: Sequence[DType],
|
||||||
@ -573,14 +589,10 @@ def _args_to_avals_and_env(
|
|||||||
return avals, shapeenv
|
return avals, shapeenv
|
||||||
|
|
||||||
|
|
||||||
# A shape environment maps shape variables to TfVal.
|
|
||||||
_shape_env = {} # type: _ShapeEnv
|
|
||||||
|
|
||||||
|
|
||||||
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
||||||
assert all(map(lambda x: x is not None, shape)), (
|
assert all(map(lambda x: x is not None, shape)), (
|
||||||
f"Argument shape should be a valid JAX shape but got {shape}")
|
f"Argument shape should be a valid JAX shape but got {shape}")
|
||||||
return shape_poly.eval_shape(shape, _shape_env)
|
return shape_poly.eval_shape(shape, _thread_local_state.shape_env)
|
||||||
|
|
||||||
|
|
||||||
def shape_as_value(x):
|
def shape_as_value(x):
|
||||||
@ -800,7 +812,7 @@ class TensorFlowTrace(core.Trace):
|
|||||||
else:
|
else:
|
||||||
return impl(*args_tf, **params)
|
return impl(*args_tf, **params)
|
||||||
|
|
||||||
if _include_xla_op_metadata:
|
if _thread_local_state.include_xla_op_metadata:
|
||||||
op_metadata = xla.make_op_metadata(primitive, params,
|
op_metadata = xla.make_op_metadata(primitive, params,
|
||||||
name_stack=_get_current_name_stack(),
|
name_stack=_get_current_name_stack(),
|
||||||
source_info=source_info_util.current())
|
source_info=source_info_util.current())
|
||||||
@ -953,7 +965,6 @@ tf_not_yet_impl = [
|
|||||||
"lu_pivots_to_permutation",
|
"lu_pivots_to_permutation",
|
||||||
"rng_bit_generator",
|
"rng_bit_generator",
|
||||||
"xla_pmap",
|
"xla_pmap",
|
||||||
"call_tf",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||||
@ -1521,7 +1532,7 @@ def _conv_general_dilated(lhs, rhs, *,
|
|||||||
_out_aval: core.AbstractValue):
|
_out_aval: core.AbstractValue):
|
||||||
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
||||||
out_tf_shape = _aval_to_tf_shape(_out_aval)
|
out_tf_shape = _aval_to_tf_shape(_out_aval)
|
||||||
if not _enable_xla:
|
if not _thread_local_state.enable_xla:
|
||||||
return _try_tf_conv(
|
return _try_tf_conv(
|
||||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||||
dimension_numbers, feature_group_count, batch_group_count,
|
dimension_numbers, feature_group_count, batch_group_count,
|
||||||
@ -1580,7 +1591,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
|
|||||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||||
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
|
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
|
||||||
if _enable_xla:
|
if _thread_local_state.enable_xla:
|
||||||
dnums_proto = xla_data_pb2.DotDimensionNumbers()
|
dnums_proto = xla_data_pb2.DotDimensionNumbers()
|
||||||
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
|
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
|
||||||
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
|
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
|
||||||
@ -1723,7 +1734,7 @@ def _pad(operand, padding_value, *, padding_config,
|
|||||||
_out_aval: core.AbstractValue):
|
_out_aval: core.AbstractValue):
|
||||||
del _in_avals
|
del _in_avals
|
||||||
low, high, interior = util.unzip3(padding_config)
|
low, high, interior = util.unzip3(padding_config)
|
||||||
if _enable_xla:
|
if _thread_local_state.enable_xla:
|
||||||
out = tfxla.pad(operand, padding_value, low, high, interior)
|
out = tfxla.pad(operand, padding_value, low, high, interior)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -1911,7 +1922,7 @@ def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions,
|
|||||||
"""
|
"""
|
||||||
assert len(consts) == 0, "Reduction computation cannot have constants"
|
assert len(consts) == 0, "Reduction computation cannot have constants"
|
||||||
|
|
||||||
if not _enable_xla:
|
if not _thread_local_state.enable_xla:
|
||||||
raise _xla_disabled_error("reduce_window")
|
raise _xla_disabled_error("reduce_window")
|
||||||
|
|
||||||
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
|
def reducer(arg1: TfVal, arg2: TfVal) -> TfVal:
|
||||||
@ -2029,7 +2040,7 @@ def _specialized_reduce_window(reducer,
|
|||||||
Returns:
|
Returns:
|
||||||
The reduced operand.
|
The reduced operand.
|
||||||
"""
|
"""
|
||||||
if not _enable_xla and name in ["reduce_window_max", "reduce_window_sum"]:
|
if not _thread_local_state.enable_xla and name in ["reduce_window_max", "reduce_window_sum"]:
|
||||||
return _try_tf_pool(name, operand, window_dimensions, window_strides,
|
return _try_tf_pool(name, operand, window_dimensions, window_strides,
|
||||||
padding, base_dilation, window_dilation)
|
padding, base_dilation, window_dilation)
|
||||||
|
|
||||||
@ -2123,7 +2134,7 @@ tf_impl[lax.select_and_scatter_p] = _select_and_scatter
|
|||||||
@partial(bool_to_int8, argnums=(0, 1))
|
@partial(bool_to_int8, argnums=(0, 1))
|
||||||
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||||
window_strides, padding, _in_avals, _out_aval):
|
window_strides, padding, _in_avals, _out_aval):
|
||||||
if not _enable_xla:
|
if not _thread_local_state.enable_xla:
|
||||||
raise _xla_disabled_error("select_and_scatter_add")
|
raise _xla_disabled_error("select_and_scatter_add")
|
||||||
init_value = tf.zeros((), operand.dtype)
|
init_value = tf.zeros((), operand.dtype)
|
||||||
select_fn = (
|
select_fn = (
|
||||||
@ -2173,7 +2184,7 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
|
|||||||
_in_avals, _out_aval):
|
_in_avals, _out_aval):
|
||||||
"""Tensorflow implementation of gather."""
|
"""Tensorflow implementation of gather."""
|
||||||
del _in_avals, unique_indices
|
del _in_avals, unique_indices
|
||||||
if not _enable_xla:
|
if not _thread_local_state.enable_xla:
|
||||||
raise _xla_disabled_error("gather")
|
raise _xla_disabled_error("gather")
|
||||||
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
|
proto = _gather_dimensions_proto(start_indices.shape, dimension_numbers)
|
||||||
slice_sizes_tf = _eval_shape(slice_sizes)
|
slice_sizes_tf = _eval_shape(slice_sizes)
|
||||||
@ -2209,7 +2220,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes,
|
|||||||
start_indices = tf.stack(start_indices)
|
start_indices = tf.stack(start_indices)
|
||||||
slice_sizes = _eval_shape(slice_sizes)
|
slice_sizes = _eval_shape(slice_sizes)
|
||||||
|
|
||||||
if _enable_xla:
|
if _thread_local_state.enable_xla:
|
||||||
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes)
|
res = tfxla.dynamic_slice(operand, start_indices, size_indices=slice_sizes)
|
||||||
# TODO: implement shape inference for XlaDynamicSlice
|
# TODO: implement shape inference for XlaDynamicSlice
|
||||||
res.set_shape(_aval_to_tf_shape(_out_aval))
|
res.set_shape(_aval_to_tf_shape(_out_aval))
|
||||||
@ -2259,7 +2270,7 @@ def _scatter(operand, scatter_indices, updates, *, update_jaxpr, update_consts,
|
|||||||
del unique_indices, _in_avals
|
del unique_indices, _in_avals
|
||||||
assert len(update_consts) == 0, "Update computation cannot have constants"
|
assert len(update_consts) == 0, "Update computation cannot have constants"
|
||||||
|
|
||||||
if not _enable_xla:
|
if not _thread_local_state.enable_xla:
|
||||||
raise _xla_disabled_error("scatter")
|
raise _xla_disabled_error("scatter")
|
||||||
|
|
||||||
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
|
proto = _scatter_dimensions_proto(scatter_indices.shape, dimension_numbers)
|
||||||
@ -2293,7 +2304,7 @@ tf_impl_with_avals[lax.scatter_add_p] = _scatter
|
|||||||
|
|
||||||
|
|
||||||
def _dynamic_update_slice(operand, update, *start_indices):
|
def _dynamic_update_slice(operand, update, *start_indices):
|
||||||
if not _enable_xla:
|
if not _thread_local_state.enable_xla:
|
||||||
raise _xla_disabled_error("dynamic_update_slice")
|
raise _xla_disabled_error("dynamic_update_slice")
|
||||||
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
|
return tfxla.dynamic_update_slice(operand, update, tf.stack(start_indices))
|
||||||
|
|
||||||
@ -2428,7 +2439,7 @@ tf_impl[lax.top_k_p] = _top_k
|
|||||||
|
|
||||||
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
def _sort(*operands: TfVal, dimension: int, is_stable: bool,
|
||||||
num_keys: int) -> Tuple[TfVal, ...]:
|
num_keys: int) -> Tuple[TfVal, ...]:
|
||||||
if not _enable_xla:
|
if not _thread_local_state.enable_xla:
|
||||||
raise _xla_disabled_error("sort")
|
raise _xla_disabled_error("sort")
|
||||||
assert 1 <= num_keys <= len(operands)
|
assert 1 <= num_keys <= len(operands)
|
||||||
assert 0 <= dimension < len(
|
assert 0 <= dimension < len(
|
||||||
|
@ -462,7 +462,7 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
|
|||||||
spec_ = spec.replace(" ", "")
|
spec_ = spec.replace(" ", "")
|
||||||
if spec_[0] == "(":
|
if spec_[0] == "(":
|
||||||
if spec_[-1] != ")":
|
if spec_[-1] != ")":
|
||||||
raise ValueError(spec)
|
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
|
||||||
spec_ = spec_[1:-1]
|
spec_ = spec_[1:-1]
|
||||||
spec_ = spec_.rstrip(",")
|
spec_ = spec_.rstrip(",")
|
||||||
if not spec_:
|
if not spec_:
|
||||||
|
@ -25,6 +25,7 @@ from jax import numpy as jnp
|
|||||||
from jax import test_util as jtu
|
from jax import test_util as jtu
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
from jax.experimental import jax2tf
|
from jax.experimental import jax2tf
|
||||||
|
from jax.experimental.jax2tf.tests import tf_test_util
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -58,10 +59,12 @@ class CallTfTest(jtu.JaxTestCase):
|
|||||||
_ = tf.add(1, 1)
|
_ = tf.add(1, 1)
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
@parameterized_jit
|
#@parameterized_jit
|
||||||
def test_eval_scalar_arg(self, with_jit=False):
|
def test_eval_scalar_arg(self, with_jit=True):
|
||||||
|
def f_tf(x):
|
||||||
|
return tf.math.sin(x)
|
||||||
x = 3.
|
x = 3.
|
||||||
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
|
res = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
|
||||||
self.assertAllClose(jnp.sin(x), res, check_dtypes=False)
|
self.assertAllClose(jnp.sin(x), res, check_dtypes=False)
|
||||||
|
|
||||||
@parameterized_jit
|
@parameterized_jit
|
||||||
@ -119,6 +122,16 @@ class CallTfTest(jtu.JaxTestCase):
|
|||||||
res = fun_jax(x, y)
|
res = fun_jax(x, y)
|
||||||
self.assertAllClose((np.float32(12.), np.float64(11.)), res)
|
self.assertAllClose((np.float32(12.), np.float64(11.)), res)
|
||||||
|
|
||||||
|
def test_eval_non_compileable(self):
|
||||||
|
# Check that in op-by-op we call a function in eager mode.
|
||||||
|
def f_tf_non_compileable(x):
|
||||||
|
return tf.strings.length(tf.strings.format("Hello {}!", [x]))
|
||||||
|
|
||||||
|
f_jax = jax2tf.call_tf(f_tf_non_compileable)
|
||||||
|
x = np.float32(0.7)
|
||||||
|
self.assertAllClose(f_tf_non_compileable(x).numpy(), f_jax(x))
|
||||||
|
|
||||||
|
|
||||||
@parameterized_jit
|
@parameterized_jit
|
||||||
def test_control_flow(self, with_jit=True):
|
def test_control_flow(self, with_jit=True):
|
||||||
|
|
||||||
@ -319,6 +332,89 @@ class CallTfTest(jtu.JaxTestCase):
|
|||||||
res = jax.pmap(fun_jax)(x)
|
res = jax.pmap(fun_jax)(x)
|
||||||
self.assertAllClose(np.float32(3. * (x + 2)), res)
|
self.assertAllClose(np.float32(3. * (x + 2)), res)
|
||||||
|
|
||||||
|
def test_round_trip(self):
|
||||||
|
f_jax = jnp.sin
|
||||||
|
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
||||||
|
x = np.float32(0.7)
|
||||||
|
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||||
|
|
||||||
|
def test_round_trip_custom_grad(self):
|
||||||
|
@jax.custom_vjp
|
||||||
|
def f(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
# f_fwd: a -> (b, residual)
|
||||||
|
def f_fwd(x):
|
||||||
|
return f(x), np.float32(3.) * x
|
||||||
|
# f_bwd: (residual, CT b) -> [CT a]
|
||||||
|
def f_bwd(residual, ct_b):
|
||||||
|
return residual * ct_b,
|
||||||
|
|
||||||
|
f.defvjp(f_fwd, f_bwd)
|
||||||
|
|
||||||
|
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
|
||||||
|
x = np.float32(0.7)
|
||||||
|
self.assertAllClose(f(x), f_rt(x))
|
||||||
|
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
|
||||||
|
|
||||||
|
def test_round_trip_shape_poly(self):
|
||||||
|
f_jax = jnp.sin
|
||||||
|
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
|
||||||
|
polymorphic_shapes=["(b, ...)"]))
|
||||||
|
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||||
|
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
||||||
|
|
||||||
|
def test_round_trip_saved_model_shape_poly(self):
|
||||||
|
tracing_count = 0
|
||||||
|
def f_jax(x):
|
||||||
|
nonlocal tracing_count
|
||||||
|
tracing_count += 1
|
||||||
|
return jnp.sin(x)
|
||||||
|
|
||||||
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||||
|
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||||
|
res_jax = f_jax(x)
|
||||||
|
self.assertEqual(1, tracing_count)
|
||||||
|
# Will trace twice, it seems. Once to get the result signature, and once again
|
||||||
|
# for the actual saving.
|
||||||
|
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||||
|
self.assertGreaterEqual(tracing_count, 2)
|
||||||
|
tracing_count = 0
|
||||||
|
f_jax_rt = jax2tf.call_tf(restored_f)
|
||||||
|
self.assertAllClose(res_jax, f_jax_rt(x))
|
||||||
|
# Ensure that restored_f works at other batch size as well
|
||||||
|
y = np.concatenate([x, x])
|
||||||
|
self.assertEqual(0, tracing_count)
|
||||||
|
res_jax_y = f_jax(y)
|
||||||
|
self.assertEqual(1, tracing_count)
|
||||||
|
# No more tracing for f_jax_rt
|
||||||
|
self.assertAllClose(res_jax_y, f_jax_rt(y))
|
||||||
|
self.assertEqual(1, tracing_count)
|
||||||
|
|
||||||
|
def test_round_trip_custom_grad_saved_model(self):
|
||||||
|
@jax.custom_vjp
|
||||||
|
def f(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
# f_fwd: a -> (b, residual)
|
||||||
|
def f_fwd(x):
|
||||||
|
return f(x), np.float32(3.) * x
|
||||||
|
# f_bwd: (residual, CT b) -> [CT a]
|
||||||
|
def f_bwd(residual, ct_b):
|
||||||
|
return residual * ct_b,
|
||||||
|
|
||||||
|
f.defvjp(f_fwd, f_bwd)
|
||||||
|
def g(x):
|
||||||
|
return jnp.sum(f(x))
|
||||||
|
|
||||||
|
g_tf = tf_test_util.SaveAndLoadFunction(
|
||||||
|
jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
|
||||||
|
[tf.TensorSpec([None], dtype=tf.float32)])
|
||||||
|
g_rt = jax2tf.call_tf(g_tf)
|
||||||
|
x = np.array([0.7], dtype=np.float32)
|
||||||
|
self.assertAllClose(g(x), g_rt(x))
|
||||||
|
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
|
||||||
|
|
||||||
def test_module_documentation(self):
|
def test_module_documentation(self):
|
||||||
def cos_tf(x):
|
def cos_tf(x):
|
||||||
return tf.math.cos(x)
|
return tf.math.cos(x)
|
||||||
@ -342,6 +438,12 @@ class CallTfTest(jtu.JaxTestCase):
|
|||||||
print(jax.make_jaxpr(cos_tf_sin_jax)(x))
|
print(jax.make_jaxpr(cos_tf_sin_jax)(x))
|
||||||
print(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
|
print(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
|
||||||
|
|
||||||
|
def test_round_trip_reverse(self):
|
||||||
|
f_tf = tf.math.sin
|
||||||
|
f_tf_rt = jax2tf.convert(jax2tf.call_tf(f_tf))
|
||||||
|
x = np.float32(0.7)
|
||||||
|
self.assertAllClose(f_tf(x).numpy(), f_tf_rt(x).numpy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
@ -32,15 +30,6 @@ config.parse_flags_with_absl()
|
|||||||
|
|
||||||
class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||||
|
|
||||||
def save_and_load_model(self, model: tf.Module) -> tf.Module:
|
|
||||||
# Roundtrip through saved model on disk.
|
|
||||||
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
|
|
||||||
tf.saved_model.save(
|
|
||||||
model, model_dir,
|
|
||||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
|
|
||||||
restored_model = tf.saved_model.load(model_dir)
|
|
||||||
return restored_model
|
|
||||||
|
|
||||||
def test_eval(self):
|
def test_eval(self):
|
||||||
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
||||||
model = tf.Module()
|
model = tf.Module()
|
||||||
@ -50,7 +39,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
|||||||
)
|
)
|
||||||
x = np.array(0.7, dtype=jnp.float32)
|
x = np.array(0.7, dtype=jnp.float32)
|
||||||
self.assertAllClose(model.f(x), f_jax(x))
|
self.assertAllClose(model.f(x), f_jax(x))
|
||||||
restored_model = self.save_and_load_model(model)
|
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||||
|
|
||||||
def test_gradient_disabled(self):
|
def test_gradient_disabled(self):
|
||||||
@ -62,7 +51,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
|||||||
input_signature=[tf.TensorSpec([], tf.float32)])
|
input_signature=[tf.TensorSpec([], tf.float32)])
|
||||||
x = np.array(0.7, dtype=jnp.float32)
|
x = np.array(0.7, dtype=jnp.float32)
|
||||||
self.assertAllClose(model.f(x), f_jax(x))
|
self.assertAllClose(model.f(x), f_jax(x))
|
||||||
restored_model = self.save_and_load_model(model)
|
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||||
xv = tf.Variable(0.7, dtype=jnp.float32)
|
xv = tf.Variable(0.7, dtype=jnp.float32)
|
||||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||||
|
|
||||||
@ -92,7 +81,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
|||||||
input_signature=[tf.TensorSpec([], tf.float32)])
|
input_signature=[tf.TensorSpec([], tf.float32)])
|
||||||
x = np.array(0.7, dtype=jnp.float32)
|
x = np.array(0.7, dtype=jnp.float32)
|
||||||
self.assertAllClose(model.f(x), f_jax(x))
|
self.assertAllClose(model.f(x), f_jax(x))
|
||||||
restored_model = self.save_and_load_model(model)
|
restored_model = tf_test_util.SaveAndLoadModel(model)
|
||||||
xv = tf.Variable(0.7, dtype=jnp.float32)
|
xv = tf.Variable(0.7, dtype=jnp.float32)
|
||||||
self.assertAllClose(restored_model.f(x), f_jax(x))
|
self.assertAllClose(restored_model.f(x), f_jax(x))
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
@ -106,14 +95,9 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
|||||||
# JAX. We check that this information is preserved through a savedmodel
|
# JAX. We check that this information is preserved through a savedmodel
|
||||||
f_tf = jax2tf.convert(f_jax)
|
f_tf = jax2tf.convert(f_jax)
|
||||||
res = f_tf(*args)
|
res = f_tf(*args)
|
||||||
|
|
||||||
model = tf.Module()
|
|
||||||
input_signature = list(tf.TensorSpec(a.shape, a.dtype) for a in args)
|
input_signature = list(tf.TensorSpec(a.shape, a.dtype) for a in args)
|
||||||
model.f = tf.function(f_tf,
|
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, input_signature)
|
||||||
autograph=False,
|
res_restored = restored_f(*args)
|
||||||
input_signature=input_signature)
|
|
||||||
restored_model = self.save_and_load_model(model)
|
|
||||||
res_restored = restored_model.f(*args)
|
|
||||||
self.assertAllClose(res, res_restored)
|
self.assertAllClose(res, res_restored)
|
||||||
|
|
||||||
def test_xla_context_preserved_slice(self):
|
def test_xla_context_preserved_slice(self):
|
||||||
@ -159,12 +143,9 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
|||||||
jnp.sin(np.array([3.14, 2.78], dtype=np.float16)))
|
jnp.sin(np.array([3.14, 2.78], dtype=np.float16)))
|
||||||
|
|
||||||
# Save and restore SavedModel
|
# Save and restore SavedModel
|
||||||
model = tf.Module()
|
restored_f = tf_test_util.SaveAndLoadFunction(composed_fn,
|
||||||
model.f = tf.function(
|
[tf.TensorSpec((2,), dtype=tf.string)])
|
||||||
composed_fn,
|
res_tf_restored = restored_f(x_str)
|
||||||
input_signature=[tf.TensorSpec((2,), dtype=tf.string)])
|
|
||||||
restored_model = self.save_and_load_model(model)
|
|
||||||
res_tf_restored = restored_model.f(x_str)
|
|
||||||
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())
|
self.assertAllClose(res_tf_restored.numpy(), res_tf.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
@ -609,6 +609,16 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
|||||||
res_jax,
|
res_jax,
|
||||||
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
|
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
|
||||||
|
|
||||||
|
def test_saved_model_shape_poly(self):
|
||||||
|
f_jax = jnp.sin
|
||||||
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||||
|
x = np.array([0.7, 0.8], dtype=np.float32)
|
||||||
|
restored_f = tf_test_util.SaveAndLoadFunction(f_tf, [tf.TensorSpec([None], x.dtype)])
|
||||||
|
self.assertAllClose(f_jax(x), restored_f(x))
|
||||||
|
# Ensure that restored_f works at other batch size as well
|
||||||
|
y = np.concatenate([x, x])
|
||||||
|
self.assertAllClose(f_jax(y), restored_f(y))
|
||||||
|
|
||||||
def test_readme_example(self):
|
def test_readme_example(self):
|
||||||
"""Some of the examples from the README."""
|
"""Some of the examples from the README."""
|
||||||
def image_mask_jax(images, mask):
|
def image_mask_jax(images, mask):
|
||||||
|
@ -15,10 +15,11 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from typing import Any, Callable, List, Optional, Sequence
|
from typing import Any, Callable, List, Optional, Sequence
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
import jax
|
import jax
|
||||||
from jax import dtypes
|
from jax import dtypes
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
@ -74,6 +75,26 @@ class OpMetadataGraph:
|
|||||||
source_line: str
|
source_line: str
|
||||||
|
|
||||||
|
|
||||||
|
def SaveAndLoadModel(model: tf.Module) -> tf.Module:
|
||||||
|
# Roundtrip through saved model on disk.
|
||||||
|
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
|
||||||
|
tf.saved_model.save(
|
||||||
|
model, model_dir,
|
||||||
|
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
|
||||||
|
restored_model = tf.saved_model.load(model_dir)
|
||||||
|
return restored_model
|
||||||
|
|
||||||
|
def SaveAndLoadFunction(f_tf: Callable,
|
||||||
|
input_signature: Sequence[tf.TensorSpec]) -> Callable:
|
||||||
|
# Roundtrip through saved model on disk
|
||||||
|
model = tf.Module()
|
||||||
|
model.f = tf.function(f_tf,
|
||||||
|
autograph=False,
|
||||||
|
input_signature=input_signature)
|
||||||
|
restored = SaveAndLoadModel(model)
|
||||||
|
return restored.f
|
||||||
|
|
||||||
|
|
||||||
class JaxToTfTestCase(jtu.JaxTestCase):
|
class JaxToTfTestCase(jtu.JaxTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -344,7 +365,6 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
|||||||
|
|
||||||
return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)
|
return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)
|
||||||
|
|
||||||
|
|
||||||
def CheckOpMetadata(self, jax_fun, x,
|
def CheckOpMetadata(self, jax_fun, x,
|
||||||
expected: Sequence[OpMetadataGraph],
|
expected: Sequence[OpMetadataGraph],
|
||||||
include_xla_op_metadata=True):
|
include_xla_op_metadata=True):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user