mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Support None
leaves in arguments to gradient of a call_tf wrapped function.
PiperOrigin-RevId: 662115139
This commit is contained in:
parent
4eb5ef28ef
commit
ad74e55dbc
@ -224,9 +224,11 @@ def call_tf(
|
||||
def tf_vjp_fun(args_tf, ct_res_tf):
|
||||
"""Invoke TF gradient."""
|
||||
|
||||
# TF does not like us to watch non-float vars
|
||||
def replace_non_float(arg_tf):
|
||||
if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex:
|
||||
# TF does not like us to watch non-float vars or Nones.
|
||||
def replace_non_float_or_none(arg_tf):
|
||||
if arg_tf is not None and (
|
||||
arg_tf.dtype.is_floating or arg_tf.dtype.is_complex
|
||||
):
|
||||
return arg_tf
|
||||
else:
|
||||
# When watched, this will be ignored. When used in results it will
|
||||
@ -234,17 +236,20 @@ def call_tf(
|
||||
# replace it with a float0)
|
||||
return tf.zeros((), dtype=tf.float32)
|
||||
|
||||
watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
|
||||
watched_args_tf = tf.nest.map_structure(
|
||||
replace_non_float_or_none, args_tf
|
||||
)
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
tape.watch(watched_args_tf)
|
||||
res = callable_tf(*args_tf)
|
||||
|
||||
tf.nest.assert_same_structure(res, ct_res_tf)
|
||||
dres_darg = tape.gradient(
|
||||
tf.nest.map_structure(replace_non_float, res),
|
||||
tf.nest.map_structure(replace_non_float_or_none, res),
|
||||
sources=watched_args_tf,
|
||||
output_gradients=ct_res_tf,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
|
||||
dres_darg = tree_util.tree_map(
|
||||
lambda x: x if x is None else tf.convert_to_tensor(x),
|
||||
|
@ -1136,6 +1136,16 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
||||
# Jit mode
|
||||
self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x))
|
||||
|
||||
def test_grad_pytree_arg_with_none_leaf(self):
|
||||
def tf_f(x, params):
|
||||
return x * params["y"]
|
||||
|
||||
x = jnp.array(1.0)
|
||||
y = jnp.array(2.0)
|
||||
actual = jax.grad(
|
||||
jax2tf.call_tf(tf_f), argnums=(1,))(x, {"y": y, "other": None})
|
||||
self.assertDictEqual(actual[0], {"y": x, "other": None})
|
||||
|
||||
|
||||
class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
"Reloading output of call_tf into TF with jax2tf."
|
||||
|
Loading…
x
Reference in New Issue
Block a user