mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[pallas:mosaic_gpu] Do not specify the default index_map
in tests
PiperOrigin-RevId: 743816110
This commit is contained in:
parent
a9bd1e3f9d
commit
4f00249aa8
@ -644,8 +644,6 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
||||
out_spec = plgpu.GPUBlockSpec(
|
||||
(128, 128),
|
||||
lambda: (0, 0),
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, 32)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
@ -676,9 +674,7 @@ class PallasCallTest(PallasTest):
|
||||
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
|
||||
|
||||
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
||||
out_spec = plgpu.GPUBlockSpec(
|
||||
(128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM,
|
||||
)
|
||||
out_spec = plgpu.GPUBlockSpec(transforms=ts, memory_space=plgpu.SMEM)
|
||||
f = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
|
||||
@ -719,8 +715,6 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
||||
out_spec = plgpu.GPUBlockSpec(
|
||||
(2, 128, 128),
|
||||
lambda: (0, 0, 0),
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, 32)),
|
||||
plgpu.TransposeTransform((0, 2, 1, 3, 4)),
|
||||
@ -750,11 +744,7 @@ class PallasCallTest(PallasTest):
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32),
|
||||
in_specs=[pl.BlockSpec(memory_space=src_memory_space)],
|
||||
out_specs=plgpu.GPUBlockSpec(
|
||||
(2, 128),
|
||||
lambda: (0, 0),
|
||||
memory_space=plgpu.SMEM,
|
||||
),
|
||||
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
for i in range(2):
|
||||
@ -776,11 +766,7 @@ class PallasCallTest(PallasTest):
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32),
|
||||
in_specs=[pl.BlockSpec(memory_space=src_memory_space)],
|
||||
out_specs=plgpu.GPUBlockSpec(
|
||||
(2, m),
|
||||
lambda: (0, 0),
|
||||
memory_space=plgpu.SMEM,
|
||||
),
|
||||
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
for i in range(2):
|
||||
@ -819,24 +805,19 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32))
|
||||
o_ref[...] = out
|
||||
|
||||
out_spec = plgpu.GPUBlockSpec(
|
||||
(m, n), lambda: (0, 0), memory_space=plgpu.SMEM,
|
||||
)
|
||||
f = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32),
|
||||
in_specs=(
|
||||
pl.BlockSpec(memory_space=src_memory_space),
|
||||
plgpu.GPUBlockSpec(
|
||||
(k, n),
|
||||
lambda: (0, 0),
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128),
|
||||
plgpu.TilingTransform((8, 64)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
),
|
||||
),
|
||||
),
|
||||
out_specs=out_spec,
|
||||
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
|
||||
)
|
||||
|
||||
out_ref = (
|
||||
@ -855,9 +836,7 @@ class PallasCallTest(PallasTest):
|
||||
plgpu.barrier_wait(barrier_ref)
|
||||
|
||||
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
||||
out_spec = plgpu.GPUBlockSpec(
|
||||
(2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM,
|
||||
)
|
||||
out_spec = plgpu.GPUBlockSpec(memory_space=plgpu.SMEM)
|
||||
f = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32),
|
||||
@ -960,11 +939,10 @@ class PallasCallTest(PallasTest):
|
||||
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec(
|
||||
shape,
|
||||
lambda: (0, 0),
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128),
|
||||
),
|
||||
plgpu.TilingTransform((8, 32)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -1143,7 +1121,8 @@ class PallasCallTest(PallasTest):
|
||||
(128, 64),
|
||||
lambda *i: i,
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128),
|
||||
plgpu.TilingTransform((8, 64)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
),
|
||||
)
|
||||
@functools.partial(
|
||||
@ -1327,9 +1306,7 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
shape = (256, 128)
|
||||
block_spec = plgpu.GPUBlockSpec(
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128),
|
||||
)
|
||||
transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
||||
)
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
@ -1380,12 +1357,7 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
|
||||
spec = plgpu.GPUBlockSpec(
|
||||
(128, 128),
|
||||
lambda: (0, 0),
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, 64)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
),
|
||||
transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
||||
)
|
||||
f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec)
|
||||
expected = np.empty_like(x)
|
||||
@ -1560,11 +1532,9 @@ class PallasCallSm90ATest(PallasSm90ATest):
|
||||
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)
|
||||
],
|
||||
in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)],
|
||||
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16),
|
||||
out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)),
|
||||
out_specs=plgpu.GPUBlockSpec((64, 64)),
|
||||
)
|
||||
def kernel(i_ref, o_ref):
|
||||
def scope(acc_ref):
|
||||
@ -1613,25 +1583,28 @@ class PallasCallSm90ATest(PallasSm90ATest):
|
||||
|
||||
if self.THREAD_SEMANTICS == plgpu.ThreadSemantics.Lane:
|
||||
lhs_spec = plgpu.GPUBlockSpec(
|
||||
lhs_spec.block_shape, lhs_spec.index_map,
|
||||
lhs_spec.block_shape,
|
||||
lhs_spec.index_map,
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, elems_128b)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
)
|
||||
),
|
||||
)
|
||||
rhs_spec = plgpu.GPUBlockSpec(
|
||||
rhs_spec.block_shape, rhs_spec.index_map,
|
||||
rhs_spec.block_shape,
|
||||
rhs_spec.index_map,
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, elems_128b)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
)
|
||||
),
|
||||
)
|
||||
out_spec = plgpu.GPUBlockSpec(
|
||||
out_spec.block_shape, out_spec.index_map,
|
||||
out_spec.block_shape,
|
||||
out_spec.index_map,
|
||||
transforms=(
|
||||
plgpu.TilingTransform((8, elems_128b)),
|
||||
plgpu.SwizzleTransform(128),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
res = self.pallas_call(
|
||||
@ -1717,14 +1690,9 @@ class PallasCallSm90ATest(PallasSm90ATest):
|
||||
res = self.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec(
|
||||
(64, 128), lambda: (0, 0), transforms=transforms
|
||||
),
|
||||
plgpu.GPUBlockSpec(
|
||||
(128, 192), lambda: (0, 0), transforms=transforms
|
||||
),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
],
|
||||
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
|
||||
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
||||
)(a, b)
|
||||
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
||||
@ -1747,17 +1715,10 @@ class PallasCallSm90ATest(PallasSm90ATest):
|
||||
res = self.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec(
|
||||
(64, 128), lambda: (0, 0), transforms=transforms
|
||||
),
|
||||
plgpu.GPUBlockSpec(
|
||||
(128, 192), lambda: (0, 0), transforms=transforms
|
||||
),
|
||||
plgpu.GPUBlockSpec(
|
||||
(64, 192), lambda: (0, 0), transforms=transforms
|
||||
),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
],
|
||||
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
|
||||
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16),
|
||||
)(a, b, i)
|
||||
np.testing.assert_allclose(res, i + a @ b, rtol=2e-3)
|
||||
@ -1783,14 +1744,9 @@ class PallasCallSm90ATest(PallasSm90ATest):
|
||||
res = self.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec(
|
||||
(2, 64, 128), lambda: (0, 0, 0), transforms=transforms
|
||||
),
|
||||
plgpu.GPUBlockSpec(
|
||||
(2, 128, 192), lambda: (0, 0, 0), transforms=transforms
|
||||
),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
],
|
||||
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
|
||||
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
||||
)(a, b)
|
||||
np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user