[export] Fixes for export_harnesses_multi_platform_test.

The test was mistakenly skipped on slow tests. This is a highly-parameterized test, and if there are some individual instances that are slow, only those should be skipped. The slowest of all instances takes 3s.

I have also ensured that when running natively, we also use jit, like in export mode, to reduce chances that we see numerical discrepancies between eager and jit mode. This fixes a failure on GPU in Kokoro.

PiperOrigin-RevId: 720946449
This commit is contained in:
George Necula 2025-01-29 06:11:44 -08:00 committed by jax authors
parent c9dfdb4e23
commit 8720e95570

View File

@ -79,7 +79,6 @@ class PrimitiveTest(jtu.JaxTestCase):
message=("Using reduced precision for gradient of reduce-window min/max "
"operator to work around missing XLA support for pair-reductions")
)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_prim(self, harness: test_harnesses.Harness):
if "eigh_" in harness.fullname:
self.skipTest("Eigenvalues are sorted and it is not correct to compare "
@ -158,7 +157,7 @@ class PrimitiveTest(jtu.JaxTestCase):
lambda x: jax.device_put(x, device), args
)
logging.info("Running harness natively on %s", device)
native_res = func_jax(*device_args)
native_res = jax.jit(func_jax)(*device_args)
logging.info("Running exported harness on %s", device)
exported_res = exp.call(*device_args)
if tol is not None: