pallas/mosaic test: avoid leaking global config state

This commit is contained in:
Jake VanderPlas 2024-06-06 16:00:02 -07:00
parent d457f9a116
commit a2c31f4d15
4 changed files with 6 additions and 15 deletions

View File

@ -47,7 +47,6 @@ else:
# ruff: noqa: F405
config.update("jax_traceback_filtering", "off")
config.parse_flags_with_absl()
def nd_loop(bounds, body, *, _idxs = ()):
@ -164,16 +163,9 @@ class TestCase(parameterized.TestCase):
self.skipTest("Only works on GPU with capability >= sm90")
super().setUp()
self.prng = np.random.default_rng(1234)
self.ctx = mlir.make_ir_context()
self.ctx.__enter__()
self.loc = ir.Location.unknown()
self.loc.__enter__()
def tearDown(self):
self.loc.__exit__(None, None, None)
self.ctx.__exit__(None, None, None)
del self.loc, self.ctx
super().tearDown()
self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
self.enter_context(mlir.make_ir_context())
self.enter_context(ir.Location.unknown())
class TestUtilTest(TestCase):

View File

@ -32,12 +32,12 @@ else:
from jax.experimental.mosaic.gpu.examples import matmul
config.update("jax_traceback_filtering", "off")
config.parse_flags_with_absl()
os.environ["XLA_FLAGS"] = (
os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0")
@jtu.with_config(jax_traceback_filtering="off")
class MatmulTestCase(jtu.JaxTestCase):
def setUp(self):

View File

@ -30,10 +30,10 @@ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
# pylint: disable=no-value-for-parameter
config.update("jax_traceback_filtering", "off")
config.parse_flags_with_absl()
@jtu.with_config(jax_traceback_filtering="off")
class DecodeAttentionTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -49,8 +49,6 @@ else:
# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
# pylint: disable=no-value-for-parameter
config.update("jax_traceback_filtering", "off")
config.parse_flags_with_absl()
@functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk",
@ -121,6 +119,7 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False):
return matmul_kernel(x, y)
@jtu.with_config(jax_traceback_filtering="off")
class PallasTest(jtu.JaxTestCase):
INTERPRET = False