Merge pull request #13980 from gnecula:call_tf_effects2

PiperOrigin-RevId: 502346453
This commit is contained in:
jax authors 2023-01-16 03:54:11 -08:00
commit 35820ef1e4
3 changed files with 27 additions and 7 deletions

View File

@ -11,6 +11,11 @@ Remember to align the itemized text with the first line of an item within a list
* Breaking changes
* Deleted `jax.experimental.callback`
* Changes
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)
that can be used to declare whether an instance can be removed or replicated
by JAX optimizations such as dead-code elimination ({jax-issue}`#13980`).
## jaxlib 0.4.2
## jax 0.4.1 (Dec 13, 2022)

View File

@ -59,7 +59,7 @@ TfConcreteFunction = Any
# DLPack, if we are careful.
_DLPACK_PLATFORMS = ("gpu",)
def call_tf(callable_tf: Callable) -> Callable:
def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
@ -85,6 +85,9 @@ def call_tf(callable_tf: Callable) -> Callable:
Args:
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
arguments.
has_side_effects: if True then it ensures that instances of this primitive
are not removed or replicated by JAX optimizations such as dead-code
elimination.
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
op-by-op mode or in a staged context. This callable can be used with
@ -153,7 +156,8 @@ def call_tf(callable_tf: Callable) -> Callable:
# Carry the actual function such that op-by-op call can call in TF eager mode.
callable_flat_tf=callable_flat_tf,
function_flat_tf=function_flat_tf,
args_flat_sig_tf=args_flat_sig_tf)
args_flat_sig_tf=args_flat_sig_tf,
has_side_effects=has_side_effects)
return res_treedef.unflatten(res_jax_flat)
# Define the fwd and bwd custom_vjp functions
@ -254,6 +258,7 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
# Mark the effectful instancess of call_tf
CallTfEffect = enum.Enum('CallTfEffect', ['EFFECT'])
mlir.lowerable_effects.add(CallTfEffect.EFFECT)
@ -264,7 +269,8 @@ custom_derivatives.allowed_effects.add(CallTfEffect.EFFECT)
def _call_tf_abstract_eval(*_,
function_flat_tf,
args_flat_sig_tf, **__):
args_flat_sig_tf,
has_side_effects, **__):
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
@ -272,6 +278,7 @@ def _call_tf_abstract_eval(*_,
def is_fully_known_shape(s):
return s.rank is not None and all([d is not None for d in s])
effects = {CallTfEffect.EFFECT} if has_side_effects else set()
if all([is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes]):
return (
@ -281,7 +288,7 @@ def _call_tf_abstract_eval(*_,
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes)
]),
{CallTfEffect.EFFECT})
effects)
# There are some cases when TF shape inference is not powerful enough to
# figure out the output shapes (e.g., b/128924522), even in situations where
@ -292,9 +299,7 @@ def _call_tf_abstract_eval(*_,
# it should not matter which platform we use.
_, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf,
"CPU")
# Add an effect to the abstract eval rule of call_tf so that JAX's DCE pass
# doesn't prune args passed to call_tf.
return tuple(result_avals), {CallTfEffect.EFFECT}
return tuple(result_avals), effects
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)

View File

@ -624,6 +624,16 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
res = tf.function(f_tf2, autograph=False)(x)
self.assertAllClose(res.numpy(), f_jax(x))
def test_effectful(self):
if not config.jax_array:
raise unittest.SkipTest("Test not intended to work without jax.Array")
x = np.ones((3,), dtype=np.float32)
lower_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=True)).lower(x)
self.assertNotEmpty(lower_effect._lowering.compile_args["unordered_effects"])
lower_no_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=False)).lower(x)
self.assertEmpty(lower_no_effect._lowering.compile_args["unordered_effects"])
def test_module_documentation(self):
def cos_tf(x):