mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Reduce number of tests for sort.
This commit is contained in:
parent
425e3d1c06
commit
43c427839d
@ -494,55 +494,39 @@ lax_top_k = tuple( # random testing
|
||||
for k in [1, 3, 6]
|
||||
)
|
||||
|
||||
lax_sort = tuple( # one array, random data, all axes, all dtypes
|
||||
Harness(f"one_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={dimension}_isstable={is_stable}",
|
||||
lax.sort,
|
||||
[RandArg(shape, dtype), StaticArg(dimension), StaticArg(is_stable)],
|
||||
shape=shape,
|
||||
dimension=dimension,
|
||||
dtype=dtype,
|
||||
is_stable=is_stable)
|
||||
def _make_sort_harness(name, *, operands=None, shape=(5, 7), dtype=np.float32,
|
||||
dimension=0, is_stable=False, nb_arrays=1):
|
||||
if operands is None:
|
||||
operands = [RandArg(shape, dtype) for _ in range(nb_arrays)]
|
||||
return Harness(f"{name}_nbarrays={nb_arrays}_shape={jtu.format_shape_dtype_string(operands[0].shape, operands[0].dtype)}_axis={dimension}_isstable={is_stable}",
|
||||
lambda *args: lax.sort_p.bind(*args[:-2], dimension=args[-2],
|
||||
is_stable=args[-1], num_keys=1),
|
||||
[*operands, StaticArg(dimension), StaticArg(is_stable)],
|
||||
shape=operands[0].shape,
|
||||
dimension=dimension,
|
||||
dtype=operands[0].dtype,
|
||||
is_stable=is_stable,
|
||||
nb_arrays=nb_arrays)
|
||||
|
||||
lax_sort = tuple( # Validate dtypes
|
||||
_make_sort_harness("dtypes", dtype=dtype)
|
||||
for dtype in jtu.dtypes.all
|
||||
for shape in [(5,), (5, 7)]
|
||||
for dimension in range(len(shape))
|
||||
for is_stable in [False, True]
|
||||
) + tuple( # one array, potential edge cases
|
||||
Harness(f"one_special_array_shape={jtu.format_shape_dtype_string(arr.shape, arr.dtype)}_axis={dimension}_isstable={is_stable}",
|
||||
lax.sort,
|
||||
[arr, StaticArg(dimension), StaticArg(is_stable)],
|
||||
shape=arr.shape,
|
||||
dimension=dimension,
|
||||
dtype=arr.dtype,
|
||||
is_stable=is_stable)
|
||||
for arr, dimension in [
|
||||
[np.array([+np.inf, np.nan, -np.nan, -np.inf, 2, 4, 189], dtype=np.float32), -1]
|
||||
) + tuple( # Validate dimensions
|
||||
[_make_sort_harness("dimensions", dimension=1)]
|
||||
) + tuple( # Validate stable sort
|
||||
[_make_sort_harness("is_stable", is_stable=True)]
|
||||
) + tuple( # Potential edge cases
|
||||
_make_sort_harness("edge_cases", operands=operands, dimension=dimension)
|
||||
for operands, dimension in [
|
||||
([np.array([+np.inf, np.nan, -np.nan, -np.inf, 2], dtype=np.float32)], -1)
|
||||
]
|
||||
) + tuple( # Validate multiple arrays
|
||||
_make_sort_harness("multiple_arrays", nb_arrays=nb_arrays, dtype=dtype)
|
||||
for nb_arrays, dtype in [
|
||||
(2, np.float32), # equivalent to sort_key_val
|
||||
(2, np.bool_), # unsupported
|
||||
(3, np.float32), # unsupported
|
||||
]
|
||||
for is_stable in [False, True]
|
||||
) + tuple( # 2 arrays, random data, all axes, all dtypes
|
||||
Harness(f"two_arrays_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={dimension}_isstable={is_stable}",
|
||||
lambda *args: lax.sort_p.bind(*args[:-2], dimension=args[-2], is_stable=args[-1], num_keys=1),
|
||||
[RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(dimension), StaticArg(is_stable)],
|
||||
shape=shape,
|
||||
dimension=dimension,
|
||||
dtype=dtype,
|
||||
is_stable=is_stable)
|
||||
for dtype in jtu.dtypes.all
|
||||
for shape in [(5,), (5, 7)]
|
||||
for dimension in range(len(shape))
|
||||
for is_stable in [False, True]
|
||||
) + tuple( # 3 arrays, random data, all axes, all dtypes
|
||||
Harness(f"three_arrays_shape={jtu.format_shape_dtype_string(shape, dtype)}_axis={dimension}_isstable={is_stable}",
|
||||
lambda *args: lax.sort_p.bind(*args[:-2], dimension=args[-2], is_stable=args[-1], num_keys=1),
|
||||
[RandArg(shape, dtype), RandArg(shape, dtype), RandArg(shape, dtype),
|
||||
StaticArg(dimension), StaticArg(is_stable)],
|
||||
shape=shape,
|
||||
dimension=dimension,
|
||||
dtype=dtype,
|
||||
is_stable=is_stable)
|
||||
for dtype in jtu.dtypes.all
|
||||
for shape in [(5,)]
|
||||
for dimension in (0,)
|
||||
for is_stable in [False, True]
|
||||
)
|
||||
|
||||
lax_linalg_cholesky = tuple(
|
||||
|
Loading…
x
Reference in New Issue
Block a user