rocm_jax/tests/host_callback_to_tf_test.py
Sergei Lebedev cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00

255 lines
7.8 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example of using host_callback.call to invoke on the host functions
written in Tensorflow. The interesting aspect here is how we can differentiate
through the outside computation, using tf.GradientTape on the host.
This is separate from host_callback_test because it needs a TF dependency.
"""
from typing import Callable
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax.experimental import host_callback as hcb
import numpy as np
try:
import tensorflow as tf
except ImportError:
tf = None
config.parse_flags_with_absl()
def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):
"""The simplest implementation of calling to TF, without AD support.
We must use hcb.call because the TF invocation must happen outside the
JAX staged computation."""
def tf_to_numpy(t):
# Turn the Tensor to NumPy array without copying.
return np.asarray(memoryview(t)) if isinstance(t, tf.Tensor) else t
return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy,
tf_fun(arg)),
arg, result_shape=result_shape)
def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape):
"""Calls a TensorFlow function with simple support for reverse AD.
Works only for 1st order AD and only for arguments and results being a single
ndarray (no pytrees). Functions whose name starts with "tf_" are TensorFlow
functions and must be called outside the JAX computation.
"""
@jax.custom_vjp
def make_call(arg):
"""We wrap it all in `make_call` so that we can attach custom VJP."""
return call_tf_no_ad(tf_fun, arg, result_shape=result_shape)
# Define the fwd and bwd custom_vjp functions
def make_call_vjp_fwd(arg):
# Return the primal argument as the residual. Use `make_call` for the
# primal computation to enable higher-order AD.
return make_call(arg), arg
def make_call_vjp_bwd(res, ct_res):
arg = res # residual is the primal argument
def tf_vjp_fun(arg_and_ct_res):
"""Invoke TF gradient; used with hcb.call."""
arg, ct_res = arg_and_ct_res
arg_var = tf.Variable(arg)
with tf.GradientTape(persistent=True) as tape:
res = tf_fun(arg_var)
dres_darg = tape.gradient(res, sources=arg_var,
output_gradients=ct_res,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
return dres_darg
return (call_tf_simple_ad(tf_vjp_fun, (arg, ct_res),
result_shape=arg),)
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
return make_call(arg)
def call_tf_full_ad(tf_fun: Callable, arg, *, result_shape):
"""Calls a TensorFlow function with support for reverse AD.
Supports higher-order AD and pytree arguments.
"""
@jax.custom_vjp
def make_call(arg):
"""We wrap it all in `make_call` so that we can attach custom VJP."""
return call_tf_no_ad(tf_fun, arg, result_shape=result_shape)
# Define the fwd and bwd custom_vjp functions
def make_call_vjp_fwd(arg):
return make_call(arg), arg # Return the primal argument as the residual
def make_call_vjp_bwd(res, ct_res):
arg = res # residual is the primal argument
def tf_vjp_fun(arg_and_ct_res):
"""Invoke TF gradient; used with hcb.call."""
arg, ct_res = arg_and_ct_res
def make_var(a):
return a if isinstance(a, tf.Variable) else tf.Variable(a)
arg_var = tf.nest.map_structure(make_var, arg)
with tf.GradientTape(persistent=True) as tape:
res = tf_fun(arg_var)
tf.nest.assert_same_structure(res, ct_res)
accumulator = None # Accumulate argument cotangent. Same structure as "arg"
def acc_ct(res_, ct_res_):
dres_darg = tape.gradient(res_, sources=arg_var,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
tf.nest.assert_same_structure(dres_darg, arg)
scaled_dres_darg = tf.nest.map_structure(lambda d: d * ct_res_, dres_darg)
nonlocal accumulator
accumulator = (scaled_dres_darg if accumulator is None
else tf.nest.map_structure(lambda x, y: x + y,
accumulator, scaled_dres_darg))
tf.nest.map_structure(acc_ct, res, ct_res)
return accumulator
return (call_tf_full_ad(tf_vjp_fun, (arg, ct_res),
result_shape=arg),)
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
return make_call(arg)
CALL_TF_IMPLEMENTATIONS = {
"none": call_tf_no_ad,
"simple": call_tf_simple_ad,
"full": call_tf_full_ad,
}
class CallToTFTest(jtu.JaxTestCase):
def setUp(self):
if tf is None:
raise unittest.SkipTest("Test requires tensorflow")
if xla_bridge.using_pjrt_c_api():
raise unittest.SkipTest("host_callback not implemented in PJRT C API")
super().setUp()
@parameterized.named_parameters(
dict(
testcase_name=f"_{ad=}",
ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys())
def test_impl(self, ad="simple"):
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
def f_jax(x):
return jnp.sin(x)
def f_outside(x):
return call_tf(tf.math.sin, x,
result_shape=x)
res = f_outside(3.)
self.assertAllClose(f_jax(3.), res)
self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.))
@parameterized.named_parameters(
dict(
testcase_name=f"_{ad=}",
ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys()
if ad != "none")
def test_grad(self, ad="simple"):
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
def f_jax(x):
return 3. * jnp.sin(2. * x)
def f_outside(x):
return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x)
x = 4.
self.assertAllClose(f_jax(x), f_outside(x))
grad_f = jax.grad(f_outside)(x)
self.assertAllClose(jax.grad(f_jax)(x), grad_f)
def test_grad_pytree(self):
call_tf = call_tf_full_ad
def f_jax(xy):
dict_ab = dict(a=2. * xy[0], b=xy[0] * xy[1])
return 3. * dict_ab["a"] + 4. * dict_ab["b"]
def f_outside(xy):
dict_ab = call_tf(
lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]),
xy,
result_shape=dict(a=xy[0], b=xy[1]))
return 3. * dict_ab["a"] + 4. * dict_ab["b"]
xy = (5., 6.)
self.assertAllClose(f_jax(xy), f_outside(xy))
res_jax = jax.grad(f_jax)(xy)
self.assertAllClose(res_jax, jax.grad(f_outside)(xy))
@parameterized.named_parameters(
dict(
testcase_name=f"_degree=_{degree}",
degree=degree)
for degree in [1, 2, 3, 4])
def test_higher_order_grad(self, degree=4):
call_tf = call_tf_full_ad
def f_jax(x):
return 2. * x * x * x
def f_outside(x):
return 2. * call_tf(lambda y: y * y * y, x,
result_shape=x)
grad_jax = f_jax
grad_outside = f_outside
for i in range(degree):
grad_jax = jax.grad(grad_jax)
grad_outside = jax.grad(grad_outside)
res_jax = grad_jax(5.)
self.assertAllClose(res_jax, grad_outside(5.))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())