[shape_poly] Copy many of the jax2tf/shape_poly_test to live outside of jax2tf.

Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf.

For now we leave the existing tests in jax2tf, because some of those tests exercise
other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515).

Since we now run these tests in GitHub and Kokoro, this has revealed a couple
of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py.

PiperOrigin-RevId: 583816243
This commit is contained in:
George Necula 2023-11-19 08:59:23 -08:00 committed by jax authors
parent 52b31a4973
commit 4fbf50dd60
4 changed files with 2386 additions and 6 deletions

View File

@ -2215,7 +2215,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization.value)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
for poly in ["b, ...", "b, w, w"]
for left in ([True, False] if dtype == np.float32 else [True])
for right in ([True, False] if dtype == np.float32 else [False])
@ -2519,7 +2519,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)],
polymorphic_shapes=[poly],
tol=(None if config.jax2tf_default_native_serialization.value else 1e-5))
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
# m and n must be static for now
for shape, poly, full_matrices in [
((2, 0, 4), "b, ...", False), # m = 0
@ -2822,7 +2822,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization.value)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
for compute_schur_vectors in [True, False]
for (shape, poly) in [
((3, 3), "w, w"),
@ -2943,7 +2943,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization.value)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
for (left_side, a_shape, b_shape, a_poly, b_poly) in [
(True, (3, 4, 4), (3, 4, 5), "b, ...", "b, ..."),
(True, (3, 4, 4), (3, 4, 5), "b, k, k", "b, k, m"),
@ -3017,11 +3017,16 @@ def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]:
harness_groups[h.group_name].append(h)
selected_harnesses = []
for group_name, hlist in harness_groups.items():
for _, hlist in harness_groups.items():
# Pick the dtype with the most harnesses in this group. Some harness
# groups only test different use cases at a few dtypes.
c = collections.Counter([h.dtype for h in hlist])
(dtype, _), = c.most_common(1)
(_, max_count), = c.most_common(1)
# Pick the first alphabetically among those with max_count, to ensure
# that we generate deterministic tests.
dtypes_with_max_count = (dtype for dtype, count in c.items()
if count == max_count)
dtype, *_ = sorted(dtypes_with_max_count, key=str)
selected_harnesses.extend([h for h in hlist if h.dtype == dtype])
batch_size = 3

View File

@ -153,6 +153,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
return f1(*args1)
# TODO(necula): clean up the test harnesses to not require these flags
@jtu.with_config(jax_numpy_rank_promotion="allow",
jax_numpy_dtype_promotion='standard',
jax_legacy_prng_key="allow")

View File

@ -1287,6 +1287,27 @@ jax_test(
],
)
jax_test(
name = "shape_poly_test",
srcs = ["shape_poly_test.py"],
disable_configs = [
"gpu_a100", # TODO(b/269593297): matmul precision issues
],
enable_configs = [
"cpu",
"cpu_x32",
],
shard_count = {
"cpu": 4,
"gpu": 4,
"tpu": 4,
},
deps = [
"//jax:internal_test_harnesses",
"//jax/experimental/export",
],
)
jax_test(
name = "export_harnesses_multi_platform_test",
srcs = ["export_harnesses_multi_platform_test.py"],

2353
tests/shape_poly_test.py Normal file

File diff suppressed because it is too large Load Diff