[jax2tf] Added testing and fixed conversion of conj.

This commit is contained in:
Benjamin Chetioui 2020-11-11 13:50:57 +01:00
parent d4b1215491
commit 850e7a87f4
3 changed files with 34 additions and 1 deletions

View File

@ -878,7 +878,18 @@ tf_impl[lax.bessel_i0e_p] = tf.math.bessel_i0e
tf_impl[lax.bessel_i1e_p] = tf.math.bessel_i1e
tf_impl[lax.complex_p] = tf.complex
tf_impl[lax.conj_p] = tf.math.conj
def _conj(x, **kwargs):
# The only dtypes that are allowed are: float32, float64, complex64, and
# complex128.
if x.dtype == tf.float32:
return tf.cast(x, tf.complex64)
elif x.dtype == tf.float64:
return tf.cast(x, tf.complex128)
else:
return tf.math.conj(x)
tf_impl[lax.conj_p] = _conj
tf_impl[lax.real_p] = tf.math.real
tf_impl[lax.imag_p] = tf.math.imag

View File

@ -849,6 +849,24 @@ lax_slice = tuple(
for dtype in [np.float32]
)
def _make_conj_harness(name, *, shape=(3, 4), dtype=np.float32, **kwargs):
return Harness(f"{name}_operand={jtu.format_shape_dtype_string(shape, dtype)}_kwargs={kwargs}".replace(" ", ""),
lambda x: lax.conj_p.bind(x, **kwargs),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
**kwargs)
lax_conj = tuple( # Validate dtypes
_make_conj_harness("dtypes", dtype=dtype)
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
) + tuple( # Validate kwargs
_make_conj_harness("kwargs", **kwargs)
for kwargs in [
{ "_input_dtype": np.float32 }, # expected kwarg for ad
]
)
# Use lax_slice, but (a) make the start_indices dynamic arg, and (b) no strides.
lax_dynamic_slice = [
Harness(harness.name,

View File

@ -673,6 +673,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
else:
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_conj)
def test_conj(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_dynamic_slice)
def test_dynamic_slice(self, harness):
# JAX.dynamic_slice rejects slice sizes too big; check this, and skip jax2tf