[xla:cpu] Support for up to 16 sorted inputs

+ enable more jax/lax tests for XLA CPU thunks

PiperOrigin-RevId: 655249641
This commit is contained in:
Eugene Zhulenev 2024-07-23 11:53:48 -07:00 committed by jax authors
parent dc42ba0e41
commit e3fc63cafb

View File

@ -445,7 +445,10 @@ jax_test(
"gpu": 40,
"tpu": 50,
},
tags = ["noasan"], # Test times out on all backends
tags = [
"noasan", # Test times out on all backends
"test_cpu_thunks",
],
)
jax_test(
@ -456,6 +459,7 @@ jax_test(
"gpu": 30,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
)
jax_test(
@ -466,6 +470,7 @@ jax_test(
"gpu": 20,
"tpu": 20,
},
tags = ["test_cpu_thunks"],
)
jax_test(
@ -486,16 +491,19 @@ jax_test(
"gpu": 10,
"tpu": 10,
},
tags = ["test_cpu_thunks"],
)
jax_test(
name = "lax_numpy_ufuncs_test",
srcs = ["lax_numpy_ufuncs_test.py"],
tags = ["test_cpu_thunks"],
)
jax_test(
name = "lax_numpy_vectorize_test",
srcs = ["lax_numpy_vectorize_test.py"],
tags = ["test_cpu_thunks"],
)
jax_test(
@ -560,6 +568,7 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
deps = [
"//jax:internal_test_util",
"//jax:lax_reference",
@ -589,6 +598,7 @@ jax_test(
"gpu": 40,
"tpu": 20,
},
tags = ["test_cpu_thunks"],
)
jax_test(
@ -599,6 +609,7 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
)
@ -610,6 +621,7 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
)
@ -652,6 +664,7 @@ jax_test(
"gpu": 40,
"tpu": 40,
},
tags = ["test_cpu_thunks"],
)
jax_test(