[Pallas] Enable interpreter mode as default lowering for CPU

PiperOrigin-RevId: 580700740
This commit is contained in:
Sharad Vikram 2023-11-08 16:35:03 -08:00 committed by jax authors
parent 21260a7a65
commit 8fbcfce2dd
2 changed files with 99 additions and 56 deletions

View File

@ -28,6 +28,7 @@ from jax._src import state
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import ad_util
from jax._src import core as jax_core
@ -346,6 +347,27 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
return name
def _pallas_call_default_lowering(
ctx: mlir.LoweringRuleContext,
*in_nodes,
interpret: bool,
**params):
platforms = ctx.module_context.platforms
if len(platforms) > 1:
raise ValueError("Can only lower pallas_call on a single platform.")
platform = platforms[0]
if platform != "cpu":
raise ValueError(
f"Cannot lower pallas_call on platform: {platform}. "
"To use Pallas on GPU, please install Triton and JAX-Triton. "
"To use Pallas on TPU, please install Jaxlib TPU and libtpu.")
if not interpret:
raise ValueError("Only interpret mode is supported on CPU backend.")
impl = partial(_pallas_call_impl, **params, interpret=True)
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
mlir.register_lowering(pallas_call_p, _pallas_call_default_lowering)
def pallas_call(
f: Callable[..., None],
out_shape: Any,

View File

@ -125,6 +125,9 @@ class PallasTest(parameterized.TestCase):
INTERPRET = False
def setUp(self):
if jax.config.x64_enabled:
self.skipTest("Only works in 32-bit")
if not self.INTERPRET:
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only works on GPU")
try:
@ -139,6 +142,12 @@ class PallasTest(parameterized.TestCase):
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
def check_gpu_capability_at_least(self, capability,
device: int = 0):
if self.INTERPRET:
return True
return plgpu.get_compute_capability(device) >= capability
class PallasCallTest(PallasTest):
@ -252,6 +261,7 @@ class PallasCallTest(PallasTest):
])
def test_reshape(self, in_shape, out_shape):
# TODO(sharadmv): re-enable when `reshape` works again
if not self.INTERPRET:
self.skipTest("Reshape not yet supported in Triton-MLIR")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
@ -303,13 +313,13 @@ class PallasCallTest(PallasTest):
if block_size_m <= m and block_size_n <= n and block_size_k <= k
])
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Matmul only works on GPUs with capability >= sm70")
if (plgpu.get_compute_capability(0) <= 75
if not self.INTERPRET and (plgpu.get_compute_capability(0) <= 75
and (bm > 128 or bn > 128 or bk > 32)):
raise unittest.SkipTest("Block sizes too big for sm70.")
k1, k2 = random.split(random.PRNGKey(0))
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
out, expected = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
@ -330,14 +340,14 @@ class PallasCallTest(PallasTest):
if block_size_m <= m and block_size_n <= n and block_size_k <= k
])
def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Matmul only works on GPUs with capability >= sm70")
if (plgpu.get_compute_capability(0) <= 75
if not self.INTERPRET and (plgpu.get_compute_capability(0) <= 75
and (bm > 128 or bn > 128 or bk > 32)):
raise unittest.SkipTest("Block sizes too big for sm70.")
k1, k2 = random.split(random.PRNGKey(0))
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
out, expected = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
@ -350,7 +360,7 @@ class PallasCallTest(PallasTest):
for dtype in ["float32", "float16"]
))
def test_dot(self, size, dtype):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Matmul only works on GPUs with capability >= sm70")
@ -363,7 +373,7 @@ class PallasCallTest(PallasTest):
y = y_ref[:, :]
o_ref[:, :] = pl.dot(x, y).astype(o_ref.dtype)
k1, k2 = random.split(random.PRNGKey(0))
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (size, size), dtype=dtype)
y = random.normal(k2, (size, size), dtype=dtype)
out, expected = dot(x, y), jnp.dot(x, y)
@ -394,7 +404,7 @@ class PallasCallTest(PallasTest):
softmax_output = numerator / denominator
pl.store(o_ref, row_idxs, softmax_output, mask=mask)
key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, [batch_size, size], dtype=dtype)
np.testing.assert_allclose(softmax(x), jax.nn.softmax(x, axis=-1),
atol=1e-5, rtol=1e-5)
@ -416,7 +426,7 @@ class PallasCallTest(PallasTest):
x = pl.load(x_ref, (idx,), mask=mask)
pl.store(o_ref, (idx,), x + 1., mask=mask)
key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (size,))
np.testing.assert_allclose(add_one(x), x + 1., atol=1e-5, rtol=1e-5)
@ -431,7 +441,7 @@ class PallasCallTest(PallasTest):
x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]))
pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.)
key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (m, n))
np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5)
@ -449,8 +459,8 @@ class PallasCallTest(PallasTest):
y = pl.swap(y_ref, (slice(None),), x)
x_ref[:] = y
x = random.normal(random.PRNGKey(0), (m, n))
y = random.normal(random.PRNGKey(1), (m, n))
x = random.normal(random.key(0), (m, n))
y = random.normal(random.key(1), (m, n))
out = swap(x, y)
np.testing.assert_array_equal(out[0], y)
np.testing.assert_array_equal(out[1], x)
@ -469,9 +479,9 @@ class PallasCallTest(PallasTest):
y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:])
x_ref[:] = y
x = random.normal(random.PRNGKey(0), (m, n))
y = random.normal(random.PRNGKey(1), (m, n))
mask = random.bernoulli(random.PRNGKey(2), shape=(m, n))
x = random.normal(random.key(0), (m, n))
y = random.normal(random.key(1), (m, n))
mask = random.bernoulli(random.key(2), shape=(m, n))
out = masked_swap(x, y, mask)
np.testing.assert_array_equal(out[0], jnp.where(mask, y, x))
np.testing.assert_array_equal(out[1], jnp.where(mask, x, y))
@ -487,7 +497,7 @@ class PallasCallTest(PallasTest):
pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]),
jnp.ones_like(o_ref))
key = random.PRNGKey(0)
key = random.key(0)
x = random.normal(key, (m, n))
np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5)
@ -505,7 +515,7 @@ class PallasCallTest(PallasTest):
grid = (8,)
size = 8
dtype = "float32"
k1 = random.PRNGKey(0)
k1 = random.key(0)
block_size = 1
x = random.normal(k1, [size], dtype=dtype)
kernel = functools.partial(add_inplace_kernel, block_size=block_size)
@ -526,7 +536,7 @@ class PallasCallTest(PallasTest):
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
])
def test_scalar_atomic(self, op, value, numpy_op):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Atomic ops onl works on GPUs with capability >= sm70")
@ -559,7 +569,7 @@ class PallasCallTest(PallasTest):
@parameterized.parameters(*[(0,), (1,)])
def test_array_atomic_add(self, axis):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Atomic ops onl works on GPUs with capability >= sm70")
@ -582,7 +592,7 @@ class PallasCallTest(PallasTest):
idx = (jnp.arange(m), i)
x = pl.load(x_ref, idx)
pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x)
x = random.normal(random.PRNGKey(0), (m, n))
x = random.normal(random.key(0), (m, n))
y = jnp.zeros(out_shape.shape, out_shape.dtype)
y = reduce(x, y)
y_ref = np.sum(x, axis=axis)
@ -591,7 +601,7 @@ class PallasCallTest(PallasTest):
@parameterized.parameters(False, True)
def test_reduce_only_dim(self, use_store):
m = 32
x = random.normal(random.PRNGKey(0), (m,), dtype=jnp.float32)
x = random.normal(random.key(0), (m,), dtype=jnp.float32)
out_shape = jax.ShapeDtypeStruct((), x.dtype)
@functools.partial(
self.pallas_call,
@ -634,7 +644,7 @@ class PallasCallTest(PallasTest):
else:
return random.normal(key, (m, n), dtype=dtype)
out_shape = jax.ShapeDtypeStruct(
op(make_x(random.PRNGKey(0)), axis=axis).shape, out_dtype)
op(make_x(random.key(0)), axis=axis).shape, out_dtype)
if isinstance(axis, int):
grid = tuple(a for i, a in enumerate((m, n)) if i != axis)
else:
@ -647,7 +657,7 @@ class PallasCallTest(PallasTest):
x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None]))
y = op(x, axis=axis)
pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y)
for i, key in enumerate(random.split(random.PRNGKey(0), 20)):
for i, key in enumerate(random.split(random.key(0), 20)):
x = make_x(key)
y = reduce(x)
y_ref = op(x, axis=axis)
@ -663,7 +673,7 @@ class PallasCallTest(PallasTest):
def slice_kernel(x_ref, y_ref):
x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4)))
pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x)
x = random.normal(random.PRNGKey(0), (m, n))
x = random.normal(random.key(0), (m, n))
y = slice_kernel(x)
y_ref = x[:4]
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2)
@ -734,8 +744,7 @@ class PallasCallTest(PallasTest):
])
def test_atomic_counter(self, num_threads):
if self.INTERPRET:
self.skipTest("While loop not supported in interpret mode yet.")
self.skipTest("While loop not supported in interpreter mode.")
@functools.partial(
self.pallas_call, out_shape=(
jax.ShapeDtypeStruct((), jnp.int32),
@ -766,7 +775,7 @@ class PallasCallTest(PallasTest):
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
m, n = 16, 32
x = random.normal(random.PRNGKey(0), (m, n))
x = random.normal(random.key(0), (m, n))
@functools.partial(self.pallas_call, out_shape=x, grid=1)
def softmax_kernel(x_ref, y_ref):
@ -1199,7 +1208,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
x = x_ref[()]
o_ref[()] = impl(x)
k1, k2 = random.split(random.PRNGKey(0))
k1, k2 = random.split(random.key(0))
x = random.normal(k1)
t = random.normal(k2)
out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,))
@ -1221,7 +1230,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
x = x_ref[()]
o_ref[()] = jax.grad(impl)(x)
x = random.normal(random.PRNGKey(0))
x = random.normal(random.key(0))
out_grad = pallas_impl(x)
out_grad_ref = jax.grad(impl)(x)
np.testing.assert_allclose(out_grad, out_grad_ref, atol=1e-5, rtol=1e-5)
@ -1237,7 +1246,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
o_ref[jnp.arange(2)] = jnp.zeros(2)
o_ref[2 + jnp.arange(2)] = impl(x)
k1, k2 = random.split(random.PRNGKey(0))
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (8,))
t = random.normal(k2, (8,))
out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,))
@ -1250,7 +1259,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
# TODO(sharadmv): enable this when we update Triton
# def test_jvp_matmul(self):
# k1, k2 = random.split(random.PRNGKey(0))
# k1, k2 = random.split(random.key(0))
# x = random.normal(k1, (256, 128))
# y = random.normal(k2, (128, 64))
# bm, bn, bk, gm = 64, 128, 32, 8
@ -1400,7 +1409,7 @@ class PallasCallVmapTest(PallasTest):
add_one = jax.vmap(jax.vmap(add_one))
add_one_ref = lambda x: x + 1
x = random.randint(random.PRNGKey(0), (4, 65536, 2), 0, 10000)
x = random.randint(random.key(0), (4, 65536, 2), 0, 10000)
out = add_one(x)
out_ref = add_one_ref(x)
@ -1418,7 +1427,7 @@ class PallasCallVmapTest(PallasTest):
add_one = jax.vmap(jax.vmap(jax.vmap(add_one)))
add_one_ref = lambda x: x + 1
x = random.randint(random.PRNGKey(0), (2, 2, 65536, 2), 0, 10000)
x = random.randint(random.key(0), (2, 2, 65536, 2), 0, 10000)
out = add_one(x)
out_ref = add_one_ref(x)
@ -1438,8 +1447,8 @@ class PallasOpsTest(PallasTest):
def ne(x_ref, y_ref, o_ref):
o_ref[:] = x_ref[...] != y_ref[...]
x = jnp.ones(8)
y = jnp.arange(8)
x = jnp.ones(8, dtype=jnp.int32)
y = jnp.arange(8, dtype=jnp.int32)
not_equal = ne(x, y)
np.testing.assert_allclose(not_equal, x != y)
@ -1561,11 +1570,11 @@ class FusedAttentionTest(PallasTest):
use_segment_ids,
kwargs,
):
if plgpu.get_compute_capability(0) < 80:
if not self.check_gpu_capability_at_least(80):
raise unittest.SkipTest(
"Fused attention only works on GPUs with capability >= sm80")
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
)
@ -1638,10 +1647,10 @@ class FusedAttentionTest(PallasTest):
def test_fused_attention_bwd(
self, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
):
if plgpu.get_compute_capability(0) < 80:
if not self.check_gpu_capability_at_least(80):
raise unittest.SkipTest(
"Fused attention only works on GPUs with capability >= sm80")
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
)
@ -1671,6 +1680,9 @@ class FusedAttentionTest(PallasTest):
np.testing.assert_allclose(dv, dv_ref, atol=0.05)
class FusedAttentionInterpreterTest(PallasTest):
INTERPRET = True
class FusedLayerNormTest(PallasTest):
@parameterized.parameters(*[
@ -1678,10 +1690,10 @@ class FusedLayerNormTest(PallasTest):
(2, 384, 192),
])
def test_fused_layernorm_fwd(self, batch_size, seq_len, embed_dim):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Fused layernorm only works on GPUs with capability >= sm70")
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
k1, k2, k3 = random.split(random.key(0), 3)
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
@ -1695,10 +1707,10 @@ class FusedLayerNormTest(PallasTest):
(2, 384, 192),
])
def test_fused_layernorm_bwd(self, batch_size, seq_len, embed_dim):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Fused layernorm only works on GPUs with capability >= sm70")
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
k1, k2, k3 = random.split(random.key(0), 3)
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
@ -1716,6 +1728,10 @@ class FusedLayerNormTest(PallasTest):
np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2)
class FusedLayerNormInterpreterTest(PallasTest):
INTERPRET = True
class RmsNormTest(PallasTest):
@parameterized.parameters(*[
@ -1723,10 +1739,10 @@ class RmsNormTest(PallasTest):
(2, 384, 192),
])
def test_rms_fwd(self, batch_size, seq_len, embed_dim):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Rms norm only works on GPUs with capability >= sm70")
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
k1, k2, k3 = random.split(random.key(0), 3)
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
@ -1740,10 +1756,10 @@ class RmsNormTest(PallasTest):
(2, 384, 192),
])
def test_rms_norm_bwd(self, batch_size, seq_len, embed_dim):
if plgpu.get_compute_capability(0) < 70:
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest(
"Rms norm only works on GPUs with capability >= sm70")
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
k1, k2, k3 = random.split(random.key(0), 3)
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
@ -1760,6 +1776,8 @@ class RmsNormTest(PallasTest):
np.testing.assert_allclose(dw, dw_ref, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2)
class RmsNormInterpreterTest(PallasTest):
INTERPRET = True
class SoftmaxTest(PallasTest):
@ -1773,7 +1791,7 @@ class SoftmaxTest(PallasTest):
if dtype == jnp.bfloat16:
raise absltest.SkipTest("Disabled due to Triton lowering bug")
x = jax.random.normal(random.PRNGKey(0), shape, dtype=dtype)
x = jax.random.normal(random.key(0), shape, dtype=dtype)
atol, rtol = {
jnp.bfloat16: (1e-2, 1e-4),
@ -1789,5 +1807,8 @@ class SoftmaxTest(PallasTest):
)
class SoftmaxInterpreterTest(PallasTest):
INTERPRET = True
if __name__ == "__main__":
absltest.main()