mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[Pallas] Enable interpreter mode as default lowering for CPU
PiperOrigin-RevId: 580700740
This commit is contained in:
parent
21260a7a65
commit
8fbcfce2dd
@ -28,6 +28,7 @@ from jax._src import state
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
@ -346,6 +347,27 @@ def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
return name
|
||||
|
||||
|
||||
def _pallas_call_default_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
interpret: bool,
|
||||
**params):
|
||||
platforms = ctx.module_context.platforms
|
||||
if len(platforms) > 1:
|
||||
raise ValueError("Can only lower pallas_call on a single platform.")
|
||||
platform = platforms[0]
|
||||
if platform != "cpu":
|
||||
raise ValueError(
|
||||
f"Cannot lower pallas_call on platform: {platform}. "
|
||||
"To use Pallas on GPU, please install Triton and JAX-Triton. "
|
||||
"To use Pallas on TPU, please install Jaxlib TPU and libtpu.")
|
||||
if not interpret:
|
||||
raise ValueError("Only interpret mode is supported on CPU backend.")
|
||||
impl = partial(_pallas_call_impl, **params, interpret=True)
|
||||
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
||||
mlir.register_lowering(pallas_call_p, _pallas_call_default_lowering)
|
||||
|
||||
|
||||
def pallas_call(
|
||||
f: Callable[..., None],
|
||||
out_shape: Any,
|
||||
|
@ -125,6 +125,9 @@ class PallasTest(parameterized.TestCase):
|
||||
INTERPRET = False
|
||||
|
||||
def setUp(self):
|
||||
if jax.config.x64_enabled:
|
||||
self.skipTest("Only works in 32-bit")
|
||||
if not self.INTERPRET:
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Only works on GPU")
|
||||
try:
|
||||
@ -139,6 +142,12 @@ class PallasTest(parameterized.TestCase):
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
||||
def check_gpu_capability_at_least(self, capability,
|
||||
device: int = 0):
|
||||
if self.INTERPRET:
|
||||
return True
|
||||
return plgpu.get_compute_capability(device) >= capability
|
||||
|
||||
|
||||
class PallasCallTest(PallasTest):
|
||||
|
||||
@ -252,6 +261,7 @@ class PallasCallTest(PallasTest):
|
||||
])
|
||||
def test_reshape(self, in_shape, out_shape):
|
||||
# TODO(sharadmv): re-enable when `reshape` works again
|
||||
if not self.INTERPRET:
|
||||
self.skipTest("Reshape not yet supported in Triton-MLIR")
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
|
||||
@ -303,13 +313,13 @@ class PallasCallTest(PallasTest):
|
||||
if block_size_m <= m and block_size_n <= n and block_size_k <= k
|
||||
])
|
||||
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Matmul only works on GPUs with capability >= sm70")
|
||||
if (plgpu.get_compute_capability(0) <= 75
|
||||
if not self.INTERPRET and (plgpu.get_compute_capability(0) <= 75
|
||||
and (bm > 128 or bn > 128 or bk > 32)):
|
||||
raise unittest.SkipTest("Block sizes too big for sm70.")
|
||||
k1, k2 = random.split(random.PRNGKey(0))
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (m, k), dtype=dtype)
|
||||
y = random.normal(k2, (k, n), dtype=dtype)
|
||||
out, expected = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
|
||||
@ -330,14 +340,14 @@ class PallasCallTest(PallasTest):
|
||||
if block_size_m <= m and block_size_n <= n and block_size_k <= k
|
||||
])
|
||||
def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Matmul only works on GPUs with capability >= sm70")
|
||||
if (plgpu.get_compute_capability(0) <= 75
|
||||
if not self.INTERPRET and (plgpu.get_compute_capability(0) <= 75
|
||||
and (bm > 128 or bn > 128 or bk > 32)):
|
||||
raise unittest.SkipTest("Block sizes too big for sm70.")
|
||||
|
||||
k1, k2 = random.split(random.PRNGKey(0))
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (m, k), dtype=dtype)
|
||||
y = random.normal(k2, (k, n), dtype=dtype)
|
||||
out, expected = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
|
||||
@ -350,7 +360,7 @@ class PallasCallTest(PallasTest):
|
||||
for dtype in ["float32", "float16"]
|
||||
))
|
||||
def test_dot(self, size, dtype):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Matmul only works on GPUs with capability >= sm70")
|
||||
|
||||
@ -363,7 +373,7 @@ class PallasCallTest(PallasTest):
|
||||
y = y_ref[:, :]
|
||||
o_ref[:, :] = pl.dot(x, y).astype(o_ref.dtype)
|
||||
|
||||
k1, k2 = random.split(random.PRNGKey(0))
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (size, size), dtype=dtype)
|
||||
y = random.normal(k2, (size, size), dtype=dtype)
|
||||
out, expected = dot(x, y), jnp.dot(x, y)
|
||||
@ -394,7 +404,7 @@ class PallasCallTest(PallasTest):
|
||||
softmax_output = numerator / denominator
|
||||
pl.store(o_ref, row_idxs, softmax_output, mask=mask)
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
x = random.normal(key, [batch_size, size], dtype=dtype)
|
||||
np.testing.assert_allclose(softmax(x), jax.nn.softmax(x, axis=-1),
|
||||
atol=1e-5, rtol=1e-5)
|
||||
@ -416,7 +426,7 @@ class PallasCallTest(PallasTest):
|
||||
x = pl.load(x_ref, (idx,), mask=mask)
|
||||
pl.store(o_ref, (idx,), x + 1., mask=mask)
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
x = random.normal(key, (size,))
|
||||
np.testing.assert_allclose(add_one(x), x + 1., atol=1e-5, rtol=1e-5)
|
||||
|
||||
@ -431,7 +441,7 @@ class PallasCallTest(PallasTest):
|
||||
x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]))
|
||||
pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), x + 1.)
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
x = random.normal(key, (m, n))
|
||||
np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5)
|
||||
|
||||
@ -449,8 +459,8 @@ class PallasCallTest(PallasTest):
|
||||
y = pl.swap(y_ref, (slice(None),), x)
|
||||
x_ref[:] = y
|
||||
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
y = random.normal(random.PRNGKey(1), (m, n))
|
||||
x = random.normal(random.key(0), (m, n))
|
||||
y = random.normal(random.key(1), (m, n))
|
||||
out = swap(x, y)
|
||||
np.testing.assert_array_equal(out[0], y)
|
||||
np.testing.assert_array_equal(out[1], x)
|
||||
@ -469,9 +479,9 @@ class PallasCallTest(PallasTest):
|
||||
y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:])
|
||||
x_ref[:] = y
|
||||
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
y = random.normal(random.PRNGKey(1), (m, n))
|
||||
mask = random.bernoulli(random.PRNGKey(2), shape=(m, n))
|
||||
x = random.normal(random.key(0), (m, n))
|
||||
y = random.normal(random.key(1), (m, n))
|
||||
mask = random.bernoulli(random.key(2), shape=(m, n))
|
||||
out = masked_swap(x, y, mask)
|
||||
np.testing.assert_array_equal(out[0], jnp.where(mask, y, x))
|
||||
np.testing.assert_array_equal(out[1], jnp.where(mask, x, y))
|
||||
@ -487,7 +497,7 @@ class PallasCallTest(PallasTest):
|
||||
pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]),
|
||||
jnp.ones_like(o_ref))
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
key = random.key(0)
|
||||
x = random.normal(key, (m, n))
|
||||
np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5)
|
||||
|
||||
@ -505,7 +515,7 @@ class PallasCallTest(PallasTest):
|
||||
grid = (8,)
|
||||
size = 8
|
||||
dtype = "float32"
|
||||
k1 = random.PRNGKey(0)
|
||||
k1 = random.key(0)
|
||||
block_size = 1
|
||||
x = random.normal(k1, [size], dtype=dtype)
|
||||
kernel = functools.partial(add_inplace_kernel, block_size=block_size)
|
||||
@ -526,7 +536,7 @@ class PallasCallTest(PallasTest):
|
||||
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
|
||||
])
|
||||
def test_scalar_atomic(self, op, value, numpy_op):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Atomic ops onl works on GPUs with capability >= sm70")
|
||||
|
||||
@ -559,7 +569,7 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
@parameterized.parameters(*[(0,), (1,)])
|
||||
def test_array_atomic_add(self, axis):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Atomic ops onl works on GPUs with capability >= sm70")
|
||||
|
||||
@ -582,7 +592,7 @@ class PallasCallTest(PallasTest):
|
||||
idx = (jnp.arange(m), i)
|
||||
x = pl.load(x_ref, idx)
|
||||
pl.atomic_add(y_ref, (jnp.arange(y.shape[0]),), x)
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
x = random.normal(random.key(0), (m, n))
|
||||
y = jnp.zeros(out_shape.shape, out_shape.dtype)
|
||||
y = reduce(x, y)
|
||||
y_ref = np.sum(x, axis=axis)
|
||||
@ -591,7 +601,7 @@ class PallasCallTest(PallasTest):
|
||||
@parameterized.parameters(False, True)
|
||||
def test_reduce_only_dim(self, use_store):
|
||||
m = 32
|
||||
x = random.normal(random.PRNGKey(0), (m,), dtype=jnp.float32)
|
||||
x = random.normal(random.key(0), (m,), dtype=jnp.float32)
|
||||
out_shape = jax.ShapeDtypeStruct((), x.dtype)
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
@ -634,7 +644,7 @@ class PallasCallTest(PallasTest):
|
||||
else:
|
||||
return random.normal(key, (m, n), dtype=dtype)
|
||||
out_shape = jax.ShapeDtypeStruct(
|
||||
op(make_x(random.PRNGKey(0)), axis=axis).shape, out_dtype)
|
||||
op(make_x(random.key(0)), axis=axis).shape, out_dtype)
|
||||
if isinstance(axis, int):
|
||||
grid = tuple(a for i, a in enumerate((m, n)) if i != axis)
|
||||
else:
|
||||
@ -647,7 +657,7 @@ class PallasCallTest(PallasTest):
|
||||
x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None]))
|
||||
y = op(x, axis=axis)
|
||||
pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y)
|
||||
for i, key in enumerate(random.split(random.PRNGKey(0), 20)):
|
||||
for i, key in enumerate(random.split(random.key(0), 20)):
|
||||
x = make_x(key)
|
||||
y = reduce(x)
|
||||
y_ref = op(x, axis=axis)
|
||||
@ -663,7 +673,7 @@ class PallasCallTest(PallasTest):
|
||||
def slice_kernel(x_ref, y_ref):
|
||||
x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4)))
|
||||
pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x)
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
x = random.normal(random.key(0), (m, n))
|
||||
y = slice_kernel(x)
|
||||
y_ref = x[:4]
|
||||
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2)
|
||||
@ -734,8 +744,7 @@ class PallasCallTest(PallasTest):
|
||||
])
|
||||
def test_atomic_counter(self, num_threads):
|
||||
if self.INTERPRET:
|
||||
self.skipTest("While loop not supported in interpret mode yet.")
|
||||
|
||||
self.skipTest("While loop not supported in interpreter mode.")
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=(
|
||||
jax.ShapeDtypeStruct((), jnp.int32),
|
||||
@ -766,7 +775,7 @@ class PallasCallTest(PallasTest):
|
||||
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
|
||||
|
||||
m, n = 16, 32
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
x = random.normal(random.key(0), (m, n))
|
||||
|
||||
@functools.partial(self.pallas_call, out_shape=x, grid=1)
|
||||
def softmax_kernel(x_ref, y_ref):
|
||||
@ -1199,7 +1208,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
|
||||
x = x_ref[()]
|
||||
o_ref[()] = impl(x)
|
||||
|
||||
k1, k2 = random.split(random.PRNGKey(0))
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1)
|
||||
t = random.normal(k2)
|
||||
out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,))
|
||||
@ -1221,7 +1230,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
|
||||
x = x_ref[()]
|
||||
o_ref[()] = jax.grad(impl)(x)
|
||||
|
||||
x = random.normal(random.PRNGKey(0))
|
||||
x = random.normal(random.key(0))
|
||||
out_grad = pallas_impl(x)
|
||||
out_grad_ref = jax.grad(impl)(x)
|
||||
np.testing.assert_allclose(out_grad, out_grad_ref, atol=1e-5, rtol=1e-5)
|
||||
@ -1237,7 +1246,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
|
||||
o_ref[jnp.arange(2)] = jnp.zeros(2)
|
||||
o_ref[2 + jnp.arange(2)] = impl(x)
|
||||
|
||||
k1, k2 = random.split(random.PRNGKey(0))
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (8,))
|
||||
t = random.normal(k2, (8,))
|
||||
out_primal, out_tangent = jax.jvp(pallas_impl, (x,), (t,))
|
||||
@ -1250,7 +1259,7 @@ class PallasCallAutodifferentiationTest(PallasTest):
|
||||
|
||||
# TODO(sharadmv): enable this when we update Triton
|
||||
# def test_jvp_matmul(self):
|
||||
# k1, k2 = random.split(random.PRNGKey(0))
|
||||
# k1, k2 = random.split(random.key(0))
|
||||
# x = random.normal(k1, (256, 128))
|
||||
# y = random.normal(k2, (128, 64))
|
||||
# bm, bn, bk, gm = 64, 128, 32, 8
|
||||
@ -1400,7 +1409,7 @@ class PallasCallVmapTest(PallasTest):
|
||||
add_one = jax.vmap(jax.vmap(add_one))
|
||||
add_one_ref = lambda x: x + 1
|
||||
|
||||
x = random.randint(random.PRNGKey(0), (4, 65536, 2), 0, 10000)
|
||||
x = random.randint(random.key(0), (4, 65536, 2), 0, 10000)
|
||||
|
||||
out = add_one(x)
|
||||
out_ref = add_one_ref(x)
|
||||
@ -1418,7 +1427,7 @@ class PallasCallVmapTest(PallasTest):
|
||||
add_one = jax.vmap(jax.vmap(jax.vmap(add_one)))
|
||||
add_one_ref = lambda x: x + 1
|
||||
|
||||
x = random.randint(random.PRNGKey(0), (2, 2, 65536, 2), 0, 10000)
|
||||
x = random.randint(random.key(0), (2, 2, 65536, 2), 0, 10000)
|
||||
|
||||
out = add_one(x)
|
||||
out_ref = add_one_ref(x)
|
||||
@ -1438,8 +1447,8 @@ class PallasOpsTest(PallasTest):
|
||||
def ne(x_ref, y_ref, o_ref):
|
||||
o_ref[:] = x_ref[...] != y_ref[...]
|
||||
|
||||
x = jnp.ones(8)
|
||||
y = jnp.arange(8)
|
||||
x = jnp.ones(8, dtype=jnp.int32)
|
||||
y = jnp.arange(8, dtype=jnp.int32)
|
||||
not_equal = ne(x, y)
|
||||
np.testing.assert_allclose(not_equal, x != y)
|
||||
|
||||
@ -1561,11 +1570,11 @@ class FusedAttentionTest(PallasTest):
|
||||
use_segment_ids,
|
||||
kwargs,
|
||||
):
|
||||
if plgpu.get_compute_capability(0) < 80:
|
||||
if not self.check_gpu_capability_at_least(80):
|
||||
raise unittest.SkipTest(
|
||||
"Fused attention only works on GPUs with capability >= sm80")
|
||||
|
||||
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(
|
||||
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
|
||||
)
|
||||
@ -1638,10 +1647,10 @@ class FusedAttentionTest(PallasTest):
|
||||
def test_fused_attention_bwd(
|
||||
self, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
|
||||
):
|
||||
if plgpu.get_compute_capability(0) < 80:
|
||||
if not self.check_gpu_capability_at_least(80):
|
||||
raise unittest.SkipTest(
|
||||
"Fused attention only works on GPUs with capability >= sm80")
|
||||
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(
|
||||
k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16
|
||||
)
|
||||
@ -1671,6 +1680,9 @@ class FusedAttentionTest(PallasTest):
|
||||
np.testing.assert_allclose(dv, dv_ref, atol=0.05)
|
||||
|
||||
|
||||
class FusedAttentionInterpreterTest(PallasTest):
|
||||
INTERPRET = True
|
||||
|
||||
class FusedLayerNormTest(PallasTest):
|
||||
|
||||
@parameterized.parameters(*[
|
||||
@ -1678,10 +1690,10 @@ class FusedLayerNormTest(PallasTest):
|
||||
(2, 384, 192),
|
||||
])
|
||||
def test_fused_layernorm_fwd(self, batch_size, seq_len, embed_dim):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Fused layernorm only works on GPUs with capability >= sm70")
|
||||
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
|
||||
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
|
||||
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
|
||||
@ -1695,10 +1707,10 @@ class FusedLayerNormTest(PallasTest):
|
||||
(2, 384, 192),
|
||||
])
|
||||
def test_fused_layernorm_bwd(self, batch_size, seq_len, embed_dim):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Fused layernorm only works on GPUs with capability >= sm70")
|
||||
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
|
||||
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
|
||||
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
|
||||
@ -1716,6 +1728,10 @@ class FusedLayerNormTest(PallasTest):
|
||||
np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
class FusedLayerNormInterpreterTest(PallasTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class RmsNormTest(PallasTest):
|
||||
|
||||
@parameterized.parameters(*[
|
||||
@ -1723,10 +1739,10 @@ class RmsNormTest(PallasTest):
|
||||
(2, 384, 192),
|
||||
])
|
||||
def test_rms_fwd(self, batch_size, seq_len, embed_dim):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Rms norm only works on GPUs with capability >= sm70")
|
||||
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
|
||||
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
|
||||
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
|
||||
@ -1740,10 +1756,10 @@ class RmsNormTest(PallasTest):
|
||||
(2, 384, 192),
|
||||
])
|
||||
def test_rms_norm_bwd(self, batch_size, seq_len, embed_dim):
|
||||
if plgpu.get_compute_capability(0) < 70:
|
||||
if not self.check_gpu_capability_at_least(70):
|
||||
raise unittest.SkipTest(
|
||||
"Rms norm only works on GPUs with capability >= sm70")
|
||||
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
x = random.normal(k1, (batch_size, seq_len, embed_dim), dtype=jnp.float32)
|
||||
w = jax.random.normal(k2, (embed_dim,), dtype=jnp.float32)
|
||||
b = jax.random.normal(k3, (embed_dim,), dtype=jnp.float32)
|
||||
@ -1760,6 +1776,8 @@ class RmsNormTest(PallasTest):
|
||||
np.testing.assert_allclose(dw, dw_ref, rtol=1e-2, atol=1e-2)
|
||||
np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
class RmsNormInterpreterTest(PallasTest):
|
||||
INTERPRET = True
|
||||
|
||||
class SoftmaxTest(PallasTest):
|
||||
|
||||
@ -1773,7 +1791,7 @@ class SoftmaxTest(PallasTest):
|
||||
if dtype == jnp.bfloat16:
|
||||
raise absltest.SkipTest("Disabled due to Triton lowering bug")
|
||||
|
||||
x = jax.random.normal(random.PRNGKey(0), shape, dtype=dtype)
|
||||
x = jax.random.normal(random.key(0), shape, dtype=dtype)
|
||||
|
||||
atol, rtol = {
|
||||
jnp.bfloat16: (1e-2, 1e-4),
|
||||
@ -1789,5 +1807,8 @@ class SoftmaxTest(PallasTest):
|
||||
)
|
||||
|
||||
|
||||
class SoftmaxInterpreterTest(PallasTest):
|
||||
INTERPRET = True
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user