Skip unsupported tests on XLA:CPU MLIR.

PiperOrigin-RevId: 490754048
This commit is contained in:
Johannes Reifferscheid 2022-11-24 09:56:27 -08:00 committed by jax authors
parent a711166569
commit 575c2f3783
2 changed files with 16 additions and 0 deletions

View File

@ -303,6 +303,20 @@ def set_host_platform_device_count(nr_devices: int):
xla_bridge.get_backend.cache_clear()
return undo
def skip_on_xla_cpu_mlir(test_method):
"""A decorator to skip tests when MLIR lowering is enabled."""
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
xla_flags = os.getenv('XLA_FLAGS') or ''
if '--xla_cpu_use_xla_runtime' in xla_flags or '--xla_cpu_enable_mlir_lowering' in xla_flags:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
f'{test_name} not supported on XLA:CPU MLIR')
return test_method(self, *args, **kwargs)
return test_method_wrapper
def skip_on_flag(flag_name, skip_value):
"""A decorator for test methods to skip the test when flags are set."""
def skip(test_method): # pylint: disable=missing-docstring

View File

@ -1103,12 +1103,14 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIsInstance(f.as_text(), (str, type(None)))
self.assertIsInstance(g.as_text(), (str, type(None)))
@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_cost_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()
g = self.jit(lambda x: x + 4).lower(1.).compile()
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise
@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_memory_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()
g = self.jit(lambda x: x + 4).lower(1.).compile()