[jax2tf] Add testing for conversion of comparison operators.

This commit is contained in:
Benjamin Chetioui 2020-11-16 18:36:52 +01:00
parent 44e671fac0
commit 356e38b1a2
4 changed files with 39 additions and 0 deletions

View File

@ -105,7 +105,11 @@ conversion to Tensorflow.
| erf_inv | Missing TF support | Primitive is unimplemented in TF | float16 | CPU, GPU, TPU |
| erfc | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
| fft | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | complex128, float64 | CPU, GPU, TPU |
| ge | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
| gt | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
| le | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
| lgamma | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
| lt | Missing TF support | Primitive is unimplemented in TF | bool, uint16, uint32, uint64 | CPU, GPU, TPU |
| lu | Missing TF support | Primitive is unimplemented in TF | complex64 | TPU |
| max | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |
| min | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |

View File

@ -261,6 +261,10 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
# operations.
tf_unimpl(np_dtype)
if prim in [lax.le_p, lax.lt_p, lax.ge_p, lax.gt_p]:
if np_dtype in [np.bool_, np.uint16, np.uint32, np.uint64]:
tf_unimpl(np_dtype)
if prim is lax.fft_p:
if np_dtype in [np.float64, np.complex128]:
tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "

View File

@ -178,6 +178,33 @@ lax_unary_elementwise = tuple(
]
)
_LAX_COMPARATORS = (
lax.eq, lax.ge, lax.gt, lax.le, lax.lt, lax.ne)
def _make_comparator_harness(name, *, dtype=np.float32, op=lax.eq, lhs_shape=(),
rhs_shape=()):
return Harness(f"{name}_op={op.__name__}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
op,
[RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype)],
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dtype=dtype)
lax_comparators = tuple( # Validate dtypes
_make_comparator_harness("dtypes", dtype=dtype, op=op)
for op in _LAX_COMPARATORS
for dtype in (jtu.dtypes.all if op in [lax.eq, lax.ne] else
set(jtu.dtypes.all) - set(jtu.dtypes.complex))
) + tuple( # Validate broadcasting behavior
_make_comparator_harness("broadcasting", lhs_shape=lhs_shape,
rhs_shape=rhs_shape, op=op)
for op in _LAX_COMPARATORS
for lhs_shape, rhs_shape in [
((), (2, 3)), # broadcast scalar
((1, 2), (3, 2)), # broadcast along specific axis
]
)
lax_bitwise_not = tuple(
[Harness(f"{jtu.dtype_str(dtype)}",
lax.bitwise_not,

View File

@ -566,6 +566,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert,
atol=atol)
@primitive_harness.parameterized(primitive_harness.lax_comparators)
def test_comparators(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_bitwise_not)
def test_bitwise_not(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))