mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13980 from gnecula:call_tf_effects2
PiperOrigin-RevId: 502346453
This commit is contained in:
commit
35820ef1e4
@ -11,6 +11,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Breaking changes
|
||||
* Deleted `jax.experimental.callback`
|
||||
|
||||
* Changes
|
||||
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)
|
||||
that can be used to declare whether an instance can be removed or replicated
|
||||
by JAX optimizations such as dead-code elimination ({jax-issue}`#13980`).
|
||||
|
||||
## jaxlib 0.4.2
|
||||
|
||||
## jax 0.4.1 (Dec 13, 2022)
|
||||
|
@ -59,7 +59,7 @@ TfConcreteFunction = Any
|
||||
# DLPack, if we are careful.
|
||||
_DLPACK_PLATFORMS = ("gpu",)
|
||||
|
||||
def call_tf(callable_tf: Callable) -> Callable:
|
||||
def call_tf(callable_tf: Callable, has_side_effects=True) -> Callable:
|
||||
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
|
||||
|
||||
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
|
||||
@ -85,6 +85,9 @@ def call_tf(callable_tf: Callable) -> Callable:
|
||||
Args:
|
||||
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
|
||||
arguments.
|
||||
has_side_effects: if True then it ensures that instances of this primitive
|
||||
are not removed or replicated by JAX optimizations such as dead-code
|
||||
elimination.
|
||||
|
||||
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
|
||||
op-by-op mode or in a staged context. This callable can be used with
|
||||
@ -153,7 +156,8 @@ def call_tf(callable_tf: Callable) -> Callable:
|
||||
# Carry the actual function such that op-by-op call can call in TF eager mode.
|
||||
callable_flat_tf=callable_flat_tf,
|
||||
function_flat_tf=function_flat_tf,
|
||||
args_flat_sig_tf=args_flat_sig_tf)
|
||||
args_flat_sig_tf=args_flat_sig_tf,
|
||||
has_side_effects=has_side_effects)
|
||||
return res_treedef.unflatten(res_jax_flat)
|
||||
|
||||
# Define the fwd and bwd custom_vjp functions
|
||||
@ -254,6 +258,7 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc
|
||||
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
|
||||
|
||||
|
||||
# Mark the effectful instancess of call_tf
|
||||
CallTfEffect = enum.Enum('CallTfEffect', ['EFFECT'])
|
||||
|
||||
mlir.lowerable_effects.add(CallTfEffect.EFFECT)
|
||||
@ -264,7 +269,8 @@ custom_derivatives.allowed_effects.add(CallTfEffect.EFFECT)
|
||||
|
||||
def _call_tf_abstract_eval(*_,
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf, **__):
|
||||
args_flat_sig_tf,
|
||||
has_side_effects, **__):
|
||||
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
|
||||
|
||||
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
|
||||
@ -272,6 +278,7 @@ def _call_tf_abstract_eval(*_,
|
||||
|
||||
def is_fully_known_shape(s):
|
||||
return s.rank is not None and all([d is not None for d in s])
|
||||
effects = {CallTfEffect.EFFECT} if has_side_effects else set()
|
||||
if all([is_fully_known_shape(s)
|
||||
for s in concrete_function_flat_tf.output_shapes]):
|
||||
return (
|
||||
@ -281,7 +288,7 @@ def _call_tf_abstract_eval(*_,
|
||||
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
|
||||
concrete_function_flat_tf.output_shapes)
|
||||
]),
|
||||
{CallTfEffect.EFFECT})
|
||||
effects)
|
||||
|
||||
# There are some cases when TF shape inference is not powerful enough to
|
||||
# figure out the output shapes (e.g., b/128924522), even in situations where
|
||||
@ -292,9 +299,7 @@ def _call_tf_abstract_eval(*_,
|
||||
# it should not matter which platform we use.
|
||||
_, result_avals = _code_generator_and_avals(function_flat_tf, args_flat_sig_tf,
|
||||
"CPU")
|
||||
# Add an effect to the abstract eval rule of call_tf so that JAX's DCE pass
|
||||
# doesn't prune args passed to call_tf.
|
||||
return tuple(result_avals), {CallTfEffect.EFFECT}
|
||||
return tuple(result_avals), effects
|
||||
|
||||
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
|
||||
|
||||
|
@ -624,6 +624,16 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
res = tf.function(f_tf2, autograph=False)(x)
|
||||
self.assertAllClose(res.numpy(), f_jax(x))
|
||||
|
||||
def test_effectful(self):
|
||||
if not config.jax_array:
|
||||
raise unittest.SkipTest("Test not intended to work without jax.Array")
|
||||
|
||||
x = np.ones((3,), dtype=np.float32)
|
||||
lower_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=True)).lower(x)
|
||||
self.assertNotEmpty(lower_effect._lowering.compile_args["unordered_effects"])
|
||||
|
||||
lower_no_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=False)).lower(x)
|
||||
self.assertEmpty(lower_no_effect._lowering.compile_args["unordered_effects"])
|
||||
|
||||
def test_module_documentation(self):
|
||||
def cos_tf(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user