mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add jax_array coverage to debug_nans_test
PiperOrigin-RevId: 478079509
This commit is contained in:
parent
ec41de2c9b
commit
fb8558cfdd
@ -81,13 +81,9 @@ jax_test(
|
||||
srcs = ["custom_object_test.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
jax_test(
|
||||
name = "debug_nans_test",
|
||||
srcs = ["debug_nans_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -150,7 +150,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 2:
|
||||
raise SkipTest("test requires >=2 devices")
|
||||
|
||||
p = jax.experimental.PartitionSpec('x')
|
||||
p = pjit.PartitionSpec('x')
|
||||
f = pjit.pjit(lambda x: 0. / x,
|
||||
in_axis_resources=p,
|
||||
out_axis_resources=p)
|
||||
|
Loading…
x
Reference in New Issue
Block a user