mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas TPU] Temporarily strengthen restrictions on Pallas tests
Mosaic is not more aggressive in its inference of large 2nd minor layouts, which causes slight problems for Pallas pipelines. This will be addressed shortly. PiperOrigin-RevId: 715714752
This commit is contained in:
parent
c4406d2759
commit
aa19f9c4c4
@ -1519,6 +1519,9 @@ if CAN_USE_HYPOTHESIS:
|
||||
if dtype == 'int8' and jtu.is_device_tpu_at_least(6):
|
||||
self.skipTest('Not implemented for TPU v6.')
|
||||
|
||||
def align_up_to(x, y):
|
||||
return (x + y - 1) // y * y
|
||||
|
||||
hp.assume(bm <= m)
|
||||
hp.assume(bn <= n)
|
||||
hp.assume(bk <= k)
|
||||
@ -1528,6 +1531,11 @@ if CAN_USE_HYPOTHESIS:
|
||||
if not jtu.is_device_tpu_at_least(5):
|
||||
self.skipTest('Only TPU v5+ allowed for int8.')
|
||||
hp.assume(bm >= 32)
|
||||
# TODO(apaszke): Relax DMA restrictions and remove this.
|
||||
packing = 4 // jnp.dtype(dtype).itemsize
|
||||
if packing != 1:
|
||||
m = align_up_to(m, 8 * packing)
|
||||
k = align_up_to(k, 8 * packing)
|
||||
k1, k2 = jax.random.split(jax.random.key(seed))
|
||||
x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype)
|
||||
y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user