mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21718 from jakevdp:pallas-config
PiperOrigin-RevId: 641349981
This commit is contained in:
commit
0d047a116a
@ -47,7 +47,6 @@ else:
|
||||
|
||||
|
||||
# ruff: noqa: F405
|
||||
config.update("jax_traceback_filtering", "off")
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def nd_loop(bounds, body, *, _idxs = ()):
|
||||
@ -164,16 +163,9 @@ class TestCase(parameterized.TestCase):
|
||||
self.skipTest("Only works on GPU with capability >= sm90")
|
||||
super().setUp()
|
||||
self.prng = np.random.default_rng(1234)
|
||||
self.ctx = mlir.make_ir_context()
|
||||
self.ctx.__enter__()
|
||||
self.loc = ir.Location.unknown()
|
||||
self.loc.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.loc.__exit__(None, None, None)
|
||||
self.ctx.__exit__(None, None, None)
|
||||
del self.loc, self.ctx
|
||||
super().tearDown()
|
||||
self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
|
||||
self.enter_context(mlir.make_ir_context())
|
||||
self.enter_context(ir.Location.unknown())
|
||||
|
||||
|
||||
class TestUtilTest(TestCase):
|
||||
|
@ -32,12 +32,12 @@ else:
|
||||
from jax.experimental.mosaic.gpu.examples import matmul
|
||||
|
||||
|
||||
config.update("jax_traceback_filtering", "off")
|
||||
config.parse_flags_with_absl()
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0")
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class MatmulTestCase(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -30,10 +30,10 @@ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
|
||||
# pylint: disable=no-value-for-parameter
|
||||
|
||||
|
||||
config.update("jax_traceback_filtering", "off")
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class DecodeAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -49,8 +49,6 @@ else:
|
||||
# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
|
||||
# pylint: disable=no-value-for-parameter
|
||||
|
||||
|
||||
config.update("jax_traceback_filtering", "off")
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk",
|
||||
@ -121,6 +119,7 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False):
|
||||
return matmul_kernel(x, y)
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user