2022-10-12 15:20:53 +00:00

1383 lines
51 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for JAX2TF converted.
Specific JAX primitive conversion tests are in primitives_test."""
import collections
import os
from typing import Callable, Dict, Optional, Tuple
import unittest
from absl import logging
from absl.testing import absltest
import jax
from jax import ad_checkpoint
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax._src import lib as jaxlib
from jax._src import source_info_util
from jax._src import test_util as jtu
import jax._src.lib.xla_bridge
from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.pjit import FROM_GDA
from jax.experimental.pjit import pjit
from jax.interpreters import mlir
from jax.interpreters.pxla import PartitionSpec as P
import numpy as np
import tensorflow as tf # type: ignore[import]
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
config.parse_flags_with_absl()
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
# TODO(b/252943725): re-enable these tests
if config.jax_array and config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("Test disabled for JAX_ARRAY")
def test_empty(self):
f_jax = lambda x, y: x
self.ConvertAndCompare(f_jax, 0.7, 1)
def test_basics(self):
f_jax = lambda x: jnp.sin(jnp.cos(x))
self.ConvertAndCompare(f_jax, 0.7)
def test_input_output_naming(self):
@jax2tf.convert
def f(xs, y):
return [jnp.add(x, y) for x in xs]
@tf.function(autograph=False)
def u(xs, y):
xs = tf.nest.map_structure(tf.convert_to_tensor, xs)
with tf.GradientTape() as tape:
tf.nest.map_structure(tape.watch, xs)
y = f(xs, y)
tape.gradient(y, xs)
return y
cf = u.get_concrete_function([1., 2., 3.], 4.)
g = cf.graph
g.get_operation_by_name("jax2tf_arg_0")
g.get_operation_by_name("jax2tf_arg_1")
g.get_operation_by_name("jax2tf_arg_2")
g.get_operation_by_name("jax2tf_arg_3")
g.get_operation_by_name("jax2tf_out")
g.get_operation_by_name("jax2tf_out_1")
g.get_operation_by_name("jax2tf_out_2")
with self.assertRaises(KeyError):
g.get_operation_by_name("jax2tf_arg_4")
with self.assertRaises(KeyError):
g.get_operation_by_name("jax2tf_out_3")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_0")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_2")
g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_3")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out_1")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out_2")
g.get_operation_by_name("jax2tf_vjp/jax2tf_out_3")
def test_pytrees(self):
# Take and return pytrees
def f_jax(x: Tuple[float, Dict[str, float]]) -> Tuple[float, Dict[str, float]]:
x_a, x_dict = x
return x_a * 2., {k: v * 3. for k, v in x_dict.items()}
x = (.7, {"a": .8, "b": .9})
self.ConvertAndCompare(f_jax, x)
def test_variable_input(self):
f_jax = lambda x: jnp.sin(jnp.cos(x))
f_tf = jax2tf.convert(f_jax)
v = tf.Variable(0.7, dtype=jax2tf.dtype_of_val(0.7))
self.assertIsInstance(f_tf(v), tf.Tensor)
self.assertAllClose(f_jax(0.7), f_tf(v))
def test_jit(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
self.ConvertAndCompare(f_jax, 0.7)
def test_nested_jit(self):
f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
x = 0.7
self.ConvertAndCompare(f_jax, x)
def test_nested_jit_pytree(self):
@jax.jit
def f_jax(xy):
x, y = xy
return x + y
xy = (0.7, 0.8)
self.ConvertAndCompare(f_jax, xy)
def test_nested_jit_is_compiled(self):
# Check that nested jax.jit are compiled with tf.function(jit_compile=True)
# We do this by looking for the _XlaMustCompile attribute in the function graph
def has_xla_must_compile(f_tf, x):
f_conc = tf.function(f_tf, autograph=True).get_concrete_function(tf.convert_to_tensor(x))
for n in f_conc.graph._nodes_by_id.values():
try:
n.get_attr("_XlaMustCompile")
return True
except ValueError:
continue
return False
x = np.array(0.7)
f_no_jit = lambda x: x
self.assertFalse(has_xla_must_compile(jax2tf.convert(f_no_jit), x))
f_jit = lambda x: jax.jit(jnp.sin)(x)
# TODO(b/207464757): TF compilation is disabled
self.assertFalse(has_xla_must_compile(jax2tf.convert(f_jit), x))
def test_converts_jax_arrays(self):
f_tf = tf.function(lambda x: x)
self.assertEqual(f_tf(jnp.zeros([])).numpy(), 0.)
self.assertEqual(f_tf(jnp.ones([])).numpy(), 1.)
f_tf = tf.function(lambda x: x + x)
self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.)
# Test with ShardedDeviceArray.
n = jax.local_device_count()
mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n]))
f_tf = tf.function(lambda x: x)
self.assertAllClose(f_tf(mk_sharded(jnp.zeros)).numpy(),
jnp.zeros([n]))
self.assertAllClose(f_tf(mk_sharded(jnp.ones)).numpy(),
jnp.ones([n]))
@jtu.skip_on_devices("gpu")
def test_bfloat16_passed_by_tf(self):
f_jax = lambda a, b: a + b
f_tf = tf.function(jax2tf.convert(f_jax),
input_signature=[tf.TensorSpec([512, 512], tf.bfloat16),
tf.TensorSpec([512, 512], tf.bfloat16)])
self.assertIsNotNone(f_tf.get_concrete_function())
@jtu.skip_on_devices("gpu")
def test_bfloat16_returned_by_jax(self):
f_jax = lambda a, b: (a + b).astype(jnp.bfloat16)
f_tf = jax2tf.convert(f_jax)
self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)
@jtu.skip_on_devices("gpu")
def test_bfloat16_tf_grad(self):
f_jax = lambda a, b: a + b
def _tf_grad(a, b):
with tf.GradientTape() as tape:
tape.watch(a)
result = jax2tf.convert(f_jax)(a, b)
return result, tape.gradient(result, a)
f_tf = tf.function(_tf_grad,
input_signature=[tf.TensorSpec([512, 512], tf.bfloat16),
tf.TensorSpec([512, 512], tf.bfloat16)])
self.assertIsNotNone(f_tf.get_concrete_function())
@jtu.sample_product(
dtype=[np.int64, np.float64],
with_function=[True, False],
)
def test_converts_64bit(self, dtype=np.int64, with_function=False):
if not config.jax_enable_x64:
self.skipTest("requires x64 mode")
big_const = np.full((5,), 2 ** 33, dtype=dtype)
self.ConvertAndCompare(jnp.sin, big_const)
f_conv = jax2tf.convert(jnp.sin)
if with_function:
f_conv = tf.function(f_conv)
# We check also when we pass tf.Variable or tf.Tensor into the
# converted function
self.assertAllClose(jnp.sin(big_const),
f_conv(tf.Variable(big_const)))
self.assertAllClose(jnp.sin(big_const),
f_conv(tf.constant(big_const)))
def test_64bit_behavior_enable_x64(self):
if not config.jax_enable_x64:
self.skipTest("requires x64 mode")
# JAX and TF have different default float types if JAX_ENABLE_X64=1
self.assertEqual(tf.math.sin(0.7).dtype, tf.float32)
self.assertEqual(jnp.sin(0.7).dtype, jnp.float64)
# jax2tf.convert has the same behavior as JAX
self.assertEqual(jax2tf.convert(jnp.sin)(0.7).dtype, tf.float64)
def test_64bit_behavior_not_enable_x64(self):
if config.jax_enable_x64:
self.skipTest("requires not x64 mode")
# JAX and TF have same default float types if JAX_ENABLE_X64=1
self.assertEqual(tf.math.sin(0.7).dtype, tf.float32)
self.assertEqual(jnp.sin(0.7).dtype, jnp.float32)
# Except that JAX forces values to 32-bit
self.assertEqual(jnp.sin(np.float64(0.7)).dtype, jnp.float32)
# jax2tf.convert has the same behavior as JAX
self.assertEqual(jax2tf.convert(jnp.sin)(0.7).dtype, tf.float32)
self.assertEqual(jax2tf.convert(jnp.sin)(np.float64(0.7)).dtype, tf.float32)
def test_function(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
self.ConvertAndCompare(f_jax, 0.7)
@jtu.sample_product(with_function=[False, True])
def test_gradients_disabled(self, with_function=False):
f_tf = jax2tf.convert(jnp.tan, with_gradient=False)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
x = tf.ones([])
# With tf.function the error is raised when we evaluate f_tf(x), in
# eager mode when we evaluate tape.gradient(y, x)
with self.assertRaisesRegex(LookupError,
"Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
with tf.GradientTape() as tape:
tape.watch(x)
y = f_tf(x)
_ = tape.gradient(y, x)
@jtu.sample_product(with_function=[False, True])
def test_gradients(self, with_function=True):
def f(x, y):
return x * x, x * y
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
default_float_type = jax2tf.dtype_of_val(4.)
x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
y = tf.Variable(5., dtype=default_float_type)
with tf.GradientTape(persistent=True) as tape:
u, v = f_tf(x, y)
self.assertAllClose(2. * 4., tape.gradient(u, x))
self.assertAllClose(0., tape.gradient(u, y))
self.assertAllClose(5., tape.gradient(v, x))
self.assertAllClose(4., tape.gradient(v, y))
@jtu.sample_product(with_function=[False, True])
def test_gradients_pytree(self, with_function=True):
def f(xy: Tuple[float, float]) -> Dict[str, float]:
x, y = xy
return dict(one=x * x, two=x * y)
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
default_float_dtype = jax2tf.dtype_of_val(4.)
x = tf.Variable(4., dtype=default_float_dtype)
y = tf.Variable(5., dtype=default_float_dtype)
with tf.GradientTape(persistent=True) as tape:
uv = f_tf((x, y))
self.assertAllClose(2. * 4., tape.gradient(uv["one"], x))
self.assertAllClose(0., tape.gradient(uv["one"], y))
self.assertAllClose(5., tape.gradient(uv["two"], x))
self.assertAllClose(4., tape.gradient(uv["two"], y))
def test_custom_pytree_readme(self):
# Code examples from README.md
class CustomPair:
def __init__(self, a, b):
self.a = a
self.b = b
jax.tree_util.register_pytree_node(CustomPair,
lambda x: ((x.a, x.b), None),
lambda _, ab: CustomPair(*ab))
def f_jax(pair: CustomPair):
return np.float32(2.) * pair.a + np.float32(3.) * pair.b
f_tf = jax2tf.convert(f_jax)
x = CustomPair(np.float32(4.), np.float32(5.))
res_jax = f_jax(x)
# TF execution works as long as JAX can flatten the arguments and results
res_tf = f_tf(x)
self.assertAllClose(res_jax, res_tf.numpy())
res_tf_2 = tf.function(f_tf, autograph=False, jit_compile=True)(x)
self.assertAllClose(res_jax, res_tf_2)
# wrapped TF function to use only standard containers
def f_tf_wrapped(a, b):
return f_tf(CustomPair(a, b))
# Try to put into SavedModel
my_model = tf.Module()
# Save a function that can take scalar inputs.
my_model.f = tf.function(f_tf_wrapped, autograph=False,
input_signature=[tf.TensorSpec([], tf.float32),
tf.TensorSpec([], tf.float32)])
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(my_model)))
tf.saved_model.save(my_model, model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
# Restoring (note: the restored model does *not* require JAX to run, just XLA).
restored_model = tf.saved_model.load(model_dir)
def restored_f(pair: CustomPair):
return restored_model.f(pair.a, pair.b)
res_tf_3 = restored_f(x)
self.assertAllClose(res_jax, res_tf_3)
grad_jax = jax.grad(f_jax)(x)
x_v = [tf.Variable(x.a), tf.Variable(x.b)]
with tf.GradientTape() as tape:
res = f_tf_wrapped(*x_v)
grad_tf = tape.gradient(res, x_v)
self.assertAllClose(grad_jax.a, grad_tf[0])
self.assertAllClose(grad_jax.b, grad_tf[1])
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_ordered_dict_input(self, with_function=True):
def f(inputs):
out = 0.0
for v in inputs.values():
out += jnp.sum(v)
return out
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
default_float_type = jax2tf.dtype_of_val(4.)
x = tf.Variable([4.], dtype=default_float_type)
y = tf.Variable([4., 5.], dtype=default_float_type)
inputs = collections.OrderedDict()
inputs['r'] = x
inputs['d'] = y
with tf.GradientTape(persistent=True) as tape:
u = f_tf(inputs)
self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy())
self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_custom_jvp(self, with_function=True):
"""Check gradients, for a function with custom JVP."""
@jax.custom_jvp
def f(x):
return x * x
@f.defjvp
def f_jvp(primals, tangents):
# 3 * x * x_t
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = 3. * x * x_dot
return primal_out, tangent_out
self.assertAllClose(4. * 4., f(4.))
self.assertAllClose(3. * 4., jax.grad(f)(4.))
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
self.assertAllClose(4. * 4., f_tf(4.))
x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
with tf.GradientTape() as tape:
tape.watch(x)
y = f_tf(x)
self.assertAllClose(4. * 4., y)
self.assertAllClose(3. * 4., tape.gradient(y, x))
@jtu.sample_product(with_function=[False, True])
def test_gradients_with_custom_vjp(self, with_function=True):
"""Check gradients, for a function with custom VJP."""
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), 3. * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
self.assertAllClose(4. * 4., f(4.))
self.assertAllClose(3. * 4., jax.grad(f)(4.))
f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
self.assertAllClose(4. * 4., f_tf(4.))
x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
with tf.GradientTape() as tape:
tape.watch(x)
y = f_tf(x)
self.assertAllClose(4. * 4., y)
self.assertAllClose(3. * 4., tape.gradient(y, x))
def test_gradient_with_float0_intermediate(self):
# Gradient over integer-argument functions
def f(x, y): # x is an int, y is a float
return 2 * x + y
def g(x): # x: f32
return 2. * f(3 * x.astype("int32"), x * 4.)
x = 2.
grad_g = jax.grad(g)
self.ConvertAndCompare(grad_g, x)
def test_gradient_with_float0_result(self):
# Gradient over integer-argument functions, with float0 result
def f(x, y): # x is an int, y is a float
return 2 * x + y
def g(x): # x: i32
return jnp.sum(2. * f(3 * x, 4. * jnp.array(x, jnp.dtype("float32"))))
grad_g = jax.grad(g, allow_int=True)
x = 2
d_dx_jax = grad_g(x)
d_dx_tf = jax2tf.convert(grad_g)(x)
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32),
d_dx_tf.numpy())
shape = (3, 4)
x = np.ones(shape, dtype=np.int32)
d_dx_jax = grad_g(x)
d_dx_tf = jax2tf.convert(grad_g)(x)
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32),
d_dx_tf.numpy())
@jtu.sample_product(with_function=[False, True])
def test_gradients_unused_argument_readme(self, with_function=False):
# x1 and x3 are not used. x3 has integer type.
def fn(x0, x1, x2, x3):
return x0 * 0. + x2 * 2.
xs = [tf.Variable(x) for x in [10., 11., 12., 13]]
with tf.GradientTape(persistent=True) as tape:
res = fn(*xs)
g_tf_native = tape.gradient(res, xs)
self.assertAllClose(g_tf_native[0].numpy(), np.float32(0.))
self.assertIsNone(g_tf_native[1])
self.assertAllClose(g_tf_native[2].numpy(), np.float32(2.))
self.assertIsNone(g_tf_native[3])
g_tf_native_0 = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.assertAllClose(g_tf_native_0[0].numpy(), np.float32(0.))
self.assertAllClose(g_tf_native_0[1].numpy(), np.float32(0.))
self.assertAllClose(g_tf_native_0[2].numpy(), np.float32(2.))
self.assertAllClose(g_tf_native_0[3].numpy(), np.int32(0))
# Now with jax2tf.convert
with tf.GradientTape(persistent=True) as tape:
conv_fn = jax2tf.convert(fn, with_gradient=True)
if with_function:
conv_fn = tf.function(conv_fn, autograph=False)
res = conv_fn(*xs)
g_jax2tf = tape.gradient(res, xs)
# Returns: 0., 0., 2., None
# Note that the gradient for x1 is 0.
self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
self.assertIsNone(g_jax2tf[3])
g_jax2tf = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0))
@jtu.sample_product(with_function=[False, True])
def test_gradients_int_argument(self, with_function=False):
# https://github.com/google/jax/issues/6975
# Also issue #6975.
# An expanded version of test_gradients_unused_argument
state = dict(
float_used=np.array([0.7, 0.9], dtype=np.float32),
float_passthrough=np.float16(1.),
float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32),
int_used=np.int16(5),
int_passthrough=np.int8(7),
int_unused=np.array([1, 2, 3], dtype=np.uint32),
bool_used=np.array([True, False, False, True], dtype=np.bool_),
bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_),
bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_),
)
def jax_f(state):
res = dict(state,
float_used=2. * state["float_used"],
int_used=3 * state["int_used"],
bool_used=(state["bool_used"] == state["bool_used"]))
del res["float_unused"]
del res["int_unused"]
del res["bool_unused"]
return res
args = (state,)
res_jax = jax_f(*args)
# Native JAX AD
vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax)
grad_jax, = vjp_jax_fun(*args_vjp)
def compare_with_overrides(*, what, expected, **expected_overrides):
what_keys = set(what.keys())
expected_keys = set(expected.keys())
self.assertEqual(what_keys, expected_keys)
for k, w in what.items():
e = expected[k]
if k in expected_overrides:
if expected_overrides[k] == "ZERO":
e = np.zeros_like(w)
elif expected_overrides[k] == "ZERO_INT32":
e = np.zeros(np.shape(w), dtype=np.int32)
elif expected_overrides[k] == "ONE":
e = np.ones_like(w)
else:
e = expected_overrides[k]
if e is None:
self.assertIsNone(w, msg=k)
else:
self.assertIsNotNone(w, msg=k)
w = w.numpy() if isinstance(w, tf.Tensor) else e
e = e.numpy() if isinstance(e, tf.Tensor) else e
try:
self.assertAllClose(e, w, err_msg=k)
except:
print(f"Failed at {k}")
raise
# compare_with_overrides(g_jax, {},
# bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0),
# bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0),
# bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0),
# float_passthrough=np.ones_like(state["float_passthrough"]),
# float_unused=np.zeros_like(state["float_unused"]),
# float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype),
# int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0),
# int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0),
# int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0))
# Now native TF gradients, only to test how native TF AD works
_, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad(
jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO)
compare_with_overrides(what=grad_tf_0,
expected=grad_jax,
float_unused="ZERO",
bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO",
int_used="ZERO", int_passthrough="ONE", int_unused="ZERO")
_, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad(
jax_f, args,
unconnected_gradients=tf.UnconnectedGradients.NONE)
compare_with_overrides(what=grad_tf_None,
expected=grad_tf_0,
float_unused=None, int_used=None, int_unused=None,
bool_used=None, bool_unused=None)
f_tf_jax = jax2tf.convert(jax_f)
if with_function:
f_tf_jax = tf.function(f_tf_jax, autograph=False)
_, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args)
# Same results as TF native AD with tf.UnconnectedGradients.ZERO
compare_with_overrides(what=grad_tf_jax_0,
expected=grad_tf_0,
int_passthrough="ZERO", bool_passthrough="ZERO")
_, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad(
f_tf_jax, args,
unconnected_gradients=tf.UnconnectedGradients.NONE)
compare_with_overrides(what=grad_tf_jax_None,
expected=grad_tf_0,
int_used=None, int_passthrough=None, int_unused=None,
bool_unused=None, bool_used=None, bool_passthrough=None)
# Not convert the JAX gradient function
tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun)
grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp)
compare_with_overrides(what=grad_tf_vjp_jax,
expected=grad_tf_0,
bool_passthrough="ZERO_INT32",
bool_unused="ZERO_INT32", bool_used="ZERO_INT32",
int_passthrough="ZERO_INT32", int_unused="ZERO_INT32",
int_used="ZERO_INT32")
def test_readme_gradient_int(self):
x = np.array(2, dtype=np.int16)
def f_jax(x): # x: int16
return x.astype(np.float32) * 2.
print(jax.grad(f_jax, allow_int=True)(x))
# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])
print(jax2tf.convert(jax.grad(f_jax, allow_int=True))(x))
# returns a 0 with same shape as x, but with dtype int32
def f_tf(x): # x: int16
return tf.cast(x, tf.float32) * 2.
xv = tf.Variable(x)
with tf.GradientTape(persistent=True) as tape:
print(tape.gradient(f_tf(xv), xv))
# returns None
print(tape.gradient(f_tf(xv), xv,
unconnected_gradients=tf.UnconnectedGradients.ZERO))
# returns 0 with the same shape and dtype as x
def test_convert_argument_non_callable_error(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
jax2tf.convert(5.)
def test_convert_argument_non_tensor_error(self):
with self.assertRaisesRegex(TypeError,
"Argument.*should be NumPy array"):
jax2tf.convert(lambda x: x)(lambda y: y)
def test_argument_eager_tensor(self):
x = jax2tf.convert(jnp.sin)(1.)
jax2tf.convert(jnp.cos)(x) # No error
def test_checkpoint_wrapper_types(self):
m = tf.Module()
m.a = [tf.Module(), tf.Module()]
m.b = (tf.Module(), tf.Module())
m.c = {'a': tf.Module(), 'b': tf.Module()}
self.assertNotEqual(type(m.a), list)
self.assertNotEqual(type(m.b), tuple)
self.assertNotEqual(type(m.c), dict)
self.assertLen(jax.tree_util.tree_leaves(m.a), 2)
self.assertLen(jax.tree_util.tree_leaves(m.b), 2)
self.assertLen(jax.tree_util.tree_leaves(m.c), 2)
def test_issue_10586(self):
class JaxModule(tf.Module):
def __init__(self):
self._params = {'w': tf.Variable(tf.ones([784, 10]), name='w'),
'b': tf.Variable(tf.ones([10]), name='b')}
def __call__(self, x):
return jax2tf.convert(lambda p, x: x @ p['w'] + p['b'])(self._params, x)
net = JaxModule()
images = tf.ones([1, 784])
with tf.GradientTape() as tape:
loss = tf.reduce_sum(net(images))
params = tape.watched_variables()
grads = tape.gradient(loss, params)
for var, grad in zip(params, grads):
self.assertEqual(var.shape, grad.shape, msg=var.name)
def test_custom_jvp(self):
"""Conversion of function with custom JVP"""
@jax.custom_jvp
def f(x):
return x * x
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = 3. * x * x_dot
return primal_out, tangent_out
arg = 0.7
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "jvp")
self.TransformConvertAndCompare(f, arg, "vmap")
self.TransformConvertAndCompare(f, arg, "jvp_vmap")
self.TransformConvertAndCompare(f, arg, "grad")
self.TransformConvertAndCompare(f, arg, "grad_vmap")
def test_custom_vjp(self):
"""Conversion of function with custom VJP"""
@jax.custom_vjp
def f(x):
return x * x
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), 3. * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * ct_b,
f.defvjp(f_fwd, f_bwd)
arg = 0.7
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "vmap")
self.TransformConvertAndCompare(f, arg, "grad")
self.TransformConvertAndCompare(f, arg, "grad_vmap")
def test_remat(self):
def f(x1):
x2 = jnp.sin(x1)
x3 = jnp.sin(x2)
x4 = jnp.sin(x3)
return x4
remat_f = ad_checkpoint.checkpoint(f)
# The computation of grad_f computes "sin" 5 times, 3 for the forward pass
# and then to rematerialize "x2" and "x3" in the backward pass.
arg = np.array(3.)
# Check that we have a Sin under a conditional
f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False)
f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def()
if jax.config.jax_remat_opt_barrier:
if config.jax2tf_default_experimental_native_lowering:
self.assertRegex(
str(f_tf_graph), r"mhlo.optimization_barrier")
else:
self.assertRegex(
str(f_tf_graph), r"XlaOptimizationBarrier")
elif config.jax_experimental_name_stack:
self.assertRegex(str(f_tf_graph),
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
else:
self.assertRegex(str(f_tf_graph),
r'switch_case/indexed_case/Sin')
def test_remat_free_var(self):
def f(x):
y = 2 * x
@ad_checkpoint.checkpoint
def g():
return y
return g()
arg = 3.
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "grad")
def test_checkpoint_name(self):
def f_jax(x):
return ad_checkpoint.checkpoint_name(jnp.sin(x), "sin")
jax2tf.convert(f_jax)(1.) # No error.
def test_convert_nullary_func(self):
# Even nullary functions are converted to TF (as opposed to constant-folded
# in JAX prior to conversion).
def f_jax():
return jnp.sin(1.)
f_tf = tf.function(jax2tf.convert(f_jax), autograph=False)
f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def()
if config.jax2tf_default_experimental_native_lowering:
self.assertIn("mhlo.sine", str(f_tf_graph))
else:
self.assertIn('op: "Sin"', str(f_tf_graph))
def test_convert_of_nested_independent_jit(self):
def func(x):
def inner1(y):
return x + y
# The JIT does not have data dependency
return jax.jit(inner1)(1.)
jax2tf.convert(func)(2.)
def test_convert_of_nested_dependent_jit(self):
def func(x):
def inner1(y):
return x + y
# The JIT does have data dependency
return jax.jit(inner1)(x)
jax2tf.convert(func)(2.) # No error
def test_nested_convert_error(self):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
jax2tf.convert(outer)(np.ones((4, )))
def test_nested_convert_error_non_tracer(self):
"""The inner convert takes non-tracer arguments"""
def outer(y):
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
return y + sin_1
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
jax2tf.convert(outer)(2.)
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
def test_convert_under_transform_error(self, transform="vmap"):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
@jtu.sample_product(transform=["jit", "jvp", "grad", "vmap"])
def test_convert_under_transform_error_non_tracer(self, transform="vmap"):
def outer(y):
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
return y + sin_1
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
def test_name_scope(self):
@tf.function(autograph=False)
def run():
@jax.named_call
def my_test_function(x):
return x * x
def caller(x):
return my_test_function(jnp.sin(x))
out = jax2tf.convert(caller, with_gradient=False)(2.)
return out
run_graph = run.get_concrete_function().graph.as_graph_def()
print(str(run_graph))
if config.jax2tf_default_experimental_native_lowering:
self.assertIn("my_test_function/mul", str(run_graph))
else:
self.assertIn("my_test_function/jit_fn_/Mul", str(run_graph))
def test_bfloat16_constant(self):
# Re: https://github.com/google/jax/issues/3942
def jax_fn_scalar(x):
x = x.astype(jnp.bfloat16)
x *= 2.
return x
def jax_fn_array(x):
x = x.astype(jnp.bfloat16)
x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16)
return x
tf_fn_scalar = jax2tf.convert(jax_fn_scalar)
self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750))
tf_fn_array = jax2tf.convert(jax_fn_array)
self.assertAllClose(
tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5],
jnp.bfloat16))
def test_shared_constants(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
def f(x):
return x + const + const + const + const
f_tf_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
self.assertEqual(f_tf_nr_consts, 1)
def test_shared_constants_under_cond(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
x = np.ones((256,), dtype=np.float32)
def f1(x):
return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const
def f2(x):
return f1(x) + const # The extra const should not cost anything
f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), x)
f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), x)
self.assertEqual(f1_nr_consts, f2_nr_consts)
def test_shared_constants_under_scan(self):
# See https://github.com/google/jax/issues/7992.
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
xs = np.ones((8, 256), dtype=np.float32)
def f1(xs):
res, _ = lax.scan(lambda carry, x: (carry + x + const, None),
np.zeros((256,), dtype=np.float32), xs)
return res
def f2(xs):
return f1(xs) + const # The extra const should not be saved
f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), xs)
f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), xs)
self.assertEqual(f1_nr_consts, f2_nr_consts)
def test_shared_constants_under_jit(self):
# We do not share constants under jit.
const = np.random.uniform(size=(16, 16)).astype(np.float32) # A shared constant
@jax.jit
def g_jit(x):
return x * const
def f(x):
return g_jit(x) + const + const
f_tf_graph_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
# TODO(b/207464757): TF compilation is disabled
self.assertEqual(f_tf_graph_nr_consts, 1)
def test_weak_types(self):
mul = jax.jit(jnp.multiply)
# The value `2` here should be weakly typed, and should not lead to
# promotion.
tf_fn = jax2tf.convert(lambda x: mul(x, 2.))
self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(),
jnp.bfloat16(2.750))
@jtu.sample_product(with_function=[False, True])
def test_kwargs(self, with_function=True):
# Re: https://github.com/google/jax/issues/6791
def f_jax(*, x):
return jnp.sum(x)
f_tf = jax2tf.convert(f_jax)
if with_function:
f_tf = tf.function(f_tf)
self.assertAllClose(
f_tf(x=np.zeros(3, dtype=np.float32)), # Call with kwargs.
np.zeros((), dtype=np.float32))
@jtu.sample_product(with_function=[False, True])
def test_grad_kwargs(self, with_function=False):
# Re: https://github.com/google/jax/issues/6791
x = (np.zeros(3, dtype=np.float32),
np.zeros(4, dtype=np.float32))
def f_jax(*, x=(1., 2.)):
return jnp.sum(x[0]) + 2. * jnp.sum(x[1])
f_tf = jax2tf.convert(f_jax)
if with_function:
f_tf = tf.function(f_tf)
xv = tf.nest.map_structure(tf.Variable, x)
with tf.GradientTape() as tape:
res = f_tf(x=xv)
grad_tf = tape.gradient(res, xv)
self.assertAllClose((np.full_like(x[0], fill_value=1.),
np.full_like(x[1], fill_value=2.)),
(grad_tf[0].numpy(), grad_tf[1].numpy()))
@jtu.skip_on_flag("jax2tf_default_experimental_native_lowering", True)
def test_enable_xla(self):
# Tests that enable_xla flag is properly scoped to a conversion.
def fun(x):
# lax.reduce is unlikely to ever be convertible with enable_xla=False
return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1))
tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True)
tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False)
x = np.ones((2, 3), dtype=np.float32)
self.assertAllClose(fun(x), tf_fun_with_xla(x))
with self.assertRaisesRegex(NotImplementedError,
"Call to reduce cannot be converted with enable_xla=False"):
tf_fun_without_xla(x)
# Now in reverse order (we had bugs with the management of enable_xla global)
tf_fun2_without_xla = jax2tf.convert(lambda x: fun(x), enable_xla=False)
tf_fun2_with_xla = jax2tf.convert(lambda x: fun(x), enable_xla=True)
with self.assertRaisesRegex(NotImplementedError,
"Call to reduce cannot be converted with enable_xla=False"):
tf_fun2_without_xla(x)
self.assertAllClose(fun(x), tf_fun2_with_xla(x))
def test_device_array_arg(self):
self.ConvertAndCompare(jnp.sin, jnp.zeros((2, 3), jnp.float32))
def test_randint(self):
if jtu.device_under_test() == "gpu" and config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("randint on GPU uses custom calls; not supported")
def randint():
return jax.random.randint(
jax.random.PRNGKey(42), shape=(), minval=0, maxval=1)
self.ConvertAndCompare(randint)
def test_op_metadata_simple(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# A simple example
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_simple(x):
return jnp.sin(x)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_simple, x,
[tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.line_num + 2,
op_name="jax2tf(f_simple)/sin",
op_type="sin")
]
)
def test_op_metadata_sub_jit(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# Calling a jitted-function
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_callee(x):
return jnp.cos(x)
def f_caller(x):
y = jnp.tanh(x)
z = jax.jit(f_callee)(y)
return jnp.sin(z)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_caller, x,
[tf_test_util.OpMetadataGraph(tf_type="Tanh",
source_file=__file__,
source_line=user_frame.line_num + 4,
op_name="jax2tf(f_caller)/tanh",
op_type="tanh"),
tf_test_util.OpMetadataGraph(tf_type="Cos",
source_file=__file__,
source_line=user_frame.line_num + 2,
op_name="jax2tf(f_caller)/jit(f_callee)/cos",
op_type="cos"),
tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.line_num + 6,
op_name="jax2tf(f_caller)/sin",
op_type="sin"),
]
)
def test_op_metadata_named(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# Calling a jax.named_call
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_callee(x):
return jnp.cos(x)
def f_caller(x):
y = jnp.tanh(x)
z = jax.named_call(f_callee, name="callee")(y)
return jnp.sin(z)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_caller, x,
[tf_test_util.OpMetadataGraph(tf_type="Tanh",
source_file=__file__,
source_line=user_frame.line_num + 4,
op_name="jax2tf(f_caller)/tanh",
op_type="tanh"),
tf_test_util.OpMetadataGraph(tf_type="Cos",
source_file=__file__,
source_line=user_frame.line_num + 2,
op_name="jax2tf(f_caller)/named(callee)/cos",
op_type="cos"),
tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.line_num + 6,
op_name="jax2tf(f_caller)/sin",
op_type="sin"),
]
)
def test_op_metadata_while_and_cond(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# An example with while and cond
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
def f_while_cond(x):
def body_fun(i_acc):
i, acc = i_acc
return (i + 1,
(jnp.cos(acc) +
lax.cond(jnp.mod(i, 2) == 0,
lambda acc: jnp.sin(acc),
lambda acc: acc,
acc)))
_, acc = lax.while_loop(
lambda i_acc: i_acc[0] <= 5,
body_fun, (0, x))
return acc
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_while_cond, x,
[tf_test_util.OpMetadataGraph(tf_type="Cos",
source_file=__file__,
source_line=user_frame.line_num + 5,
op_name="jax2tf(f_while_cond)/while/body/cos",
op_type="cos"),
tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.line_num + 7,
op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin",
op_type="sin"),
tf_test_util.OpMetadataGraph(tf_type="FloorMod",
source_file=__file__,
source_line=user_frame.line_num + 6,
op_name="jax2tf(f_while_cond)/while/body/rem",
op_type="rem"),
]
)
def test_op_metadata_batched_while(self):
self.skipTest("include_xla_op_metadata not yet enabled")
# An example with while and cond
# The user_frame is used to compute line numbers for ops in the test.
user_frame = source_info_util.user_frame(source_info_util.current())
@jax.vmap
def f_while(x):
def body_fun(carry):
new_carry = jnp.sin(carry) # We look for "sin" in the graph
return new_carry
_, carry = lax.while_loop(
lambda carry: jnp.all(carry <= x), # We look for "le" in the graph
body_fun, x)
return carry
shape = (3, 2)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
jax_comp = jax.xla_computation(f_while)(x)
backend = jax._src.lib.xla_bridge.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()
print(f"JAX OPT HLO = {jax_opt_hlo}")
self.CheckOpMetadata(
f_while, x,
[tf_test_util.OpMetadataGraph(tf_type="Sin",
source_file=__file__,
source_line=user_frame.line_num + 4,
op_name="jax2tf(f_while)/while/body/sin",
op_type="sin"),
tf_test_util.OpMetadataGraph(tf_type="LessEqual",
source_file=__file__,
source_line=user_frame.line_num + 8,
op_name="jax2tf(f_while)/while/body_pred/le",
op_type="le"),
]
)
def test_op_metadata_disabled(self):
self.skipTest("include_xla_op_metadata not yet enabled")
def f_simple(x):
return jnp.sin(x)
x = np.ones((2, 3), np.float32)
self.CheckOpMetadata(
f_simple, x,
[],
include_xla_op_metadata=False
)
def get_serialized_computation(
f_jax: Callable,
*args,
abstracted_axes: Optional[Tuple[Dict[int, str]]] = None,
use_pjit: bool = False,
in_axis_resources = None,
out_axis_resources = None) -> str:
if use_pjit:
assert not abstracted_axes
lowered = pjit(f_jax,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources).lower(*args)
else:
lowered = jax.jit(f_jax, abstracted_axes=abstracted_axes).lower(*args)
mhlo_module = lowered.compiler_ir(dialect='mhlo')
mhlo_module_text = mlir.module_to_string(mhlo_module)
if jaxlib.version <= (0, 3, 14):
mhlo_module_text = jax2tf.jax2tf._fixup_mhlo_module_text(mhlo_module_text)
logging.info(f'Serialized ir.Module = {mhlo_module_text}')
return mhlo_module_text
class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
"""Unit tests for XlaCallModule. Will move these eventually to TF."""
def test_simple(self):
def f_jax(x):
return jnp.sin(x)
x = np.ones((2, 3), dtype=np.float32)
jax_res = f_jax(x)
res = tfxla.call_module([x],
module=get_serialized_computation(f_jax, x),
Tout=[jax_res.dtype],
Sout=[jax_res.shape])
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
[jax_res])
def test_while(self):
# With nested computation
def f_jax(count, x):
return lax.while_loop(lambda carry: carry[0] < count, lambda carry:
(carry[0] + 1, carry[1] + 1.), (0, x))[1]
count = np.int32(5)
x = np.ones((2, 3), dtype=np.float32)
jax_res = f_jax(count, x)
res = tfxla.call_module([count, x],
module=get_serialized_computation(f_jax, count, x),
Tout=[jax_res.dtype],
Sout=[jax_res.shape])
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
[jax_res])
def test_multiple_args_results(self):
def f_jax(x1, x2):
return (jnp.sin(x1), jnp.cos(x2))
x1 = np.ones((2, 3), dtype=np.float32)
x2 = np.ones((3, 4), dtype=np.float32)
jax_res = f_jax(x1, x2)
def f_tf(x1_tf, x2_tf):
return tfxla.call_module([x1_tf, x2_tf],
module=get_serialized_computation(f_jax, x1, x2),
Tout=[jax_res[0].dtype, jax_res[1].dtype],
Sout=[jax_res[0].shape, jax_res[1].shape])
res = tf.function(f_tf, jit_compile=True, autograph=False)(x1, x2)
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
jax_res)
@unittest.skip("TODO(necula): 'mhlo.dynamic_iota' op can't be translated to XLA HLO")
def test_shape_poly_arange(self):
if not config.jax_dynamic_shapes:
raise unittest.SkipTest("jax_dynamic_shapes must be enabled")
def f_jax(x): # x: f32[b]
return jnp.arange(x.shape[0]) + x
x1 = np.ones((5,), dtype=np.float32)
jax_res = f_jax(x1)
def f_tf(x1_tf):
return tfxla.call_module([x1_tf],
module=get_serialized_computation(
f_jax, x1,
abstracted_axes=({
0: 'b'
},)),
Tout=[jax_res.dtype],
Sout=[jax_res.shape],
dim_args_spec=('0.0',))
res = tf.function(f_tf, jit_compile=True, autograph=False)(x1)
self.assertAllClose(
tf.nest.map_structure(lambda t: t.numpy(), res), jax_res)
# TODO(necula): figure out this failure
@jtu.skip_on_flag("jax2tf_default_experimental_native_lowering", True)
def test_global_device_array(self):
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(np.prod(global_shape)).reshape(global_shape)
return GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes,
lambda idx: global_data[idx]), global_data
global_mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh_axes = P(("x", "y"))
params, _ = create_gda((8, 2), global_mesh, mesh_axes)
input_data = np.arange(16).reshape(2, 8)
# Test 1: use GDA as constants
def jax_func(input_data):
handle = pjit(
jnp.matmul,
in_axis_resources=(P("y", "x"), FROM_GDA),
out_axis_resources=None)
return handle(input_data, params)
with global_mesh:
tf_func = tf.function(
jax2tf.convert(jax_func, enable_xla=True),
jit_compile=True,
)
jax_out = jax_func(input_data=input_data)
tf_out = tf_func(input_data=input_data)
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
np.array_equal(jax_out._value, np.array(tf_out))
@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D(self):
def func_jax(x, y):
return x + y
shape = (8, 10)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
in_axis_resources = (P("x"), P("x"))
out_axis_resources = None
res_jax = pjit(func_jax,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)(x, x)
module = get_serialized_computation(func_jax, x, x,
use_pjit=True,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)
def f_tf(x_tf, y_tf):
return tfxla.call_module([x_tf, y_tf],
module=module,
Tout=[x.dtype],
Sout=[x.shape])
res_tf = tf.function(f_tf, jit_compile=True, autograph=False)(x, x)[0]
self.assertAllClose(res_tf.numpy(), res_jax)
if __name__ == "__main__":
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if not hasattr(tfxla, "optimization_barrier"):
jax.config.update("jax_remat_opt_barrier", False)
absltest.main(testLoader=jtu.JaxTestLoader())