[jax2tf] Add testing for select_p conversion.

This commit is contained in:
Benjamin Chetioui 2020-11-17 14:08:00 +01:00
parent 59e92d3154
commit 81770e82e1
2 changed files with 24 additions and 0 deletions

View File

@ -519,6 +519,26 @@ lax_pad = tuple(
]
)
def _make_select_harness(name, *, shape_pred=(2, 3), shape_args=(2, 3),
dtype=np.float32):
return Harness(f"{name}_shapepred={jtu.format_shape_dtype_string(shape_pred, np.bool_)}_shapeargs={jtu.format_shape_dtype_string(shape_args, dtype)}",
lax.select,
[RandArg(shape_pred, np.bool_), RandArg(shape_args, dtype),
RandArg(shape_args, dtype)],
shape_pred=shape_pred,
shape_args=shape_args,
dtype=dtype)
lax_select = tuple( # Validate dtypes
_make_select_harness("dtypes", dtype=dtype)
for dtype in jtu.dtypes.all
) + tuple( # Validate shapes
_make_select_harness("shapes", shape_pred=shape_pred, shape_args=shape_args)
for shape_pred, shape_args in [
((), (18,)), # scalar pred
]
)
def _make_cumreduce_harness(name, *, f_jax=lax_control_flow.cummin,
shape=(8, 9), dtype=np.float32,
axis=0, reverse=False):

View File

@ -111,6 +111,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
def test_pad(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_select)
def test_select(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_control_flow_cumreduce)
def test_cumreduce(self, harness: primitive_harness.Harness):
f_jax, dtype = harness.params["f_jax"], harness.params["dtype"]