[Pallas] Remove grid=1 in tests

Remove `grid=1` in tests because it's the same as not specifying `grid`.

PiperOrigin-RevId: 705077047
This commit is contained in:
Ayaka 2024-12-11 05:55:54 -08:00 committed by jax authors
parent 0d7eaeb5d8
commit 13ce51785d
2 changed files with 26 additions and 50 deletions

View File

@ -859,7 +859,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
grid=1,
)
def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])
@ -939,7 +938,7 @@ class OpsTest(PallasBaseTest):
self.skipTest("64-bit types require x64_enabled")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), x_dtype), grid=1
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), x_dtype),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = lax.pow(x_ref[...], y_ref[...])
@ -1015,7 +1014,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_),
grid=1,
)
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = fn(x_ref[...], y_ref[...])
@ -1051,7 +1049,6 @@ class OpsTest(PallasBaseTest):
),
out_specs=pl.BlockSpec(memory_space=smem_on_tpu()),
out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_),
grid=1,
)
def kernel(x_ref, y_ref, o_ref):
for i in range(8):
@ -1066,7 +1063,7 @@ class OpsTest(PallasBaseTest):
def test_isnan(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_),
grid=1)
)
def isnan(x_ref, o_ref):
o_ref[:] = jnp.isnan(x_ref[...])
@ -1098,7 +1095,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8,), out_dtype),
grid=1,
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])
@ -1115,7 +1111,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), dtype),
grid=1,
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])
@ -1151,7 +1146,7 @@ class OpsTest(PallasBaseTest):
self.skipTest("16-bit types are not supported on TPU")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = f(x_ref[...], y_ref[...])
@ -1181,7 +1176,7 @@ class OpsTest(PallasBaseTest):
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((1,), dtype), grid=1
out_shape=jax.ShapeDtypeStruct((1,), dtype),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[0] = f(x_ref[0], y_ref[0])
@ -1203,7 +1198,7 @@ class OpsTest(PallasBaseTest):
f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension)
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype), grid=1
self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, dtype),
)
def kernel(o_ref):
o_ref[...] = f()
@ -1223,7 +1218,7 @@ class OpsTest(PallasBaseTest):
self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = plgpu.approx_tanh(x_ref[...])
@ -1250,7 +1245,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((256,), jnp.float16),
grid=1,
)
def kernel(x_ref, o_ref):
[o_ref[...]] = plgpu.elementwise_inline_asm(
@ -1274,7 +1268,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...]
@ -1300,7 +1293,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1)
)
def kernel(x_ref, o_ref):
@ -1328,7 +1320,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1)
)
def kernel(x_ref, o_ref):
@ -1354,7 +1345,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
grid=1,
)
def f(x_ref, o_ref):
o_ref[...] = x_ref[...].reshape(out_shape)
@ -1387,7 +1377,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
grid=1,
)
def f(x_ref, o_ref):
o_ref[...] = x_ref[...].reshape(out_shape)
@ -1418,7 +1407,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx),
grid=1,
)
def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref):
mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])[:, None, None]
@ -1447,7 +1435,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
grid=1,
)
def f(x_ref, o_ref):
x = x_ref[...]
@ -1521,7 +1508,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, dtype),
grid=1,
)
def dot(x_ref, y_ref, o_ref):
x = x_ref[:, :]
@ -1574,7 +1560,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((n,), floatx)),
grid=1,
)
def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref):
x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)),
@ -1610,7 +1595,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), floatx)),
grid=1,
)
def load(x_ref, o_ref):
x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]))
@ -1661,7 +1645,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2,
grid=1,
input_output_aliases={0: 0, 1: 1},
)
def swap(_, _2, x_ref, y_ref):
@ -1684,7 +1667,6 @@ class OpsTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2,
grid=1,
input_output_aliases={0: 0, 1: 1},
)
def masked_swap(_, _2, mask_ref, x_ref, y_ref):
@ -1710,7 +1692,6 @@ class OpsTest(PallasBaseTest):
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((n,), floatx),
jax.ShapeDtypeStruct((m,), floatx)),
grid=1,
input_output_aliases={0: 0, 1: 1},
)
def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref):
@ -1876,9 +1857,7 @@ class OpsTest(PallasBaseTest):
x = random.normal(random.key(0), (m,), dtype=jnp.float32)
out_shape = jax.ShapeDtypeStruct((), x.dtype)
@functools.partial(
self.pallas_call, out_shape=out_shape, grid=1
)
@functools.partial(self.pallas_call, out_shape=out_shape)
def reduce(x_ref, y_ref):
x = pl.load(x_ref, (jnp.arange(m),))
y = jnp.sum(x, axis=-1)
@ -2126,7 +2105,6 @@ class OpsInterpretTest(OpsTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
)
def kernel(x_ref, o_ref):
jax.debug.print("x = {}", x_ref)

View File

@ -170,7 +170,7 @@ class PallasCallTest(PallasBaseTest):
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((1,), jnp.float32),
grid=1)
)
def add_one(x_ref, o_ref):
o_ref[0] = x_ref[0] + 1.
@ -224,7 +224,7 @@ class PallasCallTest(PallasBaseTest):
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx),
grid=1)
)
def index(x_ref, i_ref, o_ref):
o_ref[()] = x_ref[i_ref[()]]
@ -518,7 +518,7 @@ class PallasCallTest(PallasBaseTest):
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), floatx),
grid=1)
)
def index(x_ref, idx_ref, o_ref):
idx = idx_ref[()]
o_ref[:] = x_ref[idx]
@ -613,9 +613,8 @@ class PallasCallTest(PallasBaseTest):
m, n = 16, 32
@functools.partial(
self.pallas_call,
out_shape=(
jax.ShapeDtypeStruct((m, n), jnp.float32)
), grid=1)
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32),
)
def dummy(_, o_ref):
pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]),
jnp.ones_like(o_ref))
@ -658,7 +657,7 @@ class PallasCallTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=out_shape,
grid=1)
)
def slice_kernel(x_ref, y_ref):
x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4)))
pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x)
@ -673,7 +672,7 @@ class PallasCallTest(PallasBaseTest):
trace_count = 0
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
grid=1)
)
def add_one(x_ref, o_ref):
nonlocal trace_count
o_ref[()] = x_ref[()] + 1.
@ -711,7 +710,6 @@ class PallasCallTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((32, 64), jnp.float32),
grid=1,
)
def dot_kernel(x_ref, y_ref, o_ref):
o_ref[()] = pl.dot(x_ref[()], y_ref[()], precision=precision)
@ -1158,10 +1156,10 @@ class PallasControlFlowTest(PallasBaseTest):
# control flow edge from Region #0 to Region #0: source type #0
# 'tensor<4xf64>' should match input type #0 'tensor<4xf32>'
with config.enable_x64(True):
@functools.partial(self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float64),
grid=1,
)
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float64),
)
def f(x_ref, y_ref):
def body(i, acc):
# TODO(sharadmv): DCE loop index but retain carry breaks scan pattern.
@ -1196,10 +1194,10 @@ class PallasControlFlowTest(PallasBaseTest):
self.skipTest("TODO: error on TPU")
arg = jnp.float32(0.)
@functools.partial(self.pallas_call,
out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32),
grid=1,
)
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(arg.shape, jnp.float32),
)
def f(branch_ref, x_ref, y_ref):
y_ref[...] = lax.switch(
branch_ref[...],
@ -1904,7 +1902,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx),
grid=1)
)
def pallas_impl(x_ref, o_ref):
x = x_ref[()]
o_ref[()] = impl(x)
@ -1927,7 +1925,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((), floatx),
name=self.id().split(".")[-1],
grid=1)
)
def pallas_impl(x_ref, o_ref):
x = x_ref[()]
o_ref[()] = jax.grad(impl)(x)
@ -1945,7 +1943,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx),
grid=1)
)
def pallas_impl(x_ref, o_ref):
x = x_ref[jnp.arange(2)]
o_ref[jnp.arange(2)] = jnp.zeros(2)
@ -1979,7 +1977,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
m, n = 16, 32
x = random.normal(random.key(0), (m, n))
@functools.partial(self.pallas_call, out_shape=x, grid=1)
@functools.partial(self.pallas_call, out_shape=x)
def softmax_kernel(x_ref, y_ref):
y_ref[:] = softmax(x_ref[:])