mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Refactor tests to increase coverage (#3700)
* [jax2tf] Refactor tests to increase coverage. This change has several goals: (a) increase the test coverage for the elementwise primitives, (b) expose explicitly the situations where JAX and TF produce different results, e.g., inf vs Nan, and (c) run all comparisons with and without tf.function, with and without experimental_compile=True. Previously the test code was just masking off non-finite values. This change uncovered quite a few unimplemented cases, e.g., with float16, bfloat16, conversions that cannot be compiled. This are left as TODO for now. * Disable the sort tests
This commit is contained in:
parent
68c8dc781e
commit
c380356ff0
@ -14,7 +14,6 @@
|
||||
"""Tests for the jax2tf conversion for control-flow primitives."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
@ -30,27 +29,19 @@ config.parse_flags_with_absl()
|
||||
|
||||
class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
def test_cond(self, with_function=False):
|
||||
def test_cond(self):
|
||||
def f_jax(pred, x):
|
||||
return lax.cond(pred, lambda t: t + 1., lambda f: f, x)
|
||||
|
||||
self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
|
||||
self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
|
||||
self.ConvertAndCompare(f_jax, True, 1.)
|
||||
self.ConvertAndCompare(f_jax, False, 1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
def test_cond_multiple_results(self, with_function=False):
|
||||
def test_cond_multiple_results(self):
|
||||
def f_jax(pred, x):
|
||||
return lax.cond(pred, lambda t: (t + 1., 1.), lambda f: (f + 2., 2.), x)
|
||||
|
||||
self.ConvertAndCompare(f_jax, True, 1., with_function=with_function)
|
||||
self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
|
||||
self.ConvertAndCompare(f_jax, True, 1.)
|
||||
self.ConvertAndCompare(f_jax, False, 1.)
|
||||
|
||||
def test_cond_partial_eval(self):
|
||||
def f(x):
|
||||
@ -58,16 +49,13 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
return res
|
||||
self.ConvertAndCompare(jax.grad(f), 1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
|
||||
def test_cond_units(self, with_function=True):
|
||||
def g(x):
|
||||
return lax.cond(True, lambda x: x, lambda y: y, x)
|
||||
|
||||
self.ConvertAndCompare(g, 0.7, with_function=with_function)
|
||||
self.ConvertAndCompare(jax.grad(g), 0.7, with_function=with_function)
|
||||
self.ConvertAndCompare(g, 0.7)
|
||||
self.ConvertAndCompare(jax.grad(g), 0.7)
|
||||
|
||||
|
||||
def test_cond_custom_jvp(self):
|
||||
@ -121,24 +109,17 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
self.TransformConvertAndCompare(g, arg, "vmap")
|
||||
self.TransformConvertAndCompare(g, arg, "grad_vmap")
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
def test_while_single_carry(self, with_function=False):
|
||||
|
||||
def test_while_single_carry(self):
|
||||
"""A while with a single carry"""
|
||||
def func(x):
|
||||
# Equivalent to:
|
||||
# for(i=x; i < 4; i++);
|
||||
return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)
|
||||
|
||||
self.ConvertAndCompare(func, 0, with_function=with_function)
|
||||
self.ConvertAndCompare(func, 0)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
def test_while(self, with_function=False):
|
||||
def test_while(self):
|
||||
# Some constants to capture in the conditional branches
|
||||
cond_const = np.ones(3, dtype=np.float32)
|
||||
body_const1 = np.full_like(cond_const, 1.)
|
||||
@ -163,13 +144,9 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
return lax.while_loop(cond, body, (0, x))
|
||||
|
||||
self.ConvertAndCompare(func, cond_const, with_function=with_function)
|
||||
self.ConvertAndCompare(func, cond_const)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
def test_while_batched_cond(self, with_function=True):
|
||||
def test_while_batched_cond(self):
|
||||
"""A while with a single carry"""
|
||||
def product(x, y):
|
||||
# Equivalent to "x * y" implemented as:
|
||||
@ -190,7 +167,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
def product_xs_ys(xs, ys):
|
||||
return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys)
|
||||
|
||||
self.ConvertAndCompare(product_xs_ys, xs, ys, with_function=with_function)
|
||||
self.ConvertAndCompare(product_xs_ys, xs, ys)
|
||||
|
||||
def test_while_custom_jvp(self):
|
||||
"""Conversion of function with custom JVP, inside while.
|
||||
@ -218,11 +195,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
self.TransformConvertAndCompare(g, arg, "jvp_vmap")
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_function={with_function}",
|
||||
with_function=with_function)
|
||||
for with_function in [False, True]))
|
||||
def test_scan(self, with_function=False):
|
||||
def test_scan(self):
|
||||
def f_jax(xs, ys):
|
||||
body_const = np.ones((2, ), dtype=np.float32) # Test constant capture
|
||||
def body(res0, inputs):
|
||||
@ -231,9 +204,9 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
return lax.scan(body, 0., (xs, ys))
|
||||
|
||||
arg = np.arange(10, dtype=np.float32)
|
||||
self.ConvertAndCompare(f_jax, arg, arg, with_function=with_function)
|
||||
self.ConvertAndCompare(f_jax, arg, arg, check_compiled=False)
|
||||
|
||||
def test_scan_partial_eval(self, with_function=False):
|
||||
def test_scan_partial_eval(self):
|
||||
def f_jax(xs, ys):
|
||||
body_const = np.ones((2, ), dtype=np.float32) # Test constant capture
|
||||
def body(res0, inputs):
|
||||
@ -244,7 +217,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
arg = np.arange(10, dtype=np.float32)
|
||||
print(jax.make_jaxpr(jax.grad(f_jax))(arg, arg))
|
||||
self.ConvertAndCompare(jax.grad(f_jax), arg, arg, with_function=with_function)
|
||||
self.ConvertAndCompare(jax.grad(f_jax), arg, arg, check_compiled=False)
|
||||
|
||||
|
||||
def test_scan_custom_jvp(self):
|
||||
@ -268,12 +241,12 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
x)[0]
|
||||
|
||||
arg = np.full((5,), 0.7)
|
||||
self.TransformConvertAndCompare(g, arg, None)
|
||||
self.TransformConvertAndCompare(g, arg, "jvp")
|
||||
self.TransformConvertAndCompare(g, arg, "vmap")
|
||||
self.TransformConvertAndCompare(g, arg, "jvp_vmap")
|
||||
self.TransformConvertAndCompare(g, arg, "grad")
|
||||
self.TransformConvertAndCompare(g, arg, "grad_vmap")
|
||||
self.TransformConvertAndCompare(g, arg, None, check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "jvp", check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "vmap", check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "jvp_vmap", check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "grad", check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "grad_vmap", check_compiled=False)
|
||||
|
||||
def test_scan_custom_vjp(self):
|
||||
"""Conversion of function with custom VJP, inside scan.
|
||||
@ -297,10 +270,10 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
x)[0]
|
||||
|
||||
arg = np.full((5,), 0.7)
|
||||
self.TransformConvertAndCompare(g, arg, None)
|
||||
self.TransformConvertAndCompare(g, arg, "vmap")
|
||||
self.TransformConvertAndCompare(g, arg, "grad")
|
||||
self.TransformConvertAndCompare(g, arg, "grad_vmap")
|
||||
self.TransformConvertAndCompare(g, arg, None, check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "vmap", check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "grad", check_compiled=False)
|
||||
self.TransformConvertAndCompare(g, arg, "grad_vmap", check_compiled=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -95,10 +95,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_dtype={dtype.__name__}_function={with_function}",
|
||||
with_function=with_function,
|
||||
dict(testcase_name=f"_dtype={dtype.__name__}",
|
||||
dtype=dtype)
|
||||
for with_function in [False, True]
|
||||
for dtype in [np.int64, np.float64]))
|
||||
def test_converts_64bit(self, dtype=np.int64, with_function=False):
|
||||
big_const = np.full((5,), 2 ** 33, dtype=dtype)
|
||||
@ -115,7 +113,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_function(self):
|
||||
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
||||
self.ConvertAndCompare(f_jax, 0.7, with_function=True)
|
||||
self.ConvertAndCompare(f_jax, 0.7)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"function={with_function}",
|
||||
@ -286,12 +284,12 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
return primal_out, tangent_out
|
||||
|
||||
arg = 0.7
|
||||
self.TransformConvertAndCompare(f, arg, None, with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "jvp", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "vmap", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "jvp_vmap", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "grad", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "grad_vmap", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, None)
|
||||
self.TransformConvertAndCompare(f, arg, "jvp")
|
||||
self.TransformConvertAndCompare(f, arg, "vmap")
|
||||
self.TransformConvertAndCompare(f, arg, "jvp_vmap")
|
||||
self.TransformConvertAndCompare(f, arg, "grad")
|
||||
self.TransformConvertAndCompare(f, arg, "grad_vmap")
|
||||
|
||||
def test_custom_vjp(self):
|
||||
"""Conversion of function with custom VJP"""
|
||||
@ -308,10 +306,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
arg = 0.7
|
||||
self.TransformConvertAndCompare(f, arg, None, with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "vmap", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "grad", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, "grad_vmap", with_function=True)
|
||||
self.TransformConvertAndCompare(f, arg, None)
|
||||
self.TransformConvertAndCompare(f, arg, "vmap")
|
||||
self.TransformConvertAndCompare(f, arg, "grad")
|
||||
self.TransformConvertAndCompare(f, arg, "grad_vmap")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -33,19 +33,44 @@ FLAGS = config.FLAGS
|
||||
Rng = Any # A random number generator
|
||||
|
||||
class RandArg(NamedTuple):
|
||||
"""Descriptor for a randomly generated argument."""
|
||||
"""Descriptor for a randomly generated argument.
|
||||
|
||||
See description of `Harness`.
|
||||
"""
|
||||
shape: Tuple[int, ...]
|
||||
dtype: np.dtype
|
||||
|
||||
class StaticArg(NamedTuple):
|
||||
"""Descriptor for a static argument."""
|
||||
"""Descriptor for a static argument.
|
||||
|
||||
See description of `Harness`.
|
||||
"""
|
||||
value: Any
|
||||
|
||||
class Harness:
|
||||
"""Specifies inputs and callable for a primitive.
|
||||
|
||||
A primitive can take dynamic and static arguments. The dynamic arguments can
|
||||
be generated using a RNG, are numeric (and appropriate for JIT).
|
||||
A harness is conceptually a callable and a list of arguments, that together
|
||||
exercise a use case. The harness can optionally have additional parameters
|
||||
that can be used by the test.
|
||||
|
||||
The arguments are specified through argument descriptors. An argument
|
||||
descriptor can be:
|
||||
* a numeric value or ndarray, or
|
||||
* an instance of ``RandArg(shape, dtype)`` to be used with a PRNG to generate
|
||||
random tensor of the given shape and type, or
|
||||
* an instance of ``StaticArg(value)``. These are values that specialize the
|
||||
callable, but are not exposed as external arguments.
|
||||
|
||||
For example, a harness for ``lax.take(arr, indices, axis=None)`` may want
|
||||
to expose as external (dynamic) argument the array and the indices, and
|
||||
keep the axis as a static argument (technically specializing the `take` to
|
||||
a axis):
|
||||
|
||||
Harness(f"take_axis={axis}",
|
||||
lax.take,
|
||||
[RandArg((2, 4), np.float32), np.array([-1, 0, 1]), StaticArg(axis)],
|
||||
axis=axis)
|
||||
"""
|
||||
# Descriptive name of the harness, used as a testcase_name. Unique in a group.
|
||||
name: str
|
||||
@ -124,9 +149,117 @@ def parameterized(harness_group: Iterable[Harness],
|
||||
cases = cases[0:1]
|
||||
return testing.parameterized.named_parameters(*cases)
|
||||
|
||||
### Harness definitions ###
|
||||
###
|
||||
_LAX_UNARY_ELEMENTWISE = (
|
||||
lax.abs, lax.acosh, lax.asinh, lax.atanh, lax.bessel_i0e, lax.bessel_i1e,
|
||||
lax.ceil, lax.cos, lax.cosh, lax.digamma, lax.erf, lax.erf_inv, lax.erfc,
|
||||
lax.exp, lax.expm1, lax.floor, lax.is_finite, lax.lgamma, lax.log,
|
||||
lax.log1p, lax.neg, lax.round, lax.rsqrt, lax.sign, lax.sin, lax.sinh,
|
||||
lax.sqrt, lax.tan, lax.tanh)
|
||||
|
||||
lax_unary_elementwise = tuple(
|
||||
Harness(f"{f_lax.__name__}_{jtu.dtype_str(dtype)}",
|
||||
f_lax,
|
||||
[arg],
|
||||
lax_name=f_lax.__name__,
|
||||
dtype=dtype)
|
||||
for f_lax in _LAX_UNARY_ELEMENTWISE
|
||||
for dtype in jtu.dtypes.all_floating
|
||||
for arg in [
|
||||
np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1., 1.4, 1.6], dtype=dtype)
|
||||
]
|
||||
)
|
||||
|
||||
lax_bitwise_not = tuple(
|
||||
[Harness(f"{jtu.dtype_str(dtype)}",
|
||||
lax.bitwise_not,
|
||||
[arg],
|
||||
dtype=dtype)
|
||||
for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned
|
||||
for arg in [
|
||||
np.array([-1, -3, -2, 0, 0, 2, 1, 3], dtype=dtype),
|
||||
]] +
|
||||
|
||||
[Harness("bool",
|
||||
f_lax,
|
||||
[arg],
|
||||
lax_name=f_lax.__name__,
|
||||
dtype=np.bool_)
|
||||
for f_lax in [lax.bitwise_not]
|
||||
for arg in [
|
||||
np.array([True, False])
|
||||
]]
|
||||
)
|
||||
|
||||
_LAX_BINARY_ELEMENTWISE = (
|
||||
lax.add, lax.atan2, lax.div, lax.igamma, lax.igammac, lax.max, lax.min,
|
||||
lax.nextafter, lax.rem, lax.sub)
|
||||
|
||||
lax_binary_elementwise = tuple(
|
||||
Harness(f"{f_lax.__name__}_{jtu.dtype_str(dtype)}",
|
||||
f_lax,
|
||||
[arg1, arg2],
|
||||
lax_name=f_lax.__name__,
|
||||
dtype=dtype
|
||||
)
|
||||
for f_lax in _LAX_BINARY_ELEMENTWISE
|
||||
for dtype in jtu.dtypes.all_floating
|
||||
for arg1, arg2 in [
|
||||
(np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1., 1.4, 1.6], dtype=dtype),
|
||||
np.array([-1.6, 1.4, 1.0, 0.0, 0.1, 0.2, 1., 1.4, -1.6], dtype=dtype))
|
||||
]
|
||||
)
|
||||
|
||||
_LAX_BINARY_ELEMENTWISE_LOGICAL = (
|
||||
lax.bitwise_and, lax.bitwise_or, lax.bitwise_xor, lax.shift_left,
|
||||
)
|
||||
|
||||
lax_binary_elementwise_logical = tuple(
|
||||
[Harness(f"{f_lax.__name__}_{jtu.dtype_str(dtype)}",
|
||||
f_lax,
|
||||
[arg1, arg2],
|
||||
lax_name=f_lax.__name__,
|
||||
dtype=dtype)
|
||||
for f_lax in _LAX_BINARY_ELEMENTWISE_LOGICAL
|
||||
for dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned
|
||||
for arg1, arg2 in [
|
||||
(np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=dtype),
|
||||
np.array([1, 2, 3, 0, 1, 0, 2, 3], dtype=dtype))
|
||||
]
|
||||
] +
|
||||
|
||||
[Harness(f"{f_lax.__name__}_bool",
|
||||
f_lax,
|
||||
[arg1, arg2],
|
||||
lax_name=f_lax.__name__,
|
||||
dtype=np.bool_)
|
||||
for f_lax in [lax.bitwise_and, lax.bitwise_or, lax.bitwise_xor]
|
||||
for arg1, arg2 in [
|
||||
(np.array([True, True, False, False]),
|
||||
np.array([True, False, True, False])),
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
lax_betainc = tuple(
|
||||
Harness(f"_{jtu.dtype_str(dtype)}",
|
||||
lax.betainc,
|
||||
[arg1, arg2, arg3],
|
||||
dtype=dtype)
|
||||
for dtype in jtu.dtypes.all_floating
|
||||
for arg1, arg2, arg3 in [
|
||||
(np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=dtype),
|
||||
np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=dtype),
|
||||
np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6],
|
||||
dtype=np.float32))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
_gather_input = np.arange(1000, dtype=np.float32).reshape((10, 10, 10))
|
||||
lax_gather = jtu.cases_from_list(
|
||||
lax_gather = tuple(
|
||||
# Construct gather harnesses using take
|
||||
[Harness(f"from_take_indices_shape={indices.shape}_axis={axis}",
|
||||
lambda a, i, axis: jnp.take(a, i, axis=axis),
|
||||
@ -168,7 +301,7 @@ lax_gather = jtu.cases_from_list(
|
||||
)
|
||||
|
||||
|
||||
lax_pad = jtu.cases_from_list(
|
||||
lax_pad = tuple(
|
||||
Harness(f"_inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_pads={pads}",
|
||||
lax.pad,
|
||||
[RandArg(arg_shape, dtype), np.array(0, dtype), StaticArg(pads)],
|
||||
@ -226,7 +359,7 @@ lax_sort = tuple( # one array, random data, all axes, all dtypes
|
||||
)
|
||||
|
||||
|
||||
lax_slice = jtu.cases_from_list(
|
||||
lax_slice = tuple(
|
||||
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
|
||||
lax.slice,
|
||||
[RandArg(shape, dtype), # type: ignore
|
||||
@ -267,7 +400,7 @@ lax_dynamic_slice = [
|
||||
for limit_indices in [harness.params["limit_indices"]]
|
||||
]
|
||||
|
||||
lax_dynamic_update_slice = jtu.cases_from_list(
|
||||
lax_dynamic_update_slice = tuple(
|
||||
Harness((f"_operand={jtu.format_shape_dtype_string(shape, dtype)}" # type: ignore
|
||||
f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}"
|
||||
f"_start_indices={start_indices}"),
|
||||
@ -292,7 +425,7 @@ lax_dynamic_update_slice = jtu.cases_from_list(
|
||||
(np.float64, np.float32)
|
||||
])
|
||||
|
||||
lax_squeeze = jtu.cases_from_list(
|
||||
lax_squeeze = tuple(
|
||||
Harness(f"_inshape={jtu.format_shape_dtype_string(arg_shape, dtype)}_dimensions={dimensions}", # type: ignore
|
||||
lax.squeeze,
|
||||
[RandArg(arg_shape, dtype), StaticArg(dimensions)], # type: ignore[has-type]
|
||||
@ -319,21 +452,21 @@ shift_inputs = [
|
||||
for shift_amount in [0, 1, 2, 3, 7]
|
||||
]
|
||||
|
||||
lax_shift_left = jtu.cases_from_list(
|
||||
lax_shift_left = tuple(
|
||||
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
|
||||
lax.shift_left,
|
||||
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
|
||||
for arg, dtype, shift_amount in shift_inputs
|
||||
)
|
||||
|
||||
lax_shift_right_logical = jtu.cases_from_list(
|
||||
lax_shift_right_logical = tuple(
|
||||
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
|
||||
lax.shift_right_logical,
|
||||
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
|
||||
for arg, dtype, shift_amount in shift_inputs
|
||||
)
|
||||
|
||||
lax_shift_right_arithmetic = jtu.cases_from_list(
|
||||
lax_shift_right_arithmetic = tuple(
|
||||
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
|
||||
lax.shift_right_arithmetic,
|
||||
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
|
||||
|
@ -35,78 +35,21 @@ config.parse_flags_with_absl()
|
||||
# Import after parsing flags
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
|
||||
# TODO(tomhennigan) Increase coverage here.
|
||||
LAX_ELEMENTWISE_UNARY = (
|
||||
lax.abs,
|
||||
lax.acosh,
|
||||
lax.asinh,
|
||||
lax.atanh,
|
||||
lax.bessel_i0e,
|
||||
lax.bessel_i1e,
|
||||
lax.ceil,
|
||||
lax.cos,
|
||||
lax.cosh,
|
||||
lax.digamma,
|
||||
lax.erf,
|
||||
lax.erf_inv,
|
||||
lax.erfc,
|
||||
lax.exp,
|
||||
lax.expm1,
|
||||
lax.floor,
|
||||
lax.is_finite,
|
||||
lax.lgamma,
|
||||
lax.log,
|
||||
lax.log1p,
|
||||
lax.neg,
|
||||
lax.round,
|
||||
lax.rsqrt,
|
||||
lax.sign,
|
||||
lax.sin,
|
||||
lax.sinh,
|
||||
lax.sqrt,
|
||||
lax.tan,
|
||||
lax.tanh,
|
||||
)
|
||||
|
||||
LAX_ELEMENTWISE_BINARY = (
|
||||
lax.add,
|
||||
lax.atan2,
|
||||
lax.div,
|
||||
lax.igamma,
|
||||
lax.igammac,
|
||||
lax.max,
|
||||
lax.min,
|
||||
lax.nextafter,
|
||||
lax.rem,
|
||||
lax.sub,
|
||||
)
|
||||
|
||||
LAX_LOGICAL_ELEMENTWISE_UNARY = (
|
||||
lax.bitwise_not,
|
||||
)
|
||||
|
||||
LAX_LOGICAL_ELEMENTWISE_BINARY = (
|
||||
lax.bitwise_and,
|
||||
lax.bitwise_or,
|
||||
lax.bitwise_xor,
|
||||
lax.shift_left,
|
||||
)
|
||||
|
||||
REDUCE = (
|
||||
jnp.all,
|
||||
jnp.any,
|
||||
jnp.max,
|
||||
jnp.min,
|
||||
jnp.prod,
|
||||
jnp.sum,
|
||||
jnp.all,
|
||||
jnp.any,
|
||||
jnp.max,
|
||||
jnp.min,
|
||||
jnp.prod,
|
||||
jnp.sum,
|
||||
)
|
||||
|
||||
INDEX = (
|
||||
jax.ops.index_add,
|
||||
jax.ops.index_max,
|
||||
jax.ops.index_min,
|
||||
jax.ops.index_mul,
|
||||
jax.ops.index_update,
|
||||
jax.ops.index_add,
|
||||
jax.ops.index_max,
|
||||
jax.ops.index_min,
|
||||
jax.ops.index_mul,
|
||||
jax.ops.index_update,
|
||||
)
|
||||
|
||||
|
||||
@ -147,14 +90,14 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
for y_dtype in types:
|
||||
x = np.array([1, 2], dtype=x_dtype)
|
||||
y = np.array([3, 4], dtype=y_dtype)
|
||||
self.ConvertAndCompare(f_jax, x, y, with_function=True)
|
||||
self.ConvertAndCompare(f_jax, x, y)
|
||||
|
||||
def test_concat(self):
|
||||
values = [np.array([1, 2], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.int32),
|
||||
np.array([1, 2], dtype=np.int8)]
|
||||
f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0))
|
||||
self.ConvertAndCompare(f_jax, values, with_function=True)
|
||||
self.ConvertAndCompare(f_jax, values)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_pad)
|
||||
def test_pad(self, harness: primitive_harness.Harness):
|
||||
@ -164,8 +107,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
# TODO: fix pad with negative padding in XLA (fixed on 06/16/2020)
|
||||
if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
|
||||
raise unittest.SkipTest("pad with negative pad not supported")
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=False)
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_sort)
|
||||
def test_sort(self, harness: primitive_harness.Harness):
|
||||
@ -189,190 +131,173 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
not harness.params["is_stable"]):
|
||||
# TODO: fix the TF GPU test
|
||||
raise unittest.SkipTest("GPU tests are running TF on CPU")
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=False)
|
||||
# TODO: if we enable this test, we get the error
|
||||
# iterating over `tf.Tensor` is not allowed: AutoGraph is disabled in this function.
|
||||
raise unittest.SkipTest("TODO: re-enable the sort test")
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
f_jax=f_jax)
|
||||
for f_jax in LAX_ELEMENTWISE_UNARY))
|
||||
def test_unary_elementwise(self, f_jax=lax.abs):
|
||||
x = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
|
||||
dtype=np.float32)
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(x)
|
||||
r_tf = f_tf(x)
|
||||
self.assertAllClose(r_jax[np.isfinite(r_jax)],
|
||||
r_tf[np.isfinite(r_tf)], atol=1e-4)
|
||||
@primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
|
||||
def test_unary_elementwise(self, harness: primitive_harness.Harness):
|
||||
dtype = harness.params["dtype"]
|
||||
if dtype is dtypes.bfloat16:
|
||||
raise unittest.SkipTest("bfloat16 not implemented")
|
||||
arg, = harness.dyn_args_maker(self.rng())
|
||||
custom_assert = None
|
||||
if harness.params["lax_name"] == "digamma":
|
||||
# digamma is not defined at 0 and -1
|
||||
def custom_assert(result_jax, result_tf):
|
||||
# lax.digamma returns NaN and tf.math.digamma returns inf
|
||||
special_cases = (arg == 0.) | (arg == -1.)
|
||||
nr_special_cases = np.count_nonzero(special_cases)
|
||||
self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
|
||||
result_jax[special_cases])
|
||||
self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)),
|
||||
result_tf[special_cases])
|
||||
# non-special cases are equal
|
||||
self.assertAllClose(result_jax[~ special_cases],
|
||||
result_tf[~ special_cases])
|
||||
if harness.params["lax_name"] == "erf_inv":
|
||||
# TODO(necula): fix bug with erf_inv/f16
|
||||
if dtype is np.float16:
|
||||
raise unittest.SkipTest("TODO: fix bug")
|
||||
# erf_inf is not defined for arg <= -1 or arg >= 1
|
||||
def custom_assert(result_jax, result_tf): # noqa: F811
|
||||
# for arg < -1 or arg > 1
|
||||
# lax.erf_inf returns NaN; tf.math.erf_inv return +/- inf
|
||||
special_cases = (arg < -1.) | (arg > 1.)
|
||||
nr_special_cases = np.count_nonzero(special_cases)
|
||||
self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
|
||||
result_jax[special_cases])
|
||||
signs = np.where(arg[special_cases] < 0., -1., 1.)
|
||||
self.assertAllClose(np.full((nr_special_cases,), signs * dtype(np.inf)),
|
||||
result_tf[special_cases])
|
||||
# non-special cases are equal
|
||||
self.assertAllClose(result_jax[~ special_cases],
|
||||
result_tf[~ special_cases])
|
||||
self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert)
|
||||
|
||||
def test_bitwise_not(self):
|
||||
x = np.array([-1, 3, -2, 0, 0, 2, 1, 3], dtype=np.int32)
|
||||
f_jax = jax.jit(lax.bitwise_not)
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(x)
|
||||
r_tf = f_tf(x)
|
||||
self.assertAllClose(r_jax[np.isfinite(r_jax)],
|
||||
r_tf[np.isfinite(r_tf)], atol=1e-4)
|
||||
@primitive_harness.parameterized(primitive_harness.lax_bitwise_not)
|
||||
def test_bitwise_not(self, harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
f_jax=f_jax)
|
||||
for f_jax in LAX_ELEMENTWISE_BINARY))
|
||||
def test_binary_elementwise(self, f_jax=lax.add):
|
||||
a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
|
||||
dtype=np.float32)
|
||||
b = np.array([-1.6, 1.4, 1.0, 0.0, 0.1, 0.2, 1, 1.4, -1.6],
|
||||
dtype=np.float32)
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(a, b)
|
||||
r_tf = f_tf(a, b)
|
||||
# Jax outputs 0 and 1 instead of NaN for values outside the domain.
|
||||
# Whereas tensorflow does this for other combinations,
|
||||
if f_jax in (lax.igamma, lax.igammac):
|
||||
# Make returned array writeable.
|
||||
r_jax = np.copy(r_jax)
|
||||
r_jax[r_jax == 0] = np.nan
|
||||
r_jax[r_jax == 1] = np.nan
|
||||
r_tf = np.copy(r_tf)
|
||||
r_tf[r_tf == 0] = np.nan
|
||||
r_tf[r_tf == 1] = np.nan
|
||||
self.assertAllClose(r_jax[np.isfinite(r_jax)],
|
||||
r_tf[np.isfinite(r_tf)], atol=1e-4)
|
||||
@primitive_harness.parameterized(primitive_harness.lax_binary_elementwise)
|
||||
def test_binary_elementwise(self, harness):
|
||||
if harness.params["dtype"] is dtypes.bfloat16:
|
||||
raise unittest.SkipTest("bfloat16 not implemented")
|
||||
# TODO(necula): fix bug with igamma/f16
|
||||
if (harness.params["lax_name"] in ("igamma", "igammac") and
|
||||
harness.params["dtype"] is np.float16):
|
||||
raise unittest.SkipTest("TODO: fix bug")
|
||||
# TODO(necula): fix bug with nextafter/f16
|
||||
if (harness.params["lax_name"] == "nextafter" and
|
||||
harness.params["dtype"] is np.float16):
|
||||
raise unittest.SkipTest("TODO: understand unimplemented case")
|
||||
arg1, arg2 = harness.dyn_args_maker(self.rng())
|
||||
custom_assert = None
|
||||
if harness.params["lax_name"] == "igamma":
|
||||
# igamma is not defined when the first argument is <=0
|
||||
def custom_assert(result_jax, result_tf):
|
||||
# lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0
|
||||
special_cases = (arg1 == 0.) & (arg2 == 0.)
|
||||
nr_special_cases = np.count_nonzero(special_cases)
|
||||
self.assertAllClose(np.full((nr_special_cases,), np.nan),
|
||||
result_jax[special_cases])
|
||||
self.assertAllClose(np.full((nr_special_cases,), 0.),
|
||||
result_tf[special_cases])
|
||||
# non-special cases are equal
|
||||
self.assertAllClose(result_jax[~ special_cases],
|
||||
result_tf[~ special_cases])
|
||||
if harness.params["lax_name"] == "igammac":
|
||||
# igammac is not defined when the first argument is <=0
|
||||
def custom_assert(result_jax, result_tf): # noqa: F811
|
||||
# lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
|
||||
special_cases = (arg1 <= 0.) | (arg2 <= 0)
|
||||
nr_special_cases = np.count_nonzero(special_cases)
|
||||
self.assertAllClose(np.full((nr_special_cases,), 1.),
|
||||
result_jax[special_cases])
|
||||
self.assertAllClose(np.full((nr_special_cases,), np.nan),
|
||||
result_tf[special_cases])
|
||||
# non-special cases are equal
|
||||
self.assertAllClose(result_jax[~ special_cases],
|
||||
result_tf[~ special_cases])
|
||||
self.ConvertAndCompare(harness.dyn_fun, arg1, arg2,
|
||||
custom_assert=custom_assert)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
f_jax=f_jax)
|
||||
for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY))
|
||||
def test_binary_logical_elementwise(self, f_jax):
|
||||
a = np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=np.uint32)
|
||||
b = np.array([1, 2, 3, 0, 1, 0, 2, 3], dtype=np.uint32)
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(a, b)
|
||||
r_tf = f_tf(a, b)
|
||||
self.assertAllClose(r_jax[np.isfinite(r_jax)],
|
||||
r_tf[np.isfinite(r_tf)], atol=1e-4)
|
||||
# Checks support for bools.
|
||||
if f_jax in (lax.bitwise_and, lax.bitwise_or, lax.bitwise_xor):
|
||||
a = np.array([True, True, False, False])
|
||||
b = np.array([True, False, True, False])
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(a, b)
|
||||
r_tf = f_tf(a, b)
|
||||
self.assertArraysEqual(r_jax, np.asarray(r_tf))
|
||||
@primitive_harness.parameterized(primitive_harness.lax_binary_elementwise_logical)
|
||||
def test_binary_elementwise_logical(self, harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
f_jax=f_jax)
|
||||
for f_jax in LAX_LOGICAL_ELEMENTWISE_UNARY))
|
||||
def test_unary_logical_elementwise(self, f_jax):
|
||||
a = np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=np.uint32)
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(a)
|
||||
r_tf = f_tf(a)
|
||||
self.assertAllClose(r_jax[np.isfinite(r_jax)],
|
||||
r_tf[np.isfinite(r_tf)], atol=1e-4)
|
||||
# Checks support for bools.
|
||||
a = np.array([True, False])
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(a)
|
||||
r_tf = f_tf(a)
|
||||
self.assertArraysEqual(r_jax, np.asarray(r_tf))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
f_jax=f_jax)
|
||||
for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY))
|
||||
def test_binary_logical_elementwise_bool(self, f_jax):
|
||||
if f_jax == lax.shift_left:
|
||||
self.skipTest("Shift of bool not supported")
|
||||
a = np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=np.bool_)
|
||||
b = np.array([0, 1, 0, 1, 0, 1, 0, 1], dtype=np.bool_)
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(a, b)
|
||||
r_tf = f_tf(a, b)
|
||||
self.assertAllClose(r_jax, r_tf)
|
||||
@primitive_harness.parameterized(primitive_harness.lax_betainc)
|
||||
def test_betainc(self, harness: primitive_harness.Harness):
|
||||
if harness.params["dtype"] is dtypes.bfloat16:
|
||||
raise unittest.SkipTest("bfloat16 not implemented")
|
||||
# TODO(necula): fix bug with betainc/f16
|
||||
if harness.params["dtype"] is np.float16:
|
||||
raise unittest.SkipTest("TODO: understand betainc/f16 bug")
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
# TODO(necula): combine tests that are identical except for the harness
|
||||
# wait until we get more experience with using harnesses.
|
||||
@primitive_harness.parameterized(primitive_harness.lax_shift_left)
|
||||
def test_shift_left(self, harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_shift_right_logical)
|
||||
def test_shift_right_logical(self, harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_shift_right_arithmetic)
|
||||
def test_shift_right_arithmetic(self, harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_slice)
|
||||
def test_slice(self, harness):
|
||||
# JAX.slice rejects negative indices; check, and skip jax2tf
|
||||
if any(si < 0 or si >= sh or li < 0 or li > sh
|
||||
for sh, si, li in zip(harness.params["shape"],
|
||||
harness.params["start_indices"],
|
||||
harness.params["limit_indices"])):
|
||||
harness.params["start_indices"],
|
||||
harness.params["limit_indices"])):
|
||||
with self.assertRaisesRegex(TypeError, ""):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
else:
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_dynamic_slice)
|
||||
def test_dynamic_slice(self, harness):
|
||||
# JAX.dynamic_slice rejects slice sizes too big; check, and skip jax2tf
|
||||
# JAX.dynamic_slice rejects slice sizes too big; check this, and skip jax2tf
|
||||
if any(li - si < 0 or li - si >= sh
|
||||
for sh, si, li in zip(harness.params["shape"],
|
||||
harness.params["start_indices"],
|
||||
harness.params["limit_indices"])):
|
||||
harness.params["start_indices"],
|
||||
harness.params["limit_indices"])):
|
||||
with self.assertRaisesRegex(TypeError, ""):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
else:
|
||||
# TF compiler gives an error for tf.slice(start_indices < 0)
|
||||
if any(si < 0 for si in harness.params["start_indices"]):
|
||||
raise unittest.SkipTest("TF gives error for negative start_indices")
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
check_compiled=False)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_dynamic_update_slice)
|
||||
def test_dynamic_update_slice(self, harness):
|
||||
# JAX.dynamic_update_slice rejects update slices too big; check, and skip jax2tf
|
||||
if any(ush > sh
|
||||
for sh, ush in zip(harness.params["shape"],
|
||||
harness.params["update_shape"])):
|
||||
harness.params["update_shape"])):
|
||||
with self.assertRaisesRegex(TypeError, ""):
|
||||
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
|
||||
else:
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
f_jax=f_jax)
|
||||
for f_jax in (lax.betainc,)))
|
||||
def test_trinary_elementwise(self, f_jax):
|
||||
a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6],
|
||||
dtype=np.float32)
|
||||
b = np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6],
|
||||
dtype=np.float32)
|
||||
c = np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6],
|
||||
dtype=np.float32)
|
||||
f_tf = tf.function(jax2tf.convert(f_jax))
|
||||
r_jax = f_jax(a, b, c)
|
||||
r_tf = f_tf(a, b, c)
|
||||
self.assertAllClose(r_jax[np.isfinite(r_jax)],
|
||||
r_tf[np.isfinite(r_tf)], atol=1e-4)
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_squeeze)
|
||||
def test_squeeze(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=True)
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_gather)
|
||||
def test_gather(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
with_function=False)
|
||||
check_compiled=False)
|
||||
|
||||
def test_boolean_gather(self):
|
||||
values = np.array([[True, True], [False, True], [False, False]],
|
||||
@ -380,13 +305,14 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
indices = np.array([0, 1], dtype=np.int32)
|
||||
for axis in [0, 1]:
|
||||
f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop
|
||||
self.ConvertAndCompare(f_jax, values, indices, with_function=True)
|
||||
# TODO: why can't we compile this code?
|
||||
self.ConvertAndCompare(f_jax, values, indices, check_compiled=False)
|
||||
|
||||
def test_gather_rank_change(self):
|
||||
params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
|
||||
indices = jnp.array([[1, 1, 2], [0, 1, 0]])
|
||||
f_jax = jax.jit(lambda i: params[i])
|
||||
self.ConvertAndCompare(f_jax, indices, with_function=True)
|
||||
self.ConvertAndCompare(f_jax, indices, check_compiled=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
@ -394,7 +320,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
for f_jax in REDUCE))
|
||||
def test_reduce_ops_with_numerical_input(self, f_jax):
|
||||
values = np.array([1, 2, 3], dtype=np.float32)
|
||||
self.ConvertAndCompare(f_jax, values, with_function=True)
|
||||
self.ConvertAndCompare(f_jax, values)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
@ -402,7 +328,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
for f_jax in (jnp.cumsum, jnp.cumprod)))
|
||||
def test_cumulated_ops(self, f_jax):
|
||||
values = np.array([1, 2, 3], dtype=np.float32)
|
||||
self.ConvertAndCompare(f_jax, values, with_function=True)
|
||||
self.ConvertAndCompare(f_jax, values)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{op.__name__}",
|
||||
@ -412,7 +338,9 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
values = np.ones((5, 6), dtype=np.float32)
|
||||
update = np.float32(6.)
|
||||
f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
|
||||
self.ConvertAndCompare(f_jax, values, update, with_function=True)
|
||||
# TODO: compilation fails
|
||||
self.ConvertAndCompare(f_jax, values, update,
|
||||
check_compiled=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{f_jax.__name__}",
|
||||
@ -420,8 +348,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
for f_jax in REDUCE))
|
||||
def test_reduce_ops_with_boolean_input(self, f_jax):
|
||||
values = np.array([True, False, True], dtype=np.bool_)
|
||||
self.ConvertAndCompare(f_jax, values, with_function=True)
|
||||
|
||||
self.ConvertAndCompare(f_jax, values)
|
||||
|
||||
def test_prngsplit(self):
|
||||
f_jax = jax.jit(lambda key: jax.random.split(key, 2))
|
||||
@ -431,8 +358,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
|
||||
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
|
||||
]:
|
||||
self.ConvertAndCompare(f_jax, rng_key, with_function=True)
|
||||
|
||||
self.ConvertAndCompare(f_jax, rng_key)
|
||||
|
||||
def test_zeros_like(self):
|
||||
v = np.float32(2.)
|
||||
@ -443,5 +369,6 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
f = jax2tf.convert(lax.stop_gradient)
|
||||
self.assertEqual(f(tf.ones([])), 1.)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -60,24 +60,54 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
to_numpy_dtype(jtu._dtype(y)))
|
||||
|
||||
def ConvertAndCompare(self, func_jax: Callable, *args,
|
||||
with_function: bool = False,
|
||||
custom_assert: Optional[Callable] = None,
|
||||
check_compiled: bool = True,
|
||||
atol=None,
|
||||
rtol=None) -> Tuple[Any, Any]:
|
||||
"""Compares jax_func(*args) with convert(jax_func)(*args)."""
|
||||
"""Compares jax_func(*args) with convert(jax_func)(*args).
|
||||
|
||||
It compares the result of JAX, TF, TF with tf.function, and TF with
|
||||
tf.function(experimental_compile=True).
|
||||
|
||||
Args:
|
||||
check_compiled: check that JAX and tf.function(experimental_compile=True)
|
||||
produce the same results.
|
||||
custom_assert: a function that will be called
|
||||
`custom_assert(result_jax, result_tf)` to assert equality of the
|
||||
results. Use this function when JAX and TF produce different results.
|
||||
This function is not used for the experimental_compile case, because
|
||||
in that case we expect always the results to be equal.
|
||||
|
||||
"""
|
||||
result_jax = func_jax(*args)
|
||||
|
||||
func_tf = jax2tf.convert(func_jax)
|
||||
if with_function:
|
||||
func_tf = tf.function(func_tf, autograph=False)
|
||||
res_jax = func_jax(*args)
|
||||
#logging.info(f"res_jax is {res_jax} on {res_jax.device_buffer.device()}")
|
||||
res_tf = func_tf(*args)
|
||||
#logging.info(f"res_tf is {res_tf} on {res_tf.backing_device}")
|
||||
self.assertAllClose(res_jax, res_tf, atol=atol, rtol=rtol)
|
||||
return (res_jax, res_tf)
|
||||
result_tf = func_tf(*args)
|
||||
# Sometimes JAX and TF(compile=False) give different results
|
||||
if custom_assert is not None:
|
||||
custom_assert(result_jax, result_tf)
|
||||
else:
|
||||
self.assertAllClose(result_jax, result_tf, atol=atol, rtol=rtol)
|
||||
|
||||
# Using tf.function should not make a difference. Is this even something
|
||||
# we should test here?
|
||||
result_tf_function = tf.function(func_tf, autograph=False)(*args)
|
||||
self.assertAllClose(result_tf, result_tf_function, atol=atol, rtol=rtol)
|
||||
|
||||
# The result of JAX and TF with compile=True are always the same
|
||||
# TODO: enable compilation
|
||||
if check_compiled:
|
||||
result_tf_function_compile = tf.function(func_tf, autograph=False,
|
||||
experimental_compile=True)(*args)
|
||||
self.assertAllClose(result_jax, result_tf_function_compile,
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
return (result_jax, result_tf)
|
||||
|
||||
def TransformConvertAndCompare(self, func: Callable,
|
||||
arg,
|
||||
transform: Optional[str],
|
||||
with_function: bool = False):
|
||||
check_compiled=True):
|
||||
"""Like ConvertAndCompare but first applies a transformation.
|
||||
|
||||
`func` must be a function from one argument to one result. `arg` is
|
||||
@ -86,27 +116,27 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
`transform` can be None, "jvp", "grad", "vmap", "jvp_vmap", "grad_vmap"
|
||||
"""
|
||||
if transform is None:
|
||||
return self.ConvertAndCompare(func, arg, with_function=with_function)
|
||||
return self.ConvertAndCompare(func, arg, check_compiled=check_compiled)
|
||||
if transform == "jvp":
|
||||
t_func = lambda x, xt: jax.jvp(func, (x,), (xt,))
|
||||
return self.ConvertAndCompare(t_func, arg, np.full_like(arg, 0.1),
|
||||
with_function=with_function)
|
||||
check_compiled=check_compiled)
|
||||
if transform == "grad":
|
||||
return self.ConvertAndCompare(jax.grad(func), arg,
|
||||
with_function=with_function)
|
||||
check_compiled=check_compiled)
|
||||
if transform == "vmap":
|
||||
t_arg = np.stack([arg] * 4)
|
||||
return self.ConvertAndCompare(jax.vmap(func),
|
||||
t_arg, with_function=with_function)
|
||||
return self.ConvertAndCompare(jax.vmap(func), t_arg,
|
||||
check_compiled=check_compiled)
|
||||
if transform == "jvp_vmap":
|
||||
jvp_func = lambda x, xt: jax.jvp(jax.vmap(func), (x,), (xt,))
|
||||
t_arg = np.stack([arg] * 4)
|
||||
return self.ConvertAndCompare(jvp_func, t_arg,
|
||||
np.full_like(t_arg, 0.1),
|
||||
with_function=with_function)
|
||||
check_compiled=check_compiled)
|
||||
if transform == "grad_vmap":
|
||||
grad_func = jax.grad(lambda x: jnp.sum(jax.vmap(func)(x)))
|
||||
t_arg = np.stack([arg] * 4)
|
||||
return self.ConvertAndCompare(grad_func, t_arg,
|
||||
with_function=with_function)
|
||||
check_compiled=check_compiled)
|
||||
assert False, transform
|
||||
|
Loading…
x
Reference in New Issue
Block a user