From 8720e955704c2aa3bf638ff04b10971928369721 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 29 Jan 2025 06:11:44 -0800 Subject: [PATCH] [export] Fixes for export_harnesses_multi_platform_test. The test was mistakenly skipped on slow tests. This is a highly-parameterized test, and if there are some individual instances that are slow, only those should be skipped. The slowest of all instances takes 3s. I have also ensured that when running natively, we also use jit, like in export mode, to reduce chances that we see numerical discrepancies between eager and jit mode. This fixes a failure on GPU in Kokoro. PiperOrigin-RevId: 720946449 --- tests/export_harnesses_multi_platform_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index d5878fa50..ef9d1e04c 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -79,7 +79,6 @@ class PrimitiveTest(jtu.JaxTestCase): message=("Using reduced precision for gradient of reduce-window min/max " "operator to work around missing XLA support for pair-reductions") ) - @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_prim(self, harness: test_harnesses.Harness): if "eigh_" in harness.fullname: self.skipTest("Eigenvalues are sorted and it is not correct to compare " @@ -158,7 +157,7 @@ class PrimitiveTest(jtu.JaxTestCase): lambda x: jax.device_put(x, device), args ) logging.info("Running harness natively on %s", device) - native_res = func_jax(*device_args) + native_res = jax.jit(func_jax)(*device_args) logging.info("Running exported harness on %s", device) exported_res = exp.call(*device_args) if tol is not None: