mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add testing for select_p conversion.
This commit is contained in:
parent
59e92d3154
commit
81770e82e1
@ -519,6 +519,26 @@ lax_pad = tuple(
|
||||
]
|
||||
)
|
||||
|
||||
def _make_select_harness(name, *, shape_pred=(2, 3), shape_args=(2, 3),
|
||||
dtype=np.float32):
|
||||
return Harness(f"{name}_shapepred={jtu.format_shape_dtype_string(shape_pred, np.bool_)}_shapeargs={jtu.format_shape_dtype_string(shape_args, dtype)}",
|
||||
lax.select,
|
||||
[RandArg(shape_pred, np.bool_), RandArg(shape_args, dtype),
|
||||
RandArg(shape_args, dtype)],
|
||||
shape_pred=shape_pred,
|
||||
shape_args=shape_args,
|
||||
dtype=dtype)
|
||||
|
||||
lax_select = tuple( # Validate dtypes
|
||||
_make_select_harness("dtypes", dtype=dtype)
|
||||
for dtype in jtu.dtypes.all
|
||||
) + tuple( # Validate shapes
|
||||
_make_select_harness("shapes", shape_pred=shape_pred, shape_args=shape_args)
|
||||
for shape_pred, shape_args in [
|
||||
((), (18,)), # scalar pred
|
||||
]
|
||||
)
|
||||
|
||||
def _make_cumreduce_harness(name, *, f_jax=lax_control_flow.cummin,
|
||||
shape=(8, 9), dtype=np.float32,
|
||||
axis=0, reverse=False):
|
||||
|
@ -111,6 +111,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_pad(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_select)
|
||||
def test_select(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_control_flow_cumreduce)
|
||||
def test_cumreduce(self, harness: primitive_harness.Harness):
|
||||
f_jax, dtype = harness.params["f_jax"], harness.params["dtype"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user