mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[xla:cpu] Support for up to 16 sorted inputs
+ enable more jax/lax tests for XLA CPU thunks PiperOrigin-RevId: 655249641
This commit is contained in:
parent
dc42ba0e41
commit
e3fc63cafb
15
tests/BUILD
15
tests/BUILD
@ -445,7 +445,10 @@ jax_test(
|
||||
"gpu": 40,
|
||||
"tpu": 50,
|
||||
},
|
||||
tags = ["noasan"], # Test times out on all backends
|
||||
tags = [
|
||||
"noasan", # Test times out on all backends
|
||||
"test_cpu_thunks",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -456,6 +459,7 @@ jax_test(
|
||||
"gpu": 30,
|
||||
"tpu": 40,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -466,6 +470,7 @@ jax_test(
|
||||
"gpu": 20,
|
||||
"tpu": 20,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -486,16 +491,19 @@ jax_test(
|
||||
"gpu": 10,
|
||||
"tpu": 10,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_numpy_ufuncs_test",
|
||||
srcs = ["lax_numpy_ufuncs_test.py"],
|
||||
tags = ["test_cpu_thunks"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_numpy_vectorize_test",
|
||||
srcs = ["lax_numpy_vectorize_test.py"],
|
||||
tags = ["test_cpu_thunks"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -560,6 +568,7 @@ jax_test(
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
deps = [
|
||||
"//jax:internal_test_util",
|
||||
"//jax:lax_reference",
|
||||
@ -589,6 +598,7 @@ jax_test(
|
||||
"gpu": 40,
|
||||
"tpu": 20,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -599,6 +609,7 @@ jax_test(
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
@ -610,6 +621,7 @@ jax_test(
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
@ -652,6 +664,7 @@ jax_test(
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
},
|
||||
tags = ["test_cpu_thunks"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user