mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[export] Fixes for export_harnesses_multi_platform_test.
The test was mistakenly skipped on slow tests. This is a highly-parameterized test, and if there are some individual instances that are slow, only those should be skipped. The slowest of all instances takes 3s. I have also ensured that when running natively, we also use jit, like in export mode, to reduce chances that we see numerical discrepancies between eager and jit mode. This fixes a failure on GPU in Kokoro. PiperOrigin-RevId: 720946449
This commit is contained in:
parent
c9dfdb4e23
commit
8720e95570
@ -79,7 +79,6 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
message=("Using reduced precision for gradient of reduce-window min/max "
|
||||
"operator to work around missing XLA support for pair-reductions")
|
||||
)
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_prim(self, harness: test_harnesses.Harness):
|
||||
if "eigh_" in harness.fullname:
|
||||
self.skipTest("Eigenvalues are sorted and it is not correct to compare "
|
||||
@ -158,7 +157,7 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
lambda x: jax.device_put(x, device), args
|
||||
)
|
||||
logging.info("Running harness natively on %s", device)
|
||||
native_res = func_jax(*device_args)
|
||||
native_res = jax.jit(func_jax)(*device_args)
|
||||
logging.info("Running exported harness on %s", device)
|
||||
exported_res = exp.call(*device_args)
|
||||
if tol is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user