mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00

Instead, we skip tests that the PJRT C API doesn't support. We had this tag for feature development so it was easy to broadly disable, but now we don't expect to need to do that.
256 lines
7.8 KiB
Python
256 lines
7.8 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""An example of using host_callback.call to invoke on the host functions
|
|
written in Tensorflow. The interesting aspect here is how we can differentiate
|
|
through the outside computation, using tf.GradientTape on the host.
|
|
|
|
This is separate from host_callback_test because it needs a TF dependency.
|
|
"""
|
|
from typing import Callable
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import jax
|
|
from jax.config import config
|
|
from jax import numpy as jnp
|
|
from jax._src import test_util as jtu
|
|
from jax._src import xla_bridge
|
|
from jax.experimental import host_callback as hcb
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
tf = None
|
|
|
|
config.parse_flags_with_absl()
|
|
FLAGS = config.FLAGS
|
|
|
|
|
|
def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):
|
|
"""The simplest implementation of calling to TF, without AD support.
|
|
|
|
We must use hcb.call because the TF invocation must happen outside the
|
|
JAX staged computation."""
|
|
|
|
def tf_to_numpy(t):
|
|
# Turn the Tensor to NumPy array without copying.
|
|
return np.asarray(memoryview(t)) if isinstance(t, tf.Tensor) else t
|
|
|
|
return hcb.call(lambda arg: tf.nest.map_structure(tf_to_numpy,
|
|
tf_fun(arg)),
|
|
arg, result_shape=result_shape)
|
|
|
|
|
|
def call_tf_simple_ad(tf_fun: Callable, arg, *, result_shape):
|
|
"""Calls a TensorFlow function with simple support for reverse AD.
|
|
|
|
Works only for 1st order AD and only for arguments and results being a single
|
|
ndarray (no pytrees). Functions whose name starts with "tf_" are TensorFlow
|
|
functions and must be called outside the JAX computation.
|
|
"""
|
|
|
|
@jax.custom_vjp
|
|
def make_call(arg):
|
|
"""We wrap it all in `make_call` so that we can attach custom VJP."""
|
|
return call_tf_no_ad(tf_fun, arg, result_shape=result_shape)
|
|
|
|
# Define the fwd and bwd custom_vjp functions
|
|
def make_call_vjp_fwd(arg):
|
|
# Return the primal argument as the residual. Use `make_call` for the
|
|
# primal computation to enable higher-order AD.
|
|
return make_call(arg), arg
|
|
|
|
def make_call_vjp_bwd(res, ct_res):
|
|
arg = res # residual is the primal argument
|
|
|
|
def tf_vjp_fun(arg_and_ct_res):
|
|
"""Invoke TF gradient; used with hcb.call."""
|
|
arg, ct_res = arg_and_ct_res
|
|
arg_var = tf.Variable(arg)
|
|
with tf.GradientTape(persistent=True) as tape:
|
|
res = tf_fun(arg_var)
|
|
|
|
dres_darg = tape.gradient(res, sources=arg_var,
|
|
output_gradients=ct_res,
|
|
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
|
return dres_darg
|
|
|
|
return (call_tf_simple_ad(tf_vjp_fun, (arg, ct_res),
|
|
result_shape=arg),)
|
|
|
|
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
|
|
return make_call(arg)
|
|
|
|
|
|
def call_tf_full_ad(tf_fun: Callable, arg, *, result_shape):
|
|
"""Calls a TensorFlow function with support for reverse AD.
|
|
|
|
Supports higher-order AD and pytree arguments.
|
|
"""
|
|
|
|
@jax.custom_vjp
|
|
def make_call(arg):
|
|
"""We wrap it all in `make_call` so that we can attach custom VJP."""
|
|
return call_tf_no_ad(tf_fun, arg, result_shape=result_shape)
|
|
|
|
# Define the fwd and bwd custom_vjp functions
|
|
def make_call_vjp_fwd(arg):
|
|
return make_call(arg), arg # Return the primal argument as the residual
|
|
|
|
def make_call_vjp_bwd(res, ct_res):
|
|
arg = res # residual is the primal argument
|
|
|
|
def tf_vjp_fun(arg_and_ct_res):
|
|
"""Invoke TF gradient; used with hcb.call."""
|
|
arg, ct_res = arg_and_ct_res
|
|
|
|
def make_var(a):
|
|
return a if isinstance(a, tf.Variable) else tf.Variable(a)
|
|
|
|
arg_var = tf.nest.map_structure(make_var, arg)
|
|
|
|
with tf.GradientTape(persistent=True) as tape:
|
|
res = tf_fun(arg_var)
|
|
|
|
tf.nest.assert_same_structure(res, ct_res)
|
|
accumulator = None # Accumulate argument cotangent. Same structure as "arg"
|
|
|
|
def acc_ct(res_, ct_res_):
|
|
dres_darg = tape.gradient(res_, sources=arg_var,
|
|
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
|
tf.nest.assert_same_structure(dres_darg, arg)
|
|
scaled_dres_darg = tf.nest.map_structure(lambda d: d * ct_res_, dres_darg)
|
|
nonlocal accumulator
|
|
accumulator = (scaled_dres_darg if accumulator is None
|
|
else tf.nest.map_structure(lambda x, y: x + y,
|
|
accumulator, scaled_dres_darg))
|
|
|
|
tf.nest.map_structure(acc_ct, res, ct_res)
|
|
return accumulator
|
|
|
|
return (call_tf_full_ad(tf_vjp_fun, (arg, ct_res),
|
|
result_shape=arg),)
|
|
|
|
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
|
|
return make_call(arg)
|
|
|
|
|
|
CALL_TF_IMPLEMENTATIONS = {
|
|
"none": call_tf_no_ad,
|
|
"simple": call_tf_simple_ad,
|
|
"full": call_tf_full_ad,
|
|
}
|
|
|
|
|
|
class CallToTFTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if tf is None:
|
|
raise unittest.SkipTest("Test requires tensorflow")
|
|
if xla_bridge.using_pjrt_c_api():
|
|
raise unittest.SkipTest("host_callback not implemented in PJRT C API")
|
|
super().setUp()
|
|
|
|
@parameterized.named_parameters(
|
|
dict(
|
|
testcase_name=f"_{ad=}",
|
|
ad=ad)
|
|
for ad in CALL_TF_IMPLEMENTATIONS.keys())
|
|
def test_impl(self, ad="simple"):
|
|
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
|
|
|
def f_jax(x):
|
|
return jnp.sin(x)
|
|
|
|
def f_outside(x):
|
|
return call_tf(tf.math.sin, x,
|
|
result_shape=x)
|
|
|
|
res = f_outside(3.)
|
|
self.assertAllClose(f_jax(3.), res)
|
|
self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.))
|
|
|
|
@parameterized.named_parameters(
|
|
dict(
|
|
testcase_name=f"_{ad=}",
|
|
ad=ad)
|
|
for ad in CALL_TF_IMPLEMENTATIONS.keys()
|
|
if ad != "none")
|
|
def test_grad(self, ad="simple"):
|
|
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
|
|
|
def f_jax(x):
|
|
return 3. * jnp.sin(2. * x)
|
|
|
|
def f_outside(x):
|
|
return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x)
|
|
|
|
x = 4.
|
|
self.assertAllClose(f_jax(x), f_outside(x))
|
|
|
|
grad_f = jax.grad(f_outside)(x)
|
|
self.assertAllClose(jax.grad(f_jax)(x), grad_f)
|
|
|
|
def test_grad_pytree(self):
|
|
call_tf = call_tf_full_ad
|
|
|
|
def f_jax(xy):
|
|
dict_ab = dict(a=2. * xy[0], b=xy[0] * xy[1])
|
|
return 3. * dict_ab["a"] + 4. * dict_ab["b"]
|
|
|
|
def f_outside(xy):
|
|
dict_ab = call_tf(
|
|
lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]),
|
|
xy,
|
|
result_shape=dict(a=xy[0], b=xy[1]))
|
|
return 3. * dict_ab["a"] + 4. * dict_ab["b"]
|
|
|
|
xy = (5., 6.)
|
|
self.assertAllClose(f_jax(xy), f_outside(xy))
|
|
res_jax = jax.grad(f_jax)(xy)
|
|
self.assertAllClose(res_jax, jax.grad(f_outside)(xy))
|
|
|
|
@parameterized.named_parameters(
|
|
dict(
|
|
testcase_name=f"_degree=_{degree}",
|
|
degree=degree)
|
|
for degree in [1, 2, 3, 4])
|
|
def test_higher_order_grad(self, degree=4):
|
|
call_tf = call_tf_full_ad
|
|
|
|
def f_jax(x):
|
|
return 2. * x * x * x
|
|
|
|
def f_outside(x):
|
|
return 2. * call_tf(lambda y: y * y * y, x,
|
|
result_shape=x)
|
|
|
|
grad_jax = f_jax
|
|
grad_outside = f_outside
|
|
for i in range(degree):
|
|
grad_jax = jax.grad(grad_jax)
|
|
grad_outside = jax.grad(grad_outside)
|
|
|
|
res_jax = grad_jax(5.)
|
|
self.assertAllClose(res_jax, grad_outside(5.))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|