[Pallas TPU] Temporarily strengthen restrictions on Pallas tests

Mosaic is not more aggressive in its inference of large 2nd minor layouts,
which causes slight problems for Pallas pipelines. This will be addressed
shortly.

PiperOrigin-RevId: 715714752
This commit is contained in:
Adam Paszke 2025-01-15 02:31:22 -08:00 committed by jax authors
parent c4406d2759
commit aa19f9c4c4

View File

@ -1519,6 +1519,9 @@ if CAN_USE_HYPOTHESIS:
if dtype == 'int8' and jtu.is_device_tpu_at_least(6):
self.skipTest('Not implemented for TPU v6.')
def align_up_to(x, y):
return (x + y - 1) // y * y
hp.assume(bm <= m)
hp.assume(bn <= n)
hp.assume(bk <= k)
@ -1528,6 +1531,11 @@ if CAN_USE_HYPOTHESIS:
if not jtu.is_device_tpu_at_least(5):
self.skipTest('Only TPU v5+ allowed for int8.')
hp.assume(bm >= 32)
# TODO(apaszke): Relax DMA restrictions and remove this.
packing = 4 // jnp.dtype(dtype).itemsize
if packing != 1:
m = align_up_to(m, 8 * packing)
k = align_up_to(k, 8 * packing)
k1, k2 = jax.random.split(jax.random.key(seed))
x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype)
y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype)