mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Added tests for the conversion of transpose.
This commit is contained in:
parent
44e671fac0
commit
e1e05140e8
@ -1386,8 +1386,8 @@ tf_impl[lax.rev_p] = _rev
|
||||
|
||||
tf_impl[lax.select_p] = tf.where
|
||||
|
||||
def _transpose(operand, permutation):
|
||||
return tf.transpose(operand, permutation)
|
||||
def _transpose(operand, *, permutation):
|
||||
return tf.transpose(operand, perm=permutation)
|
||||
tf_impl[lax.transpose_p] = _transpose
|
||||
|
||||
axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)
|
||||
|
@ -539,6 +539,26 @@ lax_select = tuple( # Validate dtypes
|
||||
]
|
||||
)
|
||||
|
||||
def _make_transpose_harness(name, *, shape=(2, 3), permutation=(1, 0),
|
||||
dtype=np.float32):
|
||||
return Harness(f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_permutation={permutation}".replace(' ', ''),
|
||||
lambda x: lax.transpose_p.bind(x, permutation=permutation),
|
||||
[RandArg(shape, dtype)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
permutation=permutation)
|
||||
|
||||
lax_transpose = tuple( # Validate dtypes
|
||||
_make_transpose_harness("dtypes", dtype=dtype)
|
||||
for dtype in jtu.dtypes.all
|
||||
) + tuple( # Validate permutations
|
||||
_make_transpose_harness("permutations", shape=shape, permutation=permutation)
|
||||
for shape, permutation in [
|
||||
((2, 3, 4), (0, 1, 2)), # identity
|
||||
((2, 3, 4), (1, 2, 0)), # transposition
|
||||
]
|
||||
)
|
||||
|
||||
def _make_cumreduce_harness(name, *, f_jax=lax_control_flow.cummin,
|
||||
shape=(8, 9), dtype=np.float32,
|
||||
axis=0, reverse=False):
|
||||
|
@ -115,6 +115,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_select(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_transpose)
|
||||
def test_transpose(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_control_flow_cumreduce)
|
||||
def test_cumreduce(self, harness: primitive_harness.Harness):
|
||||
f_jax, dtype = harness.params["f_jax"], harness.params["dtype"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user