[jax2tf] Added tests for the conversion of transpose.

This commit is contained in:
Benjamin Chetioui 2020-11-17 15:17:21 +01:00
parent 44e671fac0
commit e1e05140e8
3 changed files with 26 additions and 2 deletions

View File

@ -1386,8 +1386,8 @@ tf_impl[lax.rev_p] = _rev
tf_impl[lax.select_p] = tf.where
def _transpose(operand, permutation):
return tf.transpose(operand, permutation)
def _transpose(operand, *, permutation):
return tf.transpose(operand, perm=permutation)
tf_impl[lax.transpose_p] = _transpose
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)

View File

@ -539,6 +539,26 @@ lax_select = tuple( # Validate dtypes
]
)
def _make_transpose_harness(name, *, shape=(2, 3), permutation=(1, 0),
dtype=np.float32):
return Harness(f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_permutation={permutation}".replace(' ', ''),
lambda x: lax.transpose_p.bind(x, permutation=permutation),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
permutation=permutation)
lax_transpose = tuple( # Validate dtypes
_make_transpose_harness("dtypes", dtype=dtype)
for dtype in jtu.dtypes.all
) + tuple( # Validate permutations
_make_transpose_harness("permutations", shape=shape, permutation=permutation)
for shape, permutation in [
((2, 3, 4), (0, 1, 2)), # identity
((2, 3, 4), (1, 2, 0)), # transposition
]
)
def _make_cumreduce_harness(name, *, f_jax=lax_control_flow.cummin,
shape=(8, 9), dtype=np.float32,
axis=0, reverse=False):

View File

@ -115,6 +115,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
def test_select(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_transpose)
def test_transpose(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_control_flow_cumreduce)
def test_cumreduce(self, harness: primitive_harness.Harness):
f_jax, dtype = harness.params["f_jax"], harness.params["dtype"]