diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2496951ff..d1fd498a0 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -28,6 +28,7 @@ from jax._src import state from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import partial_eval as pe +from jax.interpreters import mlir from jax.interpreters import xla from jax._src import ad_util from jax._src import core as jax_core @@ -346,6 +347,27 @@ def _extract_function_name(f: Callable, name: str | None) -> str: return name +def _pallas_call_default_lowering( + ctx: mlir.LoweringRuleContext, + *in_nodes, + interpret: bool, + **params): + platforms = ctx.module_context.platforms + if len(platforms) > 1: + raise ValueError("Can only lower pallas_call on a single platform.") + platform = platforms[0] + if platform != "cpu": + raise ValueError( + f"Cannot lower pallas_call on platform: {platform}. " + "To use Pallas on GPU, please install Triton and JAX-Triton. " + "To use Pallas on TPU, please install Jaxlib TPU and libtpu.") + if not interpret: + raise ValueError("Only interpret mode is supported on CPU backend.") + impl = partial(_pallas_call_impl, **params, interpret=True) + return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes) +mlir.register_lowering(pallas_call_p, _pallas_call_default_lowering) + + def pallas_call( f: Callable[..., None], out_shape: Any, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index d2aee6dfd..7e4100a65 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -125,12 +125,15 @@ class PallasTest(parameterized.TestCase): INTERPRET = False def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - try: - import triton # noqa: F401 - except ImportError: - self.skipTest("Triton is not installed. Skipping PallasTest.") + if jax.config.x64_enabled: + self.skipTest("Only works in 32-bit") + if not self.INTERPRET: + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only works on GPU") + try: + import triton # noqa: F401 + except ImportError: + self.skipTest("Triton is not installed. Skipping PallasTest.") super().setUp() if compile_jaxpr: compile_jaxpr.cache_clear() @@ -139,6 +142,12 @@ class PallasTest(parameterized.TestCase): def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + def check_gpu_capability_at_least(self, capability, + device: int = 0): + if self.INTERPRET: + return True + return plgpu.get_compute_capability(device) >= capability + class PallasCallTest(PallasTest): @@ -252,7 +261,8 @@ class PallasCallTest(PallasTest): ]) def test_reshape(self, in_shape, out_shape): # TODO(sharadmv): re-enable when `reshape` works again - self.skipTest("Reshape not yet supported in Triton-MLIR") + if not self.INTERPRET: + self.skipTest("Reshape not yet supported in Triton-MLIR") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), grid=1) @@ -303,13 +313,13 @@ class PallasCallTest(PallasTest): if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Matmul only works on GPUs with capability >= sm70") - if (plgpu.get_compute_capability(0) <= 75 + if not self.INTERPRET and (plgpu.get_compute_capability(0) <= 75 and (bm > 128 or bn > 128 or bk > 32)): raise unittest.SkipTest("Block sizes too big for sm70.") - k1, k2 = random.split(random.PRNGKey(0)) + k1, k2 = random.split(random.key(0)) x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) out, expected = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm, @@ -330,14 +340,14 @@ class PallasCallTest(PallasTest): if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Matmul only works on GPUs with capability >= sm70") - if (plgpu.get_compute_capability(0) <= 75 + if not self.INTERPRET and (plgpu.get_compute_capability(0) <= 75 and (bm > 128 or bn > 128 or bk > 32)): raise unittest.SkipTest("Block sizes too big for sm70.") - k1, k2 = random.split(random.PRNGKey(0)) + k1, k2 = random.split(random.key(0)) x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) out, expected = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk, @@ -350,7 +360,7 @@ class PallasCallTest(PallasTest): for dtype in ["float32", "float16"] )) def test_dot(self, size, dtype): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Matmul only works on GPUs with capability >= sm70") @@ -363,7 +373,7 @@ class PallasCallTest(PallasTest): y = y_ref[:, :] o_ref[:, :] = pl.dot(x, y).astype(o_ref.dtype) - k1, k2 = random.split(random.PRNGKey(0)) + k1, k2 = random.split(random.key(0)) x = random.normal(k1, (size, size), dtype=dtype) y = random.normal(k2, (size, size), dtype=dtype) out, expected = dot(x, y), jnp.dot(x, y) @@ -394,7 +404,7 @@ class PallasCallTest(PallasTest): softmax_output = numerator / denominator pl.store(o_ref, row_idxs, softmax_output, mask=mask) - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, [batch_size, size], dtype=dtype) np.testing.assert_allclose(softmax(x), jax.nn.softmax(x, axis=-1), atol=1e-5, rtol=1e-5) @@ -416,7 +426,7 @@ class PallasCallTest(PallasTest): x = pl.load(x_ref, (idx,), mask=mask) pl.store(o_ref, (idx,), x + 1., mask=mask) - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (size,)) np.testing.assert_allclose(add_one(x), x + 1., atol=1e-5, rtol=1e-5) @@ -431,7 +441,7 @@ class PallasCallTest(PallasTest): x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :])) pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.) - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (m, n)) np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5) @@ -449,8 +459,8 @@ class PallasCallTest(PallasTest): y = pl.swap(y_ref, (slice(None),), x) x_ref[:] = y - x = random.normal(random.PRNGKey(0), (m, n)) - y = random.normal(random.PRNGKey(1), (m, n)) + x = random.normal(random.key(0), (m, n)) + y = random.normal(random.key(1), (m, n)) out = swap(x, y) np.testing.assert_array_equal(out[0], y) np.testing.assert_array_equal(out[1], x) @@ -469,9 +479,9 @@ class PallasCallTest(PallasTest): y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:]) x_ref[:] = y - x = random.normal(random.PRNGKey(0), (m, n)) - y = random.normal(random.PRNGKey(1), (m, n)) - mask = random.bernoulli(random.PRNGKey(2), shape=(m, n)) + x = random.normal(random.key(0), (m, n)) + y = random.normal(random.key(1), (m, n)) + mask = random.bernoulli(random.key(2), shape=(m, n)) out = masked_swap(x, y, mask) np.testing.assert_array_equal(out[0], jnp.where(mask, y, x)) np.testing.assert_array_equal(out[1], jnp.where(mask, x, y)) @@ -487,7 +497,7 @@ class PallasCallTest(PallasTest): pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), jnp.ones_like(o_ref)) - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (m, n)) np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5) @@ -505,7 +515,7 @@ class PallasCallTest(PallasTest): grid = (8,) size = 8 dtype = "float32" - k1 = random.PRNGKey(0) + k1 = random.key(0) block_size = 1 x = random.normal(k1, [size], dtype=dtype) kernel = functools.partial(add_inplace_kernel, block_size=block_size) @@ -526,7 +536,7 @@ class PallasCallTest(PallasTest): ("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min), ]) def test_scalar_atomic(self, op, value, numpy_op): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Atomic ops onl works on GPUs with capability >= sm70") @@ -559,7 +569,7 @@ class PallasCallTest(PallasTest): @parameterized.parameters(*[(0,), (1,)]) def test_array_atomic_add(self, axis): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Atomic ops onl works on GPUs with capability >= sm70") @@ -582,7 +592,7 @@ class PallasCallTest(PallasTest): idx = (jnp.arange(m), i) x = pl.load(x_ref, idx) pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x) - x = random.normal(random.PRNGKey(0), (m, n)) + x = random.normal(random.key(0), (m, n)) y = jnp.zeros(out_shape.shape, out_shape.dtype) y = reduce(x, y) y_ref = np.sum(x, axis=axis) @@ -591,7 +601,7 @@ class PallasCallTest(PallasTest): @parameterized.parameters(False, True) def test_reduce_only_dim(self, use_store): m = 32 - x = random.normal(random.PRNGKey(0), (m,), dtype=jnp.float32) + x = random.normal(random.key(0), (m,), dtype=jnp.float32) out_shape = jax.ShapeDtypeStruct((), x.dtype) @functools.partial( self.pallas_call, @@ -634,7 +644,7 @@ class PallasCallTest(PallasTest): else: return random.normal(key, (m, n), dtype=dtype) out_shape = jax.ShapeDtypeStruct( - op(make_x(random.PRNGKey(0)), axis=axis).shape, out_dtype) + op(make_x(random.key(0)), axis=axis).shape, out_dtype) if isinstance(axis, int): grid = tuple(a for i, a in enumerate((m, n)) if i != axis) else: @@ -647,7 +657,7 @@ class PallasCallTest(PallasTest): x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) y = op(x, axis=axis) pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) - for i, key in enumerate(random.split(random.PRNGKey(0), 20)): + for i, key in enumerate(random.split(random.key(0), 20)): x = make_x(key) y = reduce(x) y_ref = op(x, axis=axis) @@ -663,7 +673,7 @@ class PallasCallTest(PallasTest): def slice_kernel(x_ref, y_ref): x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) - x = random.normal(random.PRNGKey(0), (m, n)) + x = random.normal(random.key(0), (m, n)) y = slice_kernel(x) y_ref = x[:4] np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2) @@ -734,8 +744,7 @@ class PallasCallTest(PallasTest): ]) def test_atomic_counter(self, num_threads): if self.INTERPRET: - self.skipTest("While loop not supported in interpret mode yet.") - + self.skipTest("While loop not supported in interpreter mode.") @functools.partial( self.pallas_call, out_shape=( jax.ShapeDtypeStruct((), jnp.int32), @@ -766,7 +775,7 @@ class PallasCallTest(PallasTest): return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True)) m, n = 16, 32 - x = random.normal(random.PRNGKey(0), (m, n)) + x = random.normal(random.key(0), (m, n)) @functools.partial(self.pallas_call, out_shape=x, grid=1) def softmax_kernel(x_ref, y_ref): @@ -1199,7 +1208,7 @@ class PallasCallAutodifferentiationTest(PallasTest): x = x_ref[()] o_ref[()] = impl(x) - k1, k2 = random.split(random.PRNGKey(0)) + k1, k2 = random.split(random.key(0)) x = random.normal(k1) t = random.normal(k2) out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,)) @@ -1221,7 +1230,7 @@ class PallasCallAutodifferentiationTest(PallasTest): x = x_ref[()] o_ref[()] = jax.grad(impl)(x) - x = random.normal(random.PRNGKey(0)) + x = random.normal(random.key(0)) out_grad = pallas_impl(x) out_grad_ref = jax.grad(impl)(x) np.testing.assert_allclose(out_grad, out_grad_ref, atol=1e-5, rtol=1e-5) @@ -1237,7 +1246,7 @@ class PallasCallAutodifferentiationTest(PallasTest): o_ref[jnp.arange(2)] = jnp.zeros(2) o_ref[2 + jnp.arange(2)] = impl(x) - k1, k2 = random.split(random.PRNGKey(0)) + k1, k2 = random.split(random.key(0)) x = random.normal(k1, (8,)) t = random.normal(k2, (8,)) out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,)) @@ -1250,7 +1259,7 @@ class PallasCallAutodifferentiationTest(PallasTest): # TODO(sharadmv): enable this when we update Triton # def test_jvp_matmul(self): - # k1, k2 = random.split(random.PRNGKey(0)) + # k1, k2 = random.split(random.key(0)) # x = random.normal(k1, (256, 128)) # y = random.normal(k2, (128, 64)) # bm, bn, bk, gm = 64, 128, 32, 8 @@ -1400,7 +1409,7 @@ class PallasCallVmapTest(PallasTest): add_one = jax.vmap(jax.vmap(add_one)) add_one_ref = lambda x: x + 1 - x = random.randint(random.PRNGKey(0), (4, 65536, 2), 0, 10000) + x = random.randint(random.key(0), (4, 65536, 2), 0, 10000) out = add_one(x) out_ref = add_one_ref(x) @@ -1418,7 +1427,7 @@ class PallasCallVmapTest(PallasTest): add_one = jax.vmap(jax.vmap(jax.vmap(add_one))) add_one_ref = lambda x: x + 1 - x = random.randint(random.PRNGKey(0), (2, 2, 65536, 2), 0, 10000) + x = random.randint(random.key(0), (2, 2, 65536, 2), 0, 10000) out = add_one(x) out_ref = add_one_ref(x) @@ -1438,8 +1447,8 @@ class PallasOpsTest(PallasTest): def ne(x_ref, y_ref, o_ref): o_ref[:] = x_ref[...] != y_ref[...] - x = jnp.ones(8) - y = jnp.arange(8) + x = jnp.ones(8, dtype=jnp.int32) + y = jnp.arange(8, dtype=jnp.int32) not_equal = ne(x, y) np.testing.assert_allclose(not_equal, x != y) @@ -1561,11 +1570,11 @@ class FusedAttentionTest(PallasTest): use_segment_ids, kwargs, ): - if plgpu.get_compute_capability(0) < 80: + if not self.check_gpu_capability_at_least(80): raise unittest.SkipTest( "Fused attention only works on GPUs with capability >= sm80") - k1, k2, k3 = random.split(random.PRNGKey(0), 3) + k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 ) @@ -1638,10 +1647,10 @@ class FusedAttentionTest(PallasTest): def test_fused_attention_bwd( self, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids ): - if plgpu.get_compute_capability(0) < 80: + if not self.check_gpu_capability_at_least(80): raise unittest.SkipTest( "Fused attention only works on GPUs with capability >= sm80") - k1, k2, k3 = random.split(random.PRNGKey(0), 3) + k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 ) @@ -1671,6 +1680,9 @@ class FusedAttentionTest(PallasTest): np.testing.assert_allclose(dv, dv_ref, atol=0.05) +class FusedAttentionInterpreterTest(PallasTest): + INTERPRET = True + class FusedLayerNormTest(PallasTest): @parameterized.parameters(*[ @@ -1678,10 +1690,10 @@ class FusedLayerNormTest(PallasTest): (2, 384, 192), ]) def test_fused_layernorm_fwd(self, batch_size, seq_len, embed_dim): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Fused layernorm only works on GPUs with capability >= sm70") - k1, k2, k3 = random.split(random.PRNGKey(0), 3) + k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32) @@ -1695,10 +1707,10 @@ class FusedLayerNormTest(PallasTest): (2, 384, 192), ]) def test_fused_layernorm_bwd(self, batch_size, seq_len, embed_dim): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Fused layernorm only works on GPUs with capability >= sm70") - k1, k2, k3 = random.split(random.PRNGKey(0), 3) + k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32) @@ -1716,6 +1728,10 @@ class FusedLayerNormTest(PallasTest): np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) +class FusedLayerNormInterpreterTest(PallasTest): + INTERPRET = True + + class RmsNormTest(PallasTest): @parameterized.parameters(*[ @@ -1723,10 +1739,10 @@ class RmsNormTest(PallasTest): (2, 384, 192), ]) def test_rms_fwd(self, batch_size, seq_len, embed_dim): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Rms norm only works on GPUs with capability >= sm70") - k1, k2, k3 = random.split(random.PRNGKey(0), 3) + k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32) @@ -1740,10 +1756,10 @@ class RmsNormTest(PallasTest): (2, 384, 192), ]) def test_rms_norm_bwd(self, batch_size, seq_len, embed_dim): - if plgpu.get_compute_capability(0) < 70: + if not self.check_gpu_capability_at_least(70): raise unittest.SkipTest( "Rms norm only works on GPUs with capability >= sm70") - k1, k2, k3 = random.split(random.PRNGKey(0), 3) + k1, k2, k3 = random.split(random.key(0), 3) x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32) w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32) b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32) @@ -1760,6 +1776,8 @@ class RmsNormTest(PallasTest): np.testing.assert_allclose(dw, dw_ref, rtol=1e-2, atol=1e-2) np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) +class RmsNormInterpreterTest(PallasTest): + INTERPRET = True class SoftmaxTest(PallasTest): @@ -1773,7 +1791,7 @@ class SoftmaxTest(PallasTest): if dtype == jnp.bfloat16: raise absltest.SkipTest("Disabled due to Triton lowering bug") - x = jax.random.normal(random.PRNGKey(0), shape, dtype=dtype) + x = jax.random.normal(random.key(0), shape, dtype=dtype) atol, rtol = { jnp.bfloat16: (1e-2, 1e-4), @@ -1789,5 +1807,8 @@ class SoftmaxTest(PallasTest): ) +class SoftmaxInterpreterTest(PallasTest): + INTERPRET = True + if __name__ == "__main__": absltest.main()