[jax2tf] Reduce number of tests for sort.

This commit is contained in:
Benjamin Chetioui 2020-10-30 10:28:55 +01:00
parent 425e3d1c06
commit 43c427839d

View File

@ -494,55 +494,39 @@ lax_top_k = tuple( # random testing
for k in [1, 3, 6]
)
lax_sort = tuple( # one array, random data, all axes, all dtypes
Harness(f"one_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={dimension}_isstable={is_stable}",
lax.sort,
[RandArg(shape, dtype), StaticArg(dimension), StaticArg(is_stable)],
shape=shape,
dimension=dimension,
dtype=dtype,
is_stable=is_stable)
def _make_sort_harness(name, *, operands=None, shape=(5, 7), dtype=np.float32,
dimension=0, is_stable=False, nb_arrays=1):
if operands is None:
operands = [RandArg(shape, dtype) for _ in range(nb_arrays)]
return Harness(f"{name}_nbarrays={nb_arrays}_shape={jtu.format_shape_dtype_string(operands[0].shape, operands[0].dtype)}_axis={dimension}_isstable={is_stable}",
lambda *args: lax.sort_p.bind(*args[:-2], dimension=args[-2],
is_stable=args[-1], num_keys=1),
[*operands, StaticArg(dimension), StaticArg(is_stable)],
shape=operands[0].shape,
dimension=dimension,
dtype=operands[0].dtype,
is_stable=is_stable,
nb_arrays=nb_arrays)
lax_sort = tuple( # Validate dtypes
_make_sort_harness("dtypes", dtype=dtype)
for dtype in jtu.dtypes.all
for shape in [(5,), (5, 7)]
for dimension in range(len(shape))
for is_stable in [False, True]
) + tuple( # one array, potential edge cases
Harness(f"one_special_array_shape={jtu.format_shape_dtype_string(arr.shape, arr.dtype)}_axis={dimension}_isstable={is_stable}",
lax.sort,
[arr, StaticArg(dimension), StaticArg(is_stable)],
shape=arr.shape,
dimension=dimension,
dtype=arr.dtype,
is_stable=is_stable)
for arr, dimension in [
[np.array([+np.inf, np.nan, -np.nan, -np.inf, 2, 4, 189], dtype=np.float32), -1]
) + tuple( # Validate dimensions
[_make_sort_harness("dimensions", dimension=1)]
) + tuple( # Validate stable sort
[_make_sort_harness("is_stable", is_stable=True)]
) + tuple( # Potential edge cases
_make_sort_harness("edge_cases", operands=operands, dimension=dimension)
for operands, dimension in [
([np.array([+np.inf, np.nan, -np.nan, -np.inf, 2], dtype=np.float32)], -1)
]
) + tuple( # Validate multiple arrays
_make_sort_harness("multiple_arrays", nb_arrays=nb_arrays, dtype=dtype)
for nb_arrays, dtype in [
(2, np.float32), # equivalent to sort_key_val
(2, np.bool_), # unsupported
(3, np.float32), # unsupported
]
for is_stable in [False, True]
) + tuple( # 2 arrays, random data, all axes, all dtypes
Harness(f"two_arrays_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={dimension}_isstable={is_stable}",
lambda *args: lax.sort_p.bind(*args[:-2], dimension=args[-2], is_stable=args[-1], num_keys=1),
[RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(dimension), StaticArg(is_stable)],
shape=shape,
dimension=dimension,
dtype=dtype,
is_stable=is_stable)
for dtype in jtu.dtypes.all
for shape in [(5,), (5, 7)]
for dimension in range(len(shape))
for is_stable in [False, True]
) + tuple( # 3 arrays, random data, all axes, all dtypes
Harness(f"three_arrays_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={dimension}_isstable={is_stable}",
lambda *args: lax.sort_p.bind(*args[:-2], dimension=args[-2], is_stable=args[-1], num_keys=1),
[RandArg(shape, dtype), RandArg(shape, dtype), RandArg(shape, dtype),
StaticArg(dimension), StaticArg(is_stable)],
shape=shape,
dimension=dimension,
dtype=dtype,
is_stable=is_stable)
for dtype in jtu.dtypes.all
for shape in [(5,)]
for dimension in (0,)
for is_stable in [False, True]
)
lax_linalg_cholesky = tuple(