[jax2tf] Refactor tests to increase coverage (#3700)

* [jax2tf] Refactor tests to increase coverage.

This change has several goals: (a) increase the test coverage
for the elementwise primitives, (b) expose explicitly
the situations where JAX and TF produce different results,
e.g., inf vs Nan, and (c) run all comparisons with and
without tf.function, with and without experimental_compile=True.

Previously the test code was just masking off non-finite values.

This change uncovered quite a few unimplemented cases, e.g., with
float16, bfloat16, conversions that cannot be compiled.
This are left as TODO for now.

* Disable the sort tests
This commit is contained in:
George Necula 2020-07-15 09:49:51 +03:00 committed by GitHub
parent 68c8dc781e
commit c380356ff0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 376 additions and 315 deletions

View File

@ -14,7 +14,6 @@
"""Tests for the jax2tf conversion for control-flow primitives."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.lax as lax
@ -30,27 +29,19 @@ config.parse_flags_with_absl()
class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_cond(self, with_function=False):
def test_cond(self):
def f_jax(pred, x):
return lax.cond(pred, lambda t: t + 1., lambda f: f, x)
self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
self.ConvertAndCompare(f_jax, True, 1.)
self.ConvertAndCompare(f_jax, False, 1.)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_cond_multiple_results(self, with_function=False):
def test_cond_multiple_results(self):
def f_jax(pred, x):
return lax.cond(pred, lambda t: (t + 1., 1.), lambda f: (f + 2., 2.), x)
self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
self.ConvertAndCompare(f_jax, True, 1.)
self.ConvertAndCompare(f_jax, False, 1.)
def test_cond_partial_eval(self):
def f(x):
@ -58,16 +49,13 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
return res
self.ConvertAndCompare(jax.grad(f), 1.)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_cond_units(self, with_function=True):
def g(x):
return lax.cond(True, lambda x: x, lambda y: y, x)
self.ConvertAndCompare(g, 0.7, with_function=with_function)
self.ConvertAndCompare(jax.grad(g), 0.7, with_function=with_function)
self.ConvertAndCompare(g, 0.7)
self.ConvertAndCompare(jax.grad(g), 0.7)
def test_cond_custom_jvp(self):
@ -121,24 +109,17 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
self.TransformConvertAndCompare(g, arg, "vmap")
self.TransformConvertAndCompare(g, arg, "grad_vmap")
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_while_single_carry(self, with_function=False):
def test_while_single_carry(self):
"""A while with a single carry"""
def func(x):
# Equivalent to:
# for(i=x; i < 4; i++);
return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)
self.ConvertAndCompare(func, 0, with_function=with_function)
self.ConvertAndCompare(func, 0)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_while(self, with_function=False):
def test_while(self):
# Some constants to capture in the conditional branches
cond_const = np.ones(3, dtype=np.float32)
body_const1 = np.full_like(cond_const, 1.)
@ -163,13 +144,9 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
return lax.while_loop(cond, body, (0, x))
self.ConvertAndCompare(func, cond_const, with_function=with_function)
self.ConvertAndCompare(func, cond_const)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_while_batched_cond(self, with_function=True):
def test_while_batched_cond(self):
"""A while with a single carry"""
def product(x, y):
# Equivalent to "x * y" implemented as:
@ -190,7 +167,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
def product_xs_ys(xs, ys):
return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys)
self.ConvertAndCompare(product_xs_ys, xs, ys, with_function=with_function)
self.ConvertAndCompare(product_xs_ys, xs, ys)
def test_while_custom_jvp(self):
"""Conversion of function with custom JVP, inside while.
@ -218,11 +195,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
self.TransformConvertAndCompare(g, arg, "jvp_vmap")
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_scan(self, with_function=False):
def test_scan(self):
def f_jax(xs, ys):
body_const = np.ones((2, ), dtype=np.float32) # Test constant capture
def body(res0, inputs):
@ -231,9 +204,9 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
return lax.scan(body, 0., (xs, ys))
arg = np.arange(10, dtype=np.float32)
self.ConvertAndCompare(f_jax, arg, arg, with_function=with_function)
self.ConvertAndCompare(f_jax, arg, arg, check_compiled=False)
def test_scan_partial_eval(self, with_function=False):
def test_scan_partial_eval(self):
def f_jax(xs, ys):
body_const = np.ones((2, ), dtype=np.float32) # Test constant capture
def body(res0, inputs):
@ -244,7 +217,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
arg = np.arange(10, dtype=np.float32)
print(jax.make_jaxpr(jax.grad(f_jax))(arg, arg))
self.ConvertAndCompare(jax.grad(f_jax), arg, arg, with_function=with_function)
self.ConvertAndCompare(jax.grad(f_jax), arg, arg, check_compiled=False)
def test_scan_custom_jvp(self):
@ -268,12 +241,12 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
x)[0]
arg = np.full((5,), 0.7)
self.TransformConvertAndCompare(g, arg, None)
self.TransformConvertAndCompare(g, arg, "jvp")
self.TransformConvertAndCompare(g, arg, "vmap")
self.TransformConvertAndCompare(g, arg, "jvp_vmap")
self.TransformConvertAndCompare(g, arg, "grad")
self.TransformConvertAndCompare(g, arg, "grad_vmap")
self.TransformConvertAndCompare(g, arg, None, check_compiled=False)
self.TransformConvertAndCompare(g, arg, "jvp", check_compiled=False)
self.TransformConvertAndCompare(g, arg, "vmap", check_compiled=False)
self.TransformConvertAndCompare(g, arg, "jvp_vmap", check_compiled=False)
self.TransformConvertAndCompare(g, arg, "grad", check_compiled=False)
self.TransformConvertAndCompare(g, arg, "grad_vmap", check_compiled=False)
def test_scan_custom_vjp(self):
"""Conversion of function with custom VJP, inside scan.
@ -297,10 +270,10 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
x)[0]
arg = np.full((5,), 0.7)
self.TransformConvertAndCompare(g, arg, None)
self.TransformConvertAndCompare(g, arg, "vmap")
self.TransformConvertAndCompare(g, arg, "grad")
self.TransformConvertAndCompare(g, arg, "grad_vmap")
self.TransformConvertAndCompare(g, arg, None, check_compiled=False)
self.TransformConvertAndCompare(g, arg, "vmap", check_compiled=False)
self.TransformConvertAndCompare(g, arg, "grad", check_compiled=False)
self.TransformConvertAndCompare(g, arg, "grad_vmap", check_compiled=False)
if __name__ == "__main__":

View File

@ -95,10 +95,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_dtype={dtype.__name__}_function={with_function}",
with_function=with_function,
dict(testcase_name=f"_dtype={dtype.__name__}",
dtype=dtype)
for with_function in [False, True]
for dtype in [np.int64, np.float64]))
def test_converts_64bit(self, dtype=np.int64, with_function=False):
big_const = np.full((5,), 2 ** 33, dtype=dtype)
@ -115,7 +113,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_function(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
self.ConvertAndCompare(f_jax, 0.7, with_function=True)
self.ConvertAndCompare(f_jax, 0.7)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
@ -286,12 +284,12 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
return primal_out, tangent_out
arg = 0.7
self.TransformConvertAndCompare(f, arg, None, with_function=True)
self.TransformConvertAndCompare(f, arg, "jvp", with_function=True)
self.TransformConvertAndCompare(f, arg, "vmap", with_function=True)
self.TransformConvertAndCompare(f, arg, "jvp_vmap", with_function=True)
self.TransformConvertAndCompare(f, arg, "grad", with_function=True)
self.TransformConvertAndCompare(f, arg, "grad_vmap", with_function=True)
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "jvp")
self.TransformConvertAndCompare(f, arg, "vmap")
self.TransformConvertAndCompare(f, arg, "jvp_vmap")
self.TransformConvertAndCompare(f, arg, "grad")
self.TransformConvertAndCompare(f, arg, "grad_vmap")
def test_custom_vjp(self):
"""Conversion of function with custom VJP"""
@ -308,10 +306,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f.defvjp(f_fwd, f_bwd)
arg = 0.7
self.TransformConvertAndCompare(f, arg, None, with_function=True)
self.TransformConvertAndCompare(f, arg, "vmap", with_function=True)
self.TransformConvertAndCompare(f, arg, "grad", with_function=True)
self.TransformConvertAndCompare(f, arg, "grad_vmap", with_function=True)
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "vmap")
self.TransformConvertAndCompare(f, arg, "grad")
self.TransformConvertAndCompare(f, arg, "grad_vmap")
if __name__ == "__main__":

View File

@ -33,19 +33,44 @@ FLAGS = config.FLAGS
Rng = Any # A random number generator
class RandArg(NamedTuple):
"""Descriptor for a randomly generated argument."""
"""Descriptor for a randomly generated argument.
See description of `Harness`.
"""
shape: Tuple[int, ...]
dtype: np.dtype
class StaticArg(NamedTuple):
"""Descriptor for a static argument."""
"""Descriptor for a static argument.
See description of `Harness`.
"""
value: Any
class Harness:
"""Specifies inputs and callable for a primitive.
A primitive can take dynamic and static arguments. The dynamic arguments can
be generated using a RNG, are numeric (and appropriate for JIT).
A harness is conceptually a callable and a list of arguments, that together
exercise a use case. The harness can optionally have additional parameters
that can be used by the test.
The arguments are specified through argument descriptors. An argument
descriptor can be:
* a numeric value or ndarray, or
* an instance of ``RandArg(shape, dtype)`` to be used with a PRNG to generate
random tensor of the given shape and type, or
* an instance of ``StaticArg(value)``. These are values that specialize the
callable, but are not exposed as external arguments.
For example, a harness for ``lax.take(arr, indices, axis=None)`` may want
to expose as external (dynamic) argument the array and the indices, and
keep the axis as a static argument (technically specializing the `take` to
a axis):
Harness(f"take_axis={axis}",
lax.take,
[RandArg((2, 4), np.float32), np.array([-1, 0, 1]), StaticArg(axis)],
axis=axis)
"""
# Descriptive name of the harness, used as a testcase_name. Unique in a group.
name: str
@ -124,9 +149,117 @@ def parameterized(harness_group: Iterable[Harness],
cases = cases[0:1]
return testing.parameterized.named_parameters(*cases)
### Harness definitions ###
###
_LAX_UNARY_ELEMENTWISE = (
lax.abs, lax.acosh, lax.asinh, lax.atanh, lax.bessel_i0e, lax.bessel_i1e,
lax.ceil, lax.cos, lax.cosh, lax.digamma, lax.erf, lax.erf_inv, lax.erfc,
lax.exp, lax.expm1, lax.floor, lax.is_finite, lax.lgamma, lax.log,
lax.log1p, lax.neg, lax.round, lax.rsqrt, lax.sign, lax.sin, lax.sinh,
lax.sqrt, lax.tan, lax.tanh)
lax_unary_elementwise = tuple(
Harness(f"{f_lax.__name__}_{jtu.dtype_str(dtype)}",
f_lax,
[arg],
lax_name=f_lax.__name__,
dtype=dtype)
for f_lax in _LAX_UNARY_ELEMENTWISE
for dtype in jtu.dtypes.all_floating
for arg in [
np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1., 1.4, 1.6], dtype=dtype)
]
)
lax_bitwise_not = tuple(
[Harness(f"{jtu.dtype_str(dtype)}",
lax.bitwise_not,
[arg],
dtype=dtype)
for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned
for arg in [
np.array([-1, -3, -2, 0, 0, 2, 1, 3], dtype=dtype),
]] +
[Harness("bool",
f_lax,
[arg],
lax_name=f_lax.__name__,
dtype=np.bool_)
for f_lax in [lax.bitwise_not]
for arg in [
np.array([True, False])
]]
)
_LAX_BINARY_ELEMENTWISE = (
lax.add, lax.atan2, lax.div, lax.igamma, lax.igammac, lax.max, lax.min,
lax.nextafter, lax.rem, lax.sub)
lax_binary_elementwise = tuple(
Harness(f"{f_lax.__name__}_{jtu.dtype_str(dtype)}",
f_lax,
[arg1, arg2],
lax_name=f_lax.__name__,
dtype=dtype
)
for f_lax in _LAX_BINARY_ELEMENTWISE
for dtype in jtu.dtypes.all_floating
for arg1, arg2 in [
(np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1., 1.4, 1.6], dtype=dtype),
np.array([-1.6, 1.4, 1.0, 0.0, 0.1, 0.2, 1., 1.4, -1.6], dtype=dtype))
]
)
_LAX_BINARY_ELEMENTWISE_LOGICAL = (
lax.bitwise_and, lax.bitwise_or, lax.bitwise_xor, lax.shift_left,
)
lax_binary_elementwise_logical = tuple(
[Harness(f"{f_lax.__name__}_{jtu.dtype_str(dtype)}",
f_lax,
[arg1, arg2],
lax_name=f_lax.__name__,
dtype=dtype)
for f_lax in _LAX_BINARY_ELEMENTWISE_LOGICAL
for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned
for arg1, arg2 in [
(np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=dtype),
np.array([1, 2, 3, 0, 1, 0, 2, 3], dtype=dtype))
]
] +
[Harness(f"{f_lax.__name__}_bool",
f_lax,
[arg1, arg2],
lax_name=f_lax.__name__,
dtype=np.bool_)
for f_lax in [lax.bitwise_and, lax.bitwise_or, lax.bitwise_xor]
for arg1, arg2 in [
(np.array([True, True, False, False]),
np.array([True, False, True, False])),
]
]
)
lax_betainc = tuple(
Harness(f"_{jtu.dtype_str(dtype)}",
lax.betainc,
[arg1, arg2, arg3],
dtype=dtype)
for dtype in jtu.dtypes.all_floating
for arg1, arg2, arg3 in [
(np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=dtype),
np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=dtype),
np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6],
dtype=np.float32))
]
)
_gather_input = np.arange(1000, dtype=np.float32).reshape((10, 10, 10))
lax_gather = jtu.cases_from_list(
lax_gather = tuple(
# Construct gather harnesses using take
[Harness(f"from_take_indices_shape={indices.shape}_axis={axis}",
lambda a, i, axis: jnp.take(a, i, axis=axis),
@ -168,7 +301,7 @@ lax_gather = jtu.cases_from_list(
)
lax_pad = jtu.cases_from_list(
lax_pad = tuple(
Harness(f"_inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_pads={pads}",
lax.pad,
[RandArg(arg_shape, dtype), np.array(0, dtype), StaticArg(pads)],
@ -226,7 +359,7 @@ lax_sort = tuple( # one array, random data, all axes, all dtypes
)
lax_slice = jtu.cases_from_list(
lax_slice = tuple(
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
lax.slice,
[RandArg(shape, dtype), # type: ignore
@ -267,7 +400,7 @@ lax_dynamic_slice = [
for limit_indices in [harness.params["limit_indices"]]
]
lax_dynamic_update_slice = jtu.cases_from_list(
lax_dynamic_update_slice = tuple(
Harness((f"_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}"
f"_start_indices={start_indices}"),
@ -292,7 +425,7 @@ lax_dynamic_update_slice = jtu.cases_from_list(
(np.float64, np.float32)
])
lax_squeeze = jtu.cases_from_list(
lax_squeeze = tuple(
Harness(f"_inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_dimensions={dimensions}", # type: ignore
lax.squeeze,
[RandArg(arg_shape, dtype), StaticArg(dimensions)], # type: ignore[has-type]
@ -319,21 +452,21 @@ shift_inputs = [
for shift_amount in [0, 1, 2, 3, 7]
]
lax_shift_left = jtu.cases_from_list(
lax_shift_left = tuple(
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
lax.shift_left,
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
for arg, dtype, shift_amount in shift_inputs
)
lax_shift_right_logical = jtu.cases_from_list(
lax_shift_right_logical = tuple(
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
lax.shift_right_logical,
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
for arg, dtype, shift_amount in shift_inputs
)
lax_shift_right_arithmetic = jtu.cases_from_list(
lax_shift_right_arithmetic = tuple(
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
lax.shift_right_arithmetic,
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])

View File

@ -35,78 +35,21 @@ config.parse_flags_with_absl()
# Import after parsing flags
from jax.experimental.jax2tf.tests import primitive_harness
# TODO(tomhennigan) Increase coverage here.
LAX_ELEMENTWISE_UNARY = (
lax.abs,
lax.acosh,
lax.asinh,
lax.atanh,
lax.bessel_i0e,
lax.bessel_i1e,
lax.ceil,
lax.cos,
lax.cosh,
lax.digamma,
lax.erf,
lax.erf_inv,
lax.erfc,
lax.exp,
lax.expm1,
lax.floor,
lax.is_finite,
lax.lgamma,
lax.log,
lax.log1p,
lax.neg,
lax.round,
lax.rsqrt,
lax.sign,
lax.sin,
lax.sinh,
lax.sqrt,
lax.tan,
lax.tanh,
)
LAX_ELEMENTWISE_BINARY = (
lax.add,
lax.atan2,
lax.div,
lax.igamma,
lax.igammac,
lax.max,
lax.min,
lax.nextafter,
lax.rem,
lax.sub,
)
LAX_LOGICAL_ELEMENTWISE_UNARY = (
lax.bitwise_not,
)
LAX_LOGICAL_ELEMENTWISE_BINARY = (
lax.bitwise_and,
lax.bitwise_or,
lax.bitwise_xor,
lax.shift_left,
)
REDUCE = (
jnp.all,
jnp.any,
jnp.max,
jnp.min,
jnp.prod,
jnp.sum,
jnp.all,
jnp.any,
jnp.max,
jnp.min,
jnp.prod,
jnp.sum,
)
INDEX = (
jax.ops.index_add,
jax.ops.index_max,
jax.ops.index_min,
jax.ops.index_mul,
jax.ops.index_update,
jax.ops.index_add,
jax.ops.index_max,
jax.ops.index_min,
jax.ops.index_mul,
jax.ops.index_update,
)
@ -147,14 +90,14 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
for y_dtype in types:
x = np.array([1, 2], dtype=x_dtype)
y = np.array([3, 4], dtype=y_dtype)
self.ConvertAndCompare(f_jax, x, y, with_function=True)
self.ConvertAndCompare(f_jax, x, y)
def test_concat(self):
values = [np.array([1, 2], dtype=np.float32),
np.array([1, 2], dtype=np.int32),
np.array([1, 2], dtype=np.int8)]
f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0))
self.ConvertAndCompare(f_jax, values, with_function=True)
self.ConvertAndCompare(f_jax, values)
@primitive_harness.parameterized(primitive_harness.lax_pad)
def test_pad(self, harness: primitive_harness.Harness):
@ -164,8 +107,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# TODO: fix pad with negative padding in XLA (fixed on 06/16/2020)
if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
raise unittest.SkipTest("pad with negative pad not supported")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=False)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_sort)
def test_sort(self, harness: primitive_harness.Harness):
@ -189,190 +131,173 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
not harness.params["is_stable"]):
# TODO: fix the TF GPU test
raise unittest.SkipTest("GPU tests are running TF on CPU")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=False)
# TODO: if we enable this test, we get the error
# iterating over `tf.Tensor` is not allowed: AutoGraph is disabled in this function.
raise unittest.SkipTest("TODO: re-enable the sort test")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
for f_jax in LAX_ELEMENTWISE_UNARY))
def test_unary_elementwise(self, f_jax=lax.abs):
x = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
dtype=np.float32)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(x)
r_tf = f_tf(x)
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)
@primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
def test_unary_elementwise(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]
if dtype is dtypes.bfloat16:
raise unittest.SkipTest("bfloat16 not implemented")
arg, = harness.dyn_args_maker(self.rng())
custom_assert = None
if harness.params["lax_name"] == "digamma":
# digamma is not defined at 0 and -1
def custom_assert(result_jax, result_tf):
# lax.digamma returns NaN and tf.math.digamma returns inf
special_cases = (arg == 0.) | (arg == -1.)
nr_special_cases = np.count_nonzero(special_cases)
self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
result_jax[special_cases])
self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)),
result_tf[special_cases])
# non-special cases are equal
self.assertAllClose(result_jax[~ special_cases],
result_tf[~ special_cases])
if harness.params["lax_name"] == "erf_inv":
# TODO(necula): fix bug with erf_inv/f16
if dtype is np.float16:
raise unittest.SkipTest("TODO: fix bug")
# erf_inf is not defined for arg <= -1 or arg >= 1
def custom_assert(result_jax, result_tf): # noqa: F811
# for arg < -1 or arg > 1
# lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf
special_cases = (arg < -1.) | (arg > 1.)
nr_special_cases = np.count_nonzero(special_cases)
self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
result_jax[special_cases])
signs = np.where(arg[special_cases] < 0., -1., 1.)
self.assertAllClose(np.full((nr_special_cases,), signs * dtype(np.inf)),
result_tf[special_cases])
# non-special cases are equal
self.assertAllClose(result_jax[~ special_cases],
result_tf[~ special_cases])
self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert)
def test_bitwise_not(self):
x = np.array([-1, 3, -2, 0, 0, 2, 1, 3], dtype=np.int32)
f_jax = jax.jit(lax.bitwise_not)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(x)
r_tf = f_tf(x)
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)
@primitive_harness.parameterized(primitive_harness.lax_bitwise_not)
def test_bitwise_not(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
for f_jax in LAX_ELEMENTWISE_BINARY))
def test_binary_elementwise(self, f_jax=lax.add):
a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
dtype=np.float32)
b = np.array([-1.6, 1.4, 1.0, 0.0, 0.1, 0.2, 1, 1.4, -1.6],
dtype=np.float32)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a, b)
r_tf = f_tf(a, b)
# Jax outputs 0 and 1 instead of NaN for values outside the domain.
# Whereas tensorflow does this for other combinations,
if f_jax in (lax.igamma, lax.igammac):
# Make returned array writeable.
r_jax = np.copy(r_jax)
r_jax[r_jax == 0] = np.nan
r_jax[r_jax == 1] = np.nan
r_tf = np.copy(r_tf)
r_tf[r_tf == 0] = np.nan
r_tf[r_tf == 1] = np.nan
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)
@primitive_harness.parameterized(primitive_harness.lax_binary_elementwise)
def test_binary_elementwise(self, harness):
if harness.params["dtype"] is dtypes.bfloat16:
raise unittest.SkipTest("bfloat16 not implemented")
# TODO(necula): fix bug with igamma/f16
if (harness.params["lax_name"] in ("igamma", "igammac") and
harness.params["dtype"] is np.float16):
raise unittest.SkipTest("TODO: fix bug")
# TODO(necula): fix bug with nextafter/f16
if (harness.params["lax_name"] == "nextafter" and
harness.params["dtype"] is np.float16):
raise unittest.SkipTest("TODO: understand unimplemented case")
arg1, arg2 = harness.dyn_args_maker(self.rng())
custom_assert = None
if harness.params["lax_name"] == "igamma":
# igamma is not defined when the first argument is <=0
def custom_assert(result_jax, result_tf):
# lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0
special_cases = (arg1 == 0.) & (arg2 == 0.)
nr_special_cases = np.count_nonzero(special_cases)
self.assertAllClose(np.full((nr_special_cases,), np.nan),
result_jax[special_cases])
self.assertAllClose(np.full((nr_special_cases,), 0.),
result_tf[special_cases])
# non-special cases are equal
self.assertAllClose(result_jax[~ special_cases],
result_tf[~ special_cases])
if harness.params["lax_name"] == "igammac":
# igammac is not defined when the first argument is <=0
def custom_assert(result_jax, result_tf): # noqa: F811
# lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
special_cases = (arg1 <= 0.) | (arg2 <= 0)
nr_special_cases = np.count_nonzero(special_cases)
self.assertAllClose(np.full((nr_special_cases,), 1.),
result_jax[special_cases])
self.assertAllClose(np.full((nr_special_cases,), np.nan),
result_tf[special_cases])
# non-special cases are equal
self.assertAllClose(result_jax[~ special_cases],
result_tf[~ special_cases])
self.ConvertAndCompare(harness.dyn_fun, arg1, arg2,
custom_assert=custom_assert)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY))
def test_binary_logical_elementwise(self, f_jax):
a = np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=np.uint32)
b = np.array([1, 2, 3, 0, 1, 0, 2, 3], dtype=np.uint32)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a, b)
r_tf = f_tf(a, b)
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)
# Checks support for bools.
if f_jax in (lax.bitwise_and, lax.bitwise_or, lax.bitwise_xor):
a = np.array([True, True, False, False])
b = np.array([True, False, True, False])
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a, b)
r_tf = f_tf(a, b)
self.assertArraysEqual(r_jax, np.asarray(r_tf))
@primitive_harness.parameterized(primitive_harness.lax_binary_elementwise_logical)
def test_binary_elementwise_logical(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
for f_jax in LAX_LOGICAL_ELEMENTWISE_UNARY))
def test_unary_logical_elementwise(self, f_jax):
a = np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=np.uint32)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a)
r_tf = f_tf(a)
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)
# Checks support for bools.
a = np.array([True, False])
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a)
r_tf = f_tf(a)
self.assertArraysEqual(r_jax, np.asarray(r_tf))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY))
def test_binary_logical_elementwise_bool(self, f_jax):
if f_jax == lax.shift_left:
self.skipTest("Shift of bool not supported")
a = np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=np.bool_)
b = np.array([0, 1, 0, 1, 0, 1, 0, 1], dtype=np.bool_)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a, b)
r_tf = f_tf(a, b)
self.assertAllClose(r_jax, r_tf)
@primitive_harness.parameterized(primitive_harness.lax_betainc)
def test_betainc(self, harness: primitive_harness.Harness):
if harness.params["dtype"] is dtypes.bfloat16:
raise unittest.SkipTest("bfloat16 not implemented")
# TODO(necula): fix bug with betainc/f16
if harness.params["dtype"] is np.float16:
raise unittest.SkipTest("TODO: understand betainc/f16 bug")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
# TODO(necula): combine tests that are identical except for the harness
# wait until we get more experience with using harnesses.
@primitive_harness.parameterized(primitive_harness.lax_shift_left)
def test_shift_left(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_shift_right_logical)
def test_shift_right_logical(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_shift_right_arithmetic)
def test_shift_right_arithmetic(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_slice)
def test_slice(self, harness):
# JAX.slice rejects negative indices; check, and skip jax2tf
if any(si < 0 or si >= sh or li < 0 or li > sh
for sh, si, li in zip(harness.params["shape"],
harness.params["start_indices"],
harness.params["limit_indices"])):
harness.params["start_indices"],
harness.params["limit_indices"])):
with self.assertRaisesRegex(TypeError, ""):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
else:
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_dynamic_slice)
def test_dynamic_slice(self, harness):
# JAX.dynamic_slice rejects slice sizes too big; check, and skip jax2tf
# JAX.dynamic_slice rejects slice sizes too big; check this, and skip jax2tf
if any(li - si < 0 or li - si >= sh
for sh, si, li in zip(harness.params["shape"],
harness.params["start_indices"],
harness.params["limit_indices"])):
harness.params["start_indices"],
harness.params["limit_indices"])):
with self.assertRaisesRegex(TypeError, ""):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
else:
# TF compiler gives an error for tf.slice(start_indices < 0)
if any(si < 0 for si in harness.params["start_indices"]):
raise unittest.SkipTest("TF gives error for negative start_indices")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
check_compiled=False)
@primitive_harness.parameterized(primitive_harness.lax_dynamic_update_slice)
def test_dynamic_update_slice(self, harness):
# JAX.dynamic_update_slice rejects update slices too big; check, and skip jax2tf
if any(ush > sh
for sh, ush in zip(harness.params["shape"],
harness.params["update_shape"])):
harness.params["update_shape"])):
with self.assertRaisesRegex(TypeError, ""):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
else:
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
for f_jax in (lax.betainc,)))
def test_trinary_elementwise(self, f_jax):
a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6],
dtype=np.float32)
b = np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6],
dtype=np.float32)
c = np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6],
dtype=np.float32)
f_tf = tf.function(jax2tf.convert(f_jax))
r_jax = f_jax(a, b, c)
r_tf = f_tf(a, b, c)
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_squeeze)
def test_squeeze(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_gather)
def test_gather(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=False)
check_compiled=False)
def test_boolean_gather(self):
values = np.array([[True, True], [False, True], [False, False]],
@ -380,13 +305,14 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
indices = np.array([0, 1], dtype=np.int32)
for axis in [0, 1]:
f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop
self.ConvertAndCompare(f_jax, values, indices, with_function=True)
# TODO: why can't we compile this code?
self.ConvertAndCompare(f_jax, values, indices, check_compiled=False)
def test_gather_rank_change(self):
params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
indices = jnp.array([[1, 1, 2], [0, 1, 0]])
f_jax = jax.jit(lambda i: params[i])
self.ConvertAndCompare(f_jax, indices, with_function=True)
self.ConvertAndCompare(f_jax, indices, check_compiled=False)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
@ -394,7 +320,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
for f_jax in REDUCE))
def test_reduce_ops_with_numerical_input(self, f_jax):
values = np.array([1, 2, 3], dtype=np.float32)
self.ConvertAndCompare(f_jax, values, with_function=True)
self.ConvertAndCompare(f_jax, values)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
@ -402,7 +328,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
for f_jax in (jnp.cumsum, jnp.cumprod)))
def test_cumulated_ops(self, f_jax):
values = np.array([1, 2, 3], dtype=np.float32)
self.ConvertAndCompare(f_jax, values, with_function=True)
self.ConvertAndCompare(f_jax, values)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{op.__name__}",
@ -412,7 +338,9 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
values = np.ones((5, 6), dtype=np.float32)
update = np.float32(6.)
f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
self.ConvertAndCompare(f_jax, values, update, with_function=True)
# TODO: compilation fails
self.ConvertAndCompare(f_jax, values, update,
check_compiled=False)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
@ -420,8 +348,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
for f_jax in REDUCE))
def test_reduce_ops_with_boolean_input(self, f_jax):
values = np.array([True, False, True], dtype=np.bool_)
self.ConvertAndCompare(f_jax, values, with_function=True)
self.ConvertAndCompare(f_jax, values)
def test_prngsplit(self):
f_jax = jax.jit(lambda key: jax.random.split(key, 2))
@ -431,8 +358,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
]:
self.ConvertAndCompare(f_jax, rng_key, with_function=True)
self.ConvertAndCompare(f_jax, rng_key)
def test_zeros_like(self):
v = np.float32(2.)
@ -443,5 +369,6 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
f = jax2tf.convert(lax.stop_gradient)
self.assertEqual(f(tf.ones([])), 1.)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -60,24 +60,54 @@ class JaxToTfTestCase(jtu.JaxTestCase):
to_numpy_dtype(jtu._dtype(y)))
def ConvertAndCompare(self, func_jax: Callable, *args,
with_function: bool = False,
custom_assert: Optional[Callable] = None,
check_compiled: bool = True,
atol=None,
rtol=None) -> Tuple[Any, Any]:
"""Compares jax_func(*args) with convert(jax_func)(*args)."""
"""Compares jax_func(*args) with convert(jax_func)(*args).
It compares the result of JAX, TF, TF with tf.function, and TF with
tf.function(experimental_compile=True).
Args:
check_compiled: check that JAX and tf.function(experimental_compile=True)
produce the same results.
custom_assert: a function that will be called
`custom_assert(result_jax, result_tf)` to assert equality of the
results. Use this function when JAX and TF produce different results.
This function is not used for the experimental_compile case, because
in that case we expect always the results to be equal.
"""
result_jax = func_jax(*args)
func_tf = jax2tf.convert(func_jax)
if with_function:
func_tf = tf.function(func_tf, autograph=False)
res_jax = func_jax(*args)
#logging.info(f"res_jax is {res_jax} on {res_jax.device_buffer.device()}")
res_tf = func_tf(*args)
#logging.info(f"res_tf is {res_tf} on {res_tf.backing_device}")
self.assertAllClose(res_jax, res_tf, atol=atol, rtol=rtol)
return (res_jax, res_tf)
result_tf = func_tf(*args)
# Sometimes JAX and TF(compile=False) give different results
if custom_assert is not None:
custom_assert(result_jax, result_tf)
else:
self.assertAllClose(result_jax, result_tf, atol=atol, rtol=rtol)
# Using tf.function should not make a difference. Is this even something
# we should test here?
result_tf_function = tf.function(func_tf, autograph=False)(*args)
self.assertAllClose(result_tf, result_tf_function, atol=atol, rtol=rtol)
# The result of JAX and TF with compile=True are always the same
# TODO: enable compilation
if check_compiled:
result_tf_function_compile = tf.function(func_tf, autograph=False,
experimental_compile=True)(*args)
self.assertAllClose(result_jax, result_tf_function_compile,
atol=atol, rtol=rtol)
return (result_jax, result_tf)
def TransformConvertAndCompare(self, func: Callable,
arg,
transform: Optional[str],
with_function: bool = False):
check_compiled=True):
"""Like ConvertAndCompare but first applies a transformation.
`func` must be a function from one argument to one result. `arg` is
@ -86,27 +116,27 @@ class JaxToTfTestCase(jtu.JaxTestCase):
`transform` can be None, "jvp", "grad", "vmap", "jvp_vmap", "grad_vmap"
"""
if transform is None:
return self.ConvertAndCompare(func, arg, with_function=with_function)
return self.ConvertAndCompare(func, arg, check_compiled=check_compiled)
if transform == "jvp":
t_func = lambda x, xt: jax.jvp(func, (x,), (xt,))
return self.ConvertAndCompare(t_func, arg, np.full_like(arg, 0.1),
with_function=with_function)
check_compiled=check_compiled)
if transform == "grad":
return self.ConvertAndCompare(jax.grad(func), arg,
with_function=with_function)
check_compiled=check_compiled)
if transform == "vmap":
t_arg = np.stack([arg] * 4)
return self.ConvertAndCompare(jax.vmap(func),
t_arg, with_function=with_function)
return self.ConvertAndCompare(jax.vmap(func), t_arg,
check_compiled=check_compiled)
if transform == "jvp_vmap":
jvp_func = lambda x, xt: jax.jvp(jax.vmap(func), (x,), (xt,))
t_arg = np.stack([arg] * 4)
return self.ConvertAndCompare(jvp_func, t_arg,
np.full_like(t_arg, 0.1),
with_function=with_function)
check_compiled=check_compiled)
if transform == "grad_vmap":
grad_func = jax.grad(lambda x: jnp.sum(jax.vmap(func)(x)))
t_arg = np.stack([arg] * 4)
return self.ConvertAndCompare(grad_func, t_arg,
with_function=with_function)
check_compiled=check_compiled)
assert False, transform