mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Copy many of the jax2tf/shape_poly_test to live outside of jax2tf.
Shape polymorphism is now usable independently of jax2tf, and it deserves to have its tests independent of jax2tf. I started by branching jax2tf/tests/shape_poly_test.py into tests/shape_poly_test.py, followed by removing from the latter the tests and helper functions that do not make sense outside of jax2tf. For now we leave the existing tests in jax2tf, because some of those tests exercise other code paths. In the process of adding these tests we found two bugs (fixed separately in https://github.com/google/jax/pull/18516 and https://github.com/google/jax/pull/18515). Since we now run these tests in GitHub and Kokoro, this has revealed a couple of bugs in the tests, which we fix here both in the jax2tf/tests/shape_poly_test.py and the copy tests/shape_poly_test.py. PiperOrigin-RevId: 583816243
This commit is contained in:
parent
52b31a4973
commit
4fbf50dd60
@ -2215,7 +2215,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization.value)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
||||
for poly in ["b, ...", "b, w, w"]
|
||||
for left in ([True, False] if dtype == np.float32 else [True])
|
||||
for right in ([True, False] if dtype == np.float32 else [False])
|
||||
@ -2519,7 +2519,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)],
|
||||
polymorphic_shapes=[poly],
|
||||
tol=(None if config.jax2tf_default_native_serialization.value else 1e-5))
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
||||
# m and n must be static for now
|
||||
for shape, poly, full_matrices in [
|
||||
((2, 0, 4), "b, ...", False), # m = 0
|
||||
@ -2822,7 +2822,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization.value)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
||||
for compute_schur_vectors in [True, False]
|
||||
for (shape, poly) in [
|
||||
((3, 3), "w, w"),
|
||||
@ -2943,7 +2943,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization.value)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes()
|
||||
for (left_side, a_shape, b_shape, a_poly, b_poly) in [
|
||||
(True, (3, 4, 4), (3, 4, 5), "b, ...", "b, ..."),
|
||||
(True, (3, 4, 4), (3, 4, 5), "b, k, k", "b, k, m"),
|
||||
@ -3017,11 +3017,16 @@ def _make_vmap_primitive_harnesses() -> Sequence[PolyHarness]:
|
||||
harness_groups[h.group_name].append(h)
|
||||
|
||||
selected_harnesses = []
|
||||
for group_name, hlist in harness_groups.items():
|
||||
for _, hlist in harness_groups.items():
|
||||
# Pick the dtype with the most harnesses in this group. Some harness
|
||||
# groups only test different use cases at a few dtypes.
|
||||
c = collections.Counter([h.dtype for h in hlist])
|
||||
(dtype, _), = c.most_common(1)
|
||||
(_, max_count), = c.most_common(1)
|
||||
# Pick the first alphabetically among those with max_count, to ensure
|
||||
# that we generate deterministic tests.
|
||||
dtypes_with_max_count = (dtype for dtype, count in c.items()
|
||||
if count == max_count)
|
||||
dtype, *_ = sorted(dtypes_with_max_count, key=str)
|
||||
selected_harnesses.extend([h for h in hlist if h.dtype == dtype])
|
||||
|
||||
batch_size = 3
|
||||
|
@ -153,6 +153,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
|
||||
return f1(*args1)
|
||||
|
||||
|
||||
# TODO(necula): clean up the test harnesses to not require these flags
|
||||
@jtu.with_config(jax_numpy_rank_promotion="allow",
|
||||
jax_numpy_dtype_promotion='standard',
|
||||
jax_legacy_prng_key="allow")
|
||||
|
21
tests/BUILD
21
tests/BUILD
@ -1287,6 +1287,27 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "shape_poly_test",
|
||||
srcs = ["shape_poly_test.py"],
|
||||
disable_configs = [
|
||||
"gpu_a100", # TODO(b/269593297): matmul precision issues
|
||||
],
|
||||
enable_configs = [
|
||||
"cpu",
|
||||
"cpu_x32",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 4,
|
||||
"gpu": 4,
|
||||
"tpu": 4,
|
||||
},
|
||||
deps = [
|
||||
"//jax:internal_test_harnesses",
|
||||
"//jax/experimental/export",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "export_harnesses_multi_platform_test",
|
||||
srcs = ["export_harnesses_multi_platform_test.py"],
|
||||
|
2353
tests/shape_poly_test.py
Normal file
2353
tests/shape_poly_test.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user