Support None leaves in arguments to gradient of a call_tf wrapped function.

PiperOrigin-RevId: 662115139
This commit is contained in:
Zhuo Peng 2024-08-12 09:23:40 -07:00 committed by jax authors
parent 4eb5ef28ef
commit ad74e55dbc
2 changed files with 21 additions and 6 deletions

View File

@ -224,9 +224,11 @@ def call_tf(
def tf_vjp_fun(args_tf, ct_res_tf):
"""Invoke TF gradient."""
# TF does not like us to watch non-float vars
def replace_non_float(arg_tf):
if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex:
# TF does not like us to watch non-float vars or Nones.
def replace_non_float_or_none(arg_tf):
if arg_tf is not None and (
arg_tf.dtype.is_floating or arg_tf.dtype.is_complex
):
return arg_tf
else:
# When watched, this will be ignored. When used in results it will
@ -234,17 +236,20 @@ def call_tf(
# replace it with a float0)
return tf.zeros((), dtype=tf.float32)
watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
watched_args_tf = tf.nest.map_structure(
replace_non_float_or_none, args_tf
)
with tf.GradientTape(persistent=True) as tape:
tape.watch(watched_args_tf)
res = callable_tf(*args_tf)
tf.nest.assert_same_structure(res, ct_res_tf)
dres_darg = tape.gradient(
tf.nest.map_structure(replace_non_float, res),
tf.nest.map_structure(replace_non_float_or_none, res),
sources=watched_args_tf,
output_gradients=ct_res_tf,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
dres_darg = tree_util.tree_map(
lambda x: x if x is None else tf.convert_to_tensor(x),

View File

@ -1136,6 +1136,16 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
# Jit mode
self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x))
def test_grad_pytree_arg_with_none_leaf(self):
def tf_f(x, params):
return x * params["y"]
x = jnp.array(1.0)
y = jnp.array(2.0)
actual = jax.grad(
jax2tf.call_tf(tf_f), argnums=(1,))(x, {"y": y, "other": None})
self.assertDictEqual(actual[0], {"y": x, "other": None})
class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
"Reloading output of call_tf into TF with jax2tf."