From 4f00249aa8bff45b379f76304e79e293273f9ad6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Apr 2025 22:30:45 -0700 Subject: [PATCH] [pallas:mosaic_gpu] Do not specify the default `index_map` in tests PiperOrigin-RevId: 743816110 --- tests/pallas/mosaic_gpu_test.py | 108 ++++++++++---------------------- 1 file changed, 32 insertions(+), 76 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d35446359..0cfe9197d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -644,8 +644,6 @@ class PallasCallTest(PallasTest): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), transforms=( plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), @@ -676,9 +674,7 @@ class PallasCallTest(PallasTest): pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, - ) + out_spec = plgpu.GPUBlockSpec(transforms=ts, memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), @@ -719,8 +715,6 @@ class PallasCallTest(PallasTest): in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) out_spec = plgpu.GPUBlockSpec( - (2, 128, 128), - lambda: (0, 0, 0), transforms=( plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), @@ -750,11 +744,7 @@ class PallasCallTest(PallasTest): self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec( - (2, 128), - lambda: (0, 0), - memory_space=plgpu.SMEM, - ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -776,11 +766,7 @@ class PallasCallTest(PallasTest): self.pallas_call, out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), in_specs=[pl.BlockSpec(memory_space=src_memory_space)], - out_specs=plgpu.GPUBlockSpec( - (2, m), - lambda: (0, 0), - memory_space=plgpu.SMEM, - ), + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) def kernel(x_ref, o_ref): for i in range(2): @@ -819,24 +805,19 @@ class PallasCallTest(PallasTest): out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) o_ref[...] = out - - out_spec = plgpu.GPUBlockSpec( - (m, n), lambda: (0, 0), memory_space=plgpu.SMEM, - ) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), in_specs=( pl.BlockSpec(memory_space=src_memory_space), plgpu.GPUBlockSpec( - (k, n), - lambda: (0, 0), transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), ), ), ), - out_specs=out_spec, + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM), ) out_ref = ( @@ -855,9 +836,7 @@ class PallasCallTest(PallasTest): plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, - ) + out_spec = plgpu.GPUBlockSpec(memory_space=plgpu.SMEM) f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), @@ -960,11 +939,10 @@ class PallasCallTest(PallasTest): out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=[ plgpu.GPUBlockSpec( - shape, - lambda: (0, 0), transforms=( - plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), - ), + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(128), + ) ) ], ) @@ -1143,7 +1121,8 @@ class PallasCallTest(PallasTest): (128, 64), lambda *i: i, transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), ), ) @functools.partial( @@ -1327,9 +1306,7 @@ class PallasCallTest(PallasTest): shape = (256, 128) block_spec = plgpu.GPUBlockSpec( - transforms=( - plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), - ) + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) @functools.partial( self.pallas_call, @@ -1380,12 +1357,7 @@ class PallasCallTest(PallasTest): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((8, 64)), - plgpu.SwizzleTransform(128), - ), + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) @@ -1560,11 +1532,9 @@ class PallasCallSm90ATest(PallasSm90ATest): transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) @functools.partial( self.pallas_call, - in_specs=[ - plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms) - ], + in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), + out_specs=plgpu.GPUBlockSpec((64, 64)), ) def kernel(i_ref, o_ref): def scope(acc_ref): @@ -1613,25 +1583,28 @@ class PallasCallSm90ATest(PallasSm90ATest): if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane: lhs_spec = plgpu.GPUBlockSpec( - lhs_spec.block_shape, lhs_spec.index_map, + lhs_spec.block_shape, + lhs_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) rhs_spec = plgpu.GPUBlockSpec( - rhs_spec.block_shape, rhs_spec.index_map, + rhs_spec.block_shape, + rhs_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) out_spec = plgpu.GPUBlockSpec( - out_spec.block_shape, out_spec.index_map, + out_spec.block_shape, + out_spec.index_map, transforms=( plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), - ) + ), ) res = self.pallas_call( @@ -1717,14 +1690,9 @@ class PallasCallSm90ATest(PallasSm90ATest): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (128, 192), lambda: (0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) @@ -1747,17 +1715,10 @@ class PallasCallSm90ATest(PallasSm90ATest): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (128, 192), lambda: (0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (64, 192), lambda: (0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), )(a, b, i) np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) @@ -1783,14 +1744,9 @@ class PallasCallSm90ATest(PallasSm90ATest): res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), transforms=transforms - ), - plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), transforms=transforms - ), + plgpu.GPUBlockSpec(transforms=transforms), + plgpu.GPUBlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3)