[Pallas] Add non-square pl.dot test cases.

PiperOrigin-RevId: 704788500
This commit is contained in:
Tzu-Wei Sung 2024-12-10 11:37:29 -08:00 committed by jax authors
parent 593143e17e
commit e418e88321

View File

@ -17,6 +17,7 @@
from collections.abc import Sequence
import functools
import itertools
import math
import sys
from typing import Any
import unittest
@ -62,6 +63,10 @@ intx = dtypes.canonicalize_dtype(jnp.int64)
floatx = dtypes.canonicalize_dtype(jnp.float64)
def is_power_of_two(n: int) -> bool:
return (n > 0) and (n & (n - 1) == 0)
def smem_on_tpu():
if jtu.test_device_matches(["tpu"]):
return pltpu.SMEM
@ -1410,12 +1415,45 @@ class OpsTest(PallasBaseTest):
np.testing.assert_allclose(f(x), expected)
@parameterized.product(
size=[16, 32, 64, 128, 256],
lhs_and_rhs_shape=[
((16, 16), (16, 16)),
((32, 32), (32, 32)),
((64, 64), (64, 64)),
((128, 128), (128, 128)),
((256, 256), (256, 256)),
((8, 128), (128, 256)),
((8, 128), (256, 128)),
((8, 256), (256, 128)),
((16, 128), (128, 256)),
((16, 128), (256, 128)),
((16, 256), (256, 128)),
((24, 128), (128, 256)),
((24, 128), (256, 128)),
((24, 256), (256, 128)),
((128, 8), (128, 256)),
((128, 8), (256, 128)),
((256, 8), (256, 128)),
((128, 16), (128, 256)),
((128, 16), (256, 128)),
((256, 16), (256, 128)),
((128, 24), (128, 256)),
((128, 24), (256, 128)),
((256, 24), (256, 128)),
],
dtype=[jnp.float32, jnp.float16, jnp.bfloat16],
trans_x=[False, True],
trans_y=[False, True],
)
def test_dot(self, size, dtype, trans_x, trans_y):
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
lhs_shape, rhs_shape = lhs_and_rhs_shape
final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape
final_rhs_shape = rhs_shape[::-1] if trans_y else rhs_shape
if final_lhs_shape[1] != final_rhs_shape[0]:
self.skipTest("Contraction dimensions do not match")
out_shape = (final_lhs_shape[0], final_rhs_shape[1])
if jtu.test_device_matches(["tpu"]):
if dtype == jnp.float16:
self.skipTest("float16 type is not supported on TPU")
@ -1427,12 +1465,19 @@ class OpsTest(PallasBaseTest):
if jtu.test_device_matches(["gpu"]):
if dtype == jnp.bfloat16:
self.skipTest("bfloat16 type are not supported on GPU")
if size > 128:
if (
math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape)
> (256 * 256) * 2
):
self.skipTest("Shared memory size limit exceeded")
if min(*lhs_shape, *rhs_shape) < 16:
self.skipTest("All dimensions of lhs and rhs must be >= 16")
if any(not is_power_of_two(x) for x in lhs_shape + rhs_shape):
self.skipTest("All dimensions of lhs and rhs must be power of two")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((size, size), dtype),
out_shape=jax.ShapeDtypeStruct(out_shape, dtype),
grid=1,
)
def dot(x_ref, y_ref, o_ref):
@ -1441,8 +1486,8 @@ class OpsTest(PallasBaseTest):
o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype)
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (size, size), dtype=dtype)
y = random.normal(k2, (size, size), dtype=dtype)
x = random.normal(k1, lhs_shape, dtype=dtype)
y = random.normal(k2, rhs_shape, dtype=dtype)
out = dot(x, y)
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y)
np.testing.assert_allclose(