mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas] Add non-square pl.dot test cases.
PiperOrigin-RevId: 704788500
This commit is contained in:
parent
593143e17e
commit
e418e88321
@ -17,6 +17,7 @@
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import sys
|
||||
from typing import Any
|
||||
import unittest
|
||||
@ -62,6 +63,10 @@ intx = dtypes.canonicalize_dtype(jnp.int64)
|
||||
floatx = dtypes.canonicalize_dtype(jnp.float64)
|
||||
|
||||
|
||||
def is_power_of_two(n: int) -> bool:
|
||||
return (n > 0) and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def smem_on_tpu():
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
return pltpu.SMEM
|
||||
@ -1410,12 +1415,45 @@ class OpsTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(f(x), expected)
|
||||
|
||||
@parameterized.product(
|
||||
size=[16, 32, 64, 128, 256],
|
||||
lhs_and_rhs_shape=[
|
||||
((16, 16), (16, 16)),
|
||||
((32, 32), (32, 32)),
|
||||
((64, 64), (64, 64)),
|
||||
((128, 128), (128, 128)),
|
||||
((256, 256), (256, 256)),
|
||||
((8, 128), (128, 256)),
|
||||
((8, 128), (256, 128)),
|
||||
((8, 256), (256, 128)),
|
||||
((16, 128), (128, 256)),
|
||||
((16, 128), (256, 128)),
|
||||
((16, 256), (256, 128)),
|
||||
((24, 128), (128, 256)),
|
||||
((24, 128), (256, 128)),
|
||||
((24, 256), (256, 128)),
|
||||
((128, 8), (128, 256)),
|
||||
((128, 8), (256, 128)),
|
||||
((256, 8), (256, 128)),
|
||||
((128, 16), (128, 256)),
|
||||
((128, 16), (256, 128)),
|
||||
((256, 16), (256, 128)),
|
||||
((128, 24), (128, 256)),
|
||||
((128, 24), (256, 128)),
|
||||
((256, 24), (256, 128)),
|
||||
],
|
||||
dtype=[jnp.float32, jnp.float16, jnp.bfloat16],
|
||||
trans_x=[False, True],
|
||||
trans_y=[False, True],
|
||||
)
|
||||
def test_dot(self, size, dtype, trans_x, trans_y):
|
||||
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
|
||||
lhs_shape, rhs_shape = lhs_and_rhs_shape
|
||||
|
||||
final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape
|
||||
final_rhs_shape = rhs_shape[::-1] if trans_y else rhs_shape
|
||||
if final_lhs_shape[1] != final_rhs_shape[0]:
|
||||
self.skipTest("Contraction dimensions do not match")
|
||||
|
||||
out_shape = (final_lhs_shape[0], final_rhs_shape[1])
|
||||
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
if dtype == jnp.float16:
|
||||
self.skipTest("float16 type is not supported on TPU")
|
||||
@ -1427,12 +1465,19 @@ class OpsTest(PallasBaseTest):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
if dtype == jnp.bfloat16:
|
||||
self.skipTest("bfloat16 type are not supported on GPU")
|
||||
if size > 128:
|
||||
if (
|
||||
math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape)
|
||||
> (256 * 256) * 2
|
||||
):
|
||||
self.skipTest("Shared memory size limit exceeded")
|
||||
if min(*lhs_shape, *rhs_shape) < 16:
|
||||
self.skipTest("All dimensions of lhs and rhs must be >= 16")
|
||||
if any(not is_power_of_two(x) for x in lhs_shape + rhs_shape):
|
||||
self.skipTest("All dimensions of lhs and rhs must be power of two")
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((size, size), dtype),
|
||||
out_shape=jax.ShapeDtypeStruct(out_shape, dtype),
|
||||
grid=1,
|
||||
)
|
||||
def dot(x_ref, y_ref, o_ref):
|
||||
@ -1441,8 +1486,8 @@ class OpsTest(PallasBaseTest):
|
||||
o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype)
|
||||
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (size, size), dtype=dtype)
|
||||
y = random.normal(k2, (size, size), dtype=dtype)
|
||||
x = random.normal(k1, lhs_shape, dtype=dtype)
|
||||
y = random.normal(k2, rhs_shape, dtype=dtype)
|
||||
out = dot(x, y)
|
||||
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y)
|
||||
np.testing.assert_allclose(
|
||||
|
Loading…
x
Reference in New Issue
Block a user