mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Skip unsupported tests on XLA:CPU MLIR.
PiperOrigin-RevId: 490754048
This commit is contained in:
parent
a711166569
commit
575c2f3783
@ -303,6 +303,20 @@ def set_host_platform_device_count(nr_devices: int):
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
return undo
|
||||
|
||||
|
||||
def skip_on_xla_cpu_mlir(test_method):
|
||||
"""A decorator to skip tests when MLIR lowering is enabled."""
|
||||
@functools.wraps(test_method)
|
||||
def test_method_wrapper(self, *args, **kwargs):
|
||||
xla_flags = os.getenv('XLA_FLAGS') or ''
|
||||
if '--xla_cpu_use_xla_runtime' in xla_flags or '--xla_cpu_enable_mlir_lowering' in xla_flags:
|
||||
test_name = getattr(test_method, '__name__', '[unknown test]')
|
||||
raise unittest.SkipTest(
|
||||
f'{test_name} not supported on XLA:CPU MLIR')
|
||||
return test_method(self, *args, **kwargs)
|
||||
return test_method_wrapper
|
||||
|
||||
|
||||
def skip_on_flag(flag_name, skip_value):
|
||||
"""A decorator for test methods to skip the test when flags are set."""
|
||||
def skip(test_method): # pylint: disable=missing-docstring
|
||||
|
@ -1103,12 +1103,14 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsInstance(f.as_text(), (str, type(None)))
|
||||
self.assertIsInstance(g.as_text(), (str, type(None)))
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def test_jit_lower_compile_cost_analysis(self):
|
||||
f = self.jit(lambda x: x).lower(1.).compile()
|
||||
g = self.jit(lambda x: x + 4).lower(1.).compile()
|
||||
f.cost_analysis() # doesn't raise
|
||||
g.cost_analysis() # doesn't raise
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def test_jit_lower_compile_memory_analysis(self):
|
||||
f = self.jit(lambda x: x).lower(1.).compile()
|
||||
g = self.jit(lambda x: x + 4).lower(1.).compile()
|
||||
|
Loading…
x
Reference in New Issue
Block a user