From b7715e279dd96938452a3817564d8671bf681543 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 22 Jul 2024 11:20:15 +0000 Subject: [PATCH] Another take at enabling Pallas GPU tests on x64 Note that for_loop_p no longer assumes that the loop index is an int32. Closes #18847 --- jax/_src/core.py | 18 +++--- jax/_src/lax/control_flow/for_loop.py | 8 +-- jax/_src/pallas/pallas_call.py | 4 +- jax/_src/pallas/primitives.py | 8 ++- jax/_src/pallas/triton/BUILD | 1 + jax/_src/pallas/triton/lowering.py | 8 ++- jax/_src/state/primitives.py | 2 +- jax/_src/state/types.py | 6 ++ tests/pallas/BUILD | 2 - tests/pallas/pallas_test.py | 93 +++++++++------------------ 10 files changed, 65 insertions(+), 85 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index e94ac83e5..effe49d3e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2076,14 +2076,16 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None): if handler: return handler(aval, weak_type) raise TypeError(type(aval)) -raise_to_shaped_mappings : dict[type, Callable] = { - AbstractToken: lambda aval, _: aval, - Bot: lambda aval, _: aval, - UnshapedArray: lambda aval, _: aval, - ShapedArray: lambda aval, weak_type: ShapedArray( - aval.shape, aval.dtype, weak_type, aval.named_shape), - DConcreteArray: lambda aval, weak_type: DShapedArray( - aval.shape, aval.dtype, weak_type), +raise_to_shaped_mappings: dict[type, Callable] = { + AbstractToken: lambda aval, _: aval, + Bot: lambda aval, _: aval, + UnshapedArray: lambda aval, _: aval, + ShapedArray: lambda aval, weak_type: ShapedArray( + aval.shape, aval.dtype, weak_type, aval.named_shape + ), + DConcreteArray: lambda aval, weak_type: DShapedArray( + aval.shape, aval.dtype, weak_type + ), } ### Operations on shapes and dimension sizes. diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index eca1ec289..15249e531 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -132,7 +132,7 @@ def for_loop(nsteps: int | Sequence[int], nsteps, = nsteps flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), jnp.dtype("int32")) + idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64)) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): @@ -251,7 +251,7 @@ def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll): def _for_impl_unrolled(body, nsteps, unroll, *args): remainder = nsteps % unroll - i = jnp.int32(0) + i = jnp.astype(0, dtypes.canonicalize_dtype(jnp.int64)) state = list(args) for _ in range(remainder): @@ -748,7 +748,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): """ flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), jnp.dtype("int32")) + idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64)) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): @@ -756,7 +756,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) def fori_body(i, carry): - i = jnp.int32(i) + i = jnp.astype(i, dtypes.canonicalize_dtype(jnp.int64)) if reverse: i = nsteps - i - 1 out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d185133c1..4050254f3 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -274,7 +274,7 @@ def _pallas_call_impl_interpret( len(blocks), len(scratch_values), ) - blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, + blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch) blocks = blocks[grid_mapping.num_index_operands:] blocks, out_scratch = split_list(blocks, [num_inout]) @@ -787,7 +787,7 @@ def pallas_call_checkify_rule(error: checkify.Error, # errors before other arguments. jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch] assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args) - result_flat = jax.core.eval_jaxpr( + result_flat = jax_core.eval_jaxpr( checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args) output_errors, _ = split_list(result_flat, [num_err_vals]) # Store new errors back in the error refs. diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index ce87f2bc0..706d1822f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -26,6 +26,7 @@ from jax import lax from jax import tree_util from jax._src import ad_util from jax._src import core as jax_core +from jax._src import dtypes from jax._src import effects from jax._src import pretty_printer as pp from jax._src import state @@ -359,7 +360,12 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): # of bounds, it will instead move the start_index backwards so the slice # will fit in memory. ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) - out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) + idx_dtype = dtypes.canonicalize_dtype(jnp.int64) + out_ones = lax.dynamic_slice( + ref, + [jnp.astype(s, idx_dtype) for s in slice_starts], + slice_sizes=slice_sizes, + ) out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims) out = out_ones[out_indexer] elif all(not isinstance(s, Slice) for s in idx.indices): diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 370cbb713..01d248098 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -52,6 +52,7 @@ pytype_strict_library( "//jax", "//jax:ad_util", "//jax:api_util", + "//jax:config", "//jax:core", "//jax:mlir", "//jax:partial_eval", diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 5653dcf6a..3602553f9 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -29,6 +29,7 @@ from jax import tree_util from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api_util +from jax._src import config from jax._src import core as jax_core from jax._src import custom_derivatives from jax._src import linear_util as lu @@ -2263,9 +2264,10 @@ def _for_lowering_rule( del which_linear if reverse or unroll != 1: raise NotImplementedError - lower_bound = _i32_constant(0) - upper_bound = _i32_constant(nsteps) - step = _i32_constant(1) + _i_constant = _i64_constant if config.enable_x64.value else _i32_constant + lower_bound = _i_constant(0) + upper_bound = _i_constant(nsteps) + step = _i_constant(1) init_args = map(_ensure_ir_value, args, ctx.avals_in) # Partially discharge state from jaxpr for non-pointers should_discharge = [ diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 224c2f351..d3f17df4e 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -190,7 +190,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef, f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " f"Indices: {indexers}. ") - if ref_aval.dtype != val_aval.dtype: + if ref_aval.dtype != val_aval.dtype and not val_aval.weak_type: raise ValueError("Invalid dtype for `swap`. " f"Ref dtype: {ref_aval.dtype}. " f"Value dtype: {val_aval.dtype}. ") diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 303e4da0b..9e45d5df3 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -132,6 +132,12 @@ class AbstractRef(core.AbstractValue): def __init__(self, inner_aval: core.AbstractValue): self.inner_aval = inner_aval + @property + def weak_type(self) -> bool: + if not hasattr(self.inner_aval, "weak_type"): + raise AttributeError + return self.inner_aval.weak_type + def update(self, inner_aval=None): if inner_aval is None: return AbstractRef(self.inner_aval) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 66d33aa0f..8b6c4735c 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -41,8 +41,6 @@ jax_test( disable_configs = [ "gpu", "gpu_x32", - "gpu_a100", - "gpu_h100", "gpu_p100", "gpu_p100_x32", ], diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 10f3c9ee7..b54158272 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -29,6 +29,7 @@ from jax import lax from jax import random from jax._src import checkify from jax._src import config +from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop from jax._src.pallas.pallas_call import _trace_to_jaxpr @@ -55,6 +56,10 @@ def smem_on_tpu(): return None +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + @functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk", "interpret", "debug"]) def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False): @@ -65,7 +70,7 @@ def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False): debug=debug, grid=pl.cdiv(m, bm) * pl.cdiv(n, bn)) def matmul_kernel(x_ref, y_ref, o_ref): - pid = pl.program_id(axis=0) + pid = pl.program_id(axis=0).astype(intx) num_pid_m = m // bm num_pid_n = n // bn num_pid_in_group = gm * num_pid_n @@ -133,8 +138,6 @@ class PallasBaseTest(jtu.JaxTestCase): def setUp(self): if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled: - self.skipTest("On GPU the test works only in 32-bit") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") @@ -151,13 +154,10 @@ class PallasBaseTest(jtu.JaxTestCase): class PallasCallTest(PallasBaseTest): def test_add_one(self): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)) + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx)) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1. @@ -177,14 +177,11 @@ class PallasCallTest(PallasBaseTest): np.testing.assert_allclose(add_one(x), jnp.array([1.], jnp.float32)) def test_add_vector_block_spec(self): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), + out_shape=jax.ShapeDtypeStruct((8,), intx), in_specs=[pl.BlockSpec((1,), lambda i: i)], out_specs=pl.BlockSpec((1,), lambda i: i), grid=8, @@ -195,14 +192,11 @@ class PallasCallTest(PallasBaseTest): np.testing.assert_allclose(add_one(jnp.arange(8)), jnp.arange(8) + 1) def test_add_matrix_block_spec(self): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8, 8), jnp.int32), + out_shape=jax.ShapeDtypeStruct((8, 8), intx), in_specs=[pl.BlockSpec((2, 2), lambda i, j: (i, j))], out_specs=pl.BlockSpec((2, 2), lambda i, j: (i, j)), grid=(4, 4), @@ -225,13 +219,10 @@ class PallasCallTest(PallasBaseTest): self.assertTrue(jnp.all(logical_and(x))) def test_vector_indexing(self): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx), grid=1) def index(x_ref, i_ref, o_ref): o_ref[()] = x_ref[i_ref[()]] @@ -485,13 +476,10 @@ class PallasCallTest(PallasBaseTest): self.assertAllClose(res[0:1], to_store) def test_vector_slicing(self): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), floatx), grid=1) def index(x_ref, idx_ref, o_ref): idx = idx_ref[()] @@ -517,9 +505,6 @@ class PallasCallTest(PallasBaseTest): if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: all sort of assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") k1, k2 = random.split(random.key(0)) @@ -543,9 +528,6 @@ class PallasCallTest(PallasBaseTest): if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: all sort of assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") k1, k2 = random.split(random.key(0)) @@ -605,15 +587,12 @@ class PallasCallTest(PallasBaseTest): np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5) def test_with_input_output_aliasing(self): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") def add_inplace_kernel(_, o_ref, *, block_size): pid = pl.program_id(axis=0) # we use a 1d launch grid so axis is 0 block_start = pid * block_size - offsets = block_start + jnp.arange(block_size) + offsets = block_start + jnp.arange(block_size, dtype=jnp.int32) mask = offsets < o_ref.shape[0] x = pl.load(o_ref, (offsets,), mask=mask) output = x + 1 @@ -634,13 +613,10 @@ class PallasCallTest(PallasBaseTest): np.testing.assert_allclose(out, expected) def test_using_pallas_slice(self): - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") m, n = 32, 4 - out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32) + out_shape = jax.ShapeDtypeStruct((4, n), floatx) @functools.partial( self.pallas_call, out_shape=out_shape, @@ -996,7 +972,7 @@ class PallasControlFlowTest(PallasBaseTest): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), jnp.float32), + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), floatx), in_specs=[ pl.BlockSpec((), lambda _, __: ()), pl.BlockSpec((bx, 1), lambda i, _: (i, 0)), @@ -1073,15 +1049,13 @@ class PallasControlFlowTest(PallasBaseTest): # dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14])) def test_scan_cond_vm_explicit_ref_arg(self): - if jtu.test_device_matches(["cpu"]): - # TODO: fix this - self.skipTest("Fails on CPU: assertion error") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("TODO: error on TPU") program = jnp.int32([0, 1, 2, 3, 2]) - params = jnp.arange(len(program) * 3.).reshape(len(program), 3) - x = jnp.arange(7.) + params = jnp.arange(len(program) * 3., dtype=jnp.float32) + params = params.reshape(len(program), 3) + x = jnp.arange(7., dtype=jnp.float32) bx = 4 @jax.jit @@ -1113,7 +1087,7 @@ class PallasControlFlowTest(PallasBaseTest): return state, program_ref, params_ref out_ref[...] = jax.lax.fori_loop( 0, len(program), body_fn, - (jnp.zeros(x.shape), program_ref, params_ref))[0] + (jnp.zeros(x.shape, dtype=jnp.float32), program_ref, params_ref))[0] expected = (x * params[0, 0] + 2 * x * params[1, 1] + @@ -1127,16 +1101,14 @@ class PallasControlFlowTest(PallasBaseTest): params, x) def test_scan_cond_vm_closing_over_ref(self): - if jtu.test_device_matches(["cpu"]): - # TODO: fix this - self.skipTest("Fails on CPU: assertion error") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("TODO: error on TPU") # ** Difference is the closure over params_ref in the switch branches. ** program = jnp.int32([0, 1, 2, 3, 2, -1]) - params = jnp.arange(len(program) * 3.).reshape(len(program), 3) - x = jnp.arange(7.) + params = jnp.arange(len(program) * 3., dtype=jnp.float32) + params = params.reshape(len(program), 3) + x = jnp.arange(7., dtype=jnp.float32) bx = 4 @jax.jit @@ -1169,7 +1141,7 @@ class PallasControlFlowTest(PallasBaseTest): return state, program_ref, params_ref out_ref[...] = jax.lax.fori_loop( 0, len(program), body_fn, - (jnp.zeros(x.shape), program_ref, params_ref))[0] + (jnp.zeros(x.shape, dtype=jnp.float32), program_ref, params_ref))[0] expected = (x * params[0, 0] + 2 * x * params[1, 1] + @@ -1375,7 +1347,7 @@ class PallasControlFlowTest(PallasBaseTest): kernel, grid=(1,), out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32), + out_shape=jax.ShapeDtypeStruct([1, 1], intx), in_specs=[ pl.BlockSpec( (1, 8, 128), @@ -1439,7 +1411,7 @@ class PallasControlFlowTest(PallasBaseTest): kernel, grid=(4,), out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32), + out_shape=jax.ShapeDtypeStruct([1, 1], intx), in_specs=[ pl.BlockSpec( (1, 8, 128), @@ -1634,7 +1606,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): grad_tol = 1e-1 @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx), grid=1) def pallas_impl(x_ref, o_ref): x = x_ref[()] @@ -1656,7 +1628,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): def test_pallas_around_grad(self, impl): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((), jnp.float32), + out_shape=jax.ShapeDtypeStruct((), floatx), name=self.id().split(".")[-1], grid=1) def pallas_impl(x_ref, o_ref): @@ -1675,7 +1647,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): grad_tol = 1e-1 @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx), grid=1) def pallas_impl(x_ref, o_ref): x = x_ref[jnp.arange(2)] @@ -1730,16 +1702,9 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest): INTERPRET = True - def setUp(self): - super().setUp() - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") - class PallasOutOfBoundsInterpreterTest(PallasBaseTest): - - INTERPRET: bool = True + INTERPRET = True def test_interpret_mode_out_of_bounds_access(self): block_size = 32 @@ -1818,7 +1783,7 @@ class PallasOutOfBoundsInterpreterTest(PallasBaseTest): class PallasCheckifyInterpreterTest(PallasBaseTest): # TODO(b/346651778): Support non-interpret mode checkify. - INTERPRET: bool = True + INTERPRET = True def test_no_checkify(self,): def kernel(y_ref):