mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00

Of course no communication can happen across grid dimensions (unlike over the WG dim), but we need to be able to launch multiple blocks somehow. PiperOrigin-RevId: 688488660
1053 lines
34 KiB
Python
1053 lines
34 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import functools
|
|
import math
|
|
import re
|
|
import traceback
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import jax
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax.experimental import pallas as pl
|
|
from jax.experimental.pallas import mosaic_gpu as plgpu
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
class PallasTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if config.enable_x64.value:
|
|
self.skipTest("Only works on x32 at the moment")
|
|
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
|
self.skipTest("Only works on a GPU with capability >= sm90")
|
|
|
|
super().setUp()
|
|
|
|
|
|
class PallasCallTest(PallasTest):
|
|
|
|
@parameterized.named_parameters(
|
|
("add_one", lambda x: x + 1.),
|
|
("logistic", jax.lax.logistic),
|
|
("exp", jax.lax.exp),
|
|
("square", lambda x: x ** 2),
|
|
("rsqrt", jax.lax.rsqrt),
|
|
)
|
|
def test_unary_ops(self, unary):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = unary(x_ref[...])
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), unary(x))
|
|
|
|
def test_add_first(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + y_ref[0]
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
y = jnp.flip(x).reshape(1, 256)
|
|
np.testing.assert_array_equal(kernel(x, y), x + y[0])
|
|
|
|
def test_add_xy(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + y_ref[...]
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
y = x + 1
|
|
np.testing.assert_array_equal(kernel(x, y), x + y)
|
|
|
|
def test_add_xy_indexed(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
idx = jnp.sum(y_ref[...])
|
|
o_ref[...] = x_ref[idx]
|
|
|
|
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
|
|
y = jnp.zeros(128, dtype=jnp.int32)
|
|
np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)])
|
|
|
|
def test_add_one_grid(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
|
|
grid=2,
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + 1.0
|
|
|
|
x = jnp.arange(128 * 2).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
def test_add_one_grid_with_scratch(self):
|
|
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32),
|
|
in_specs=[pl.BlockSpec((128,), lambda *i: i)],
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
scratch_shapes=[plgpu.SMEM((128,), jnp.float32)],
|
|
grid=2,
|
|
)
|
|
def kernel(x_ref, o_ref, scratch_ref):
|
|
scratch_ref[...] = x_ref[...] + 1
|
|
o_ref[...] = scratch_ref[...]
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
@parameterized.product(max_concurrent_steps=[1, 2, 3, 4, 16])
|
|
def test_add_one_grid_pipelined(self, max_concurrent_steps):
|
|
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))],
|
|
out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "sequential"],
|
|
max_concurrent_steps=max_concurrent_steps,
|
|
),
|
|
grid=(2, 4),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + 1.0
|
|
|
|
x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
def test_add_one_grid_pipelined_program_id(self):
|
|
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)),
|
|
out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "sequential"],
|
|
max_concurrent_steps=2,
|
|
),
|
|
grid=(4, 4),
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(),
|
|
jnp.repeat(jnp.repeat(jnp.arange(4), 16)[None], 16, axis=0),
|
|
)
|
|
|
|
def test_add_one_grid_pipelined_sequential_invariant_output(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))],
|
|
out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)),
|
|
out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "sequential"],
|
|
max_concurrent_steps=2,
|
|
),
|
|
grid=(2, 4),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + 1.0
|
|
|
|
x = jnp.arange(32 * 2 * 64).reshape((32 * 2, 64)).astype(jnp.float32)
|
|
y = jnp.empty_like(x)
|
|
for i in range(2):
|
|
i_slice = slice(32 * i, 32 * (i + 1))
|
|
for j in range(4):
|
|
j_slice = slice(16 * j, 16 * (j + 1))
|
|
y = y.at[i_slice, :16].set(x[i_slice, j_slice] + 1)
|
|
|
|
# We only compare the elements in the first 16 columns, because the rest
|
|
# are never written to.
|
|
np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16])
|
|
|
|
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
|
|
def test_copy_smem_to_gmem(self, indexer):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
scratch_shapes=[plgpu.SMEM((256,), jnp.float32)],
|
|
)
|
|
def kernel(x_ref, o_ref_gmem, scratch_ref):
|
|
scratch_ref[...] = x_ref[...] + 1
|
|
plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer])
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
|
|
|
|
@parameterized.product(indexer=[..., slice(128), slice(None, 128)])
|
|
def test_copy_gmem_to_smem(self, indexer):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
scratch_shapes=[
|
|
plgpu.SMEM((256,), jnp.float32),
|
|
plgpu.Barrier(num_arrivals=1),
|
|
],
|
|
)
|
|
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier=barrier_ref
|
|
)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = scratch_ref[...] + 1
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)
|
|
|
|
@parameterized.product(indexer=[0, 1, 2, 3])
|
|
def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
scratch_shapes=[
|
|
plgpu.SMEM((128,), jnp.float32),
|
|
plgpu.Barrier(num_arrivals=1, num_barriers=4),
|
|
],
|
|
)
|
|
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer]
|
|
)
|
|
plgpu.barrier_wait(barrier_ref.at[indexer])
|
|
o_ref[...] = scratch_ref[...] + 1
|
|
|
|
x = jnp.arange(128).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
@parameterized.named_parameters(("_g2s", False), ("_s2g", True))
|
|
def test_copy_with_transforms(self, to_smem):
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
if to_smem:
|
|
plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
else:
|
|
plgpu.copy_smem_to_gmem(x_ref, o_ref)
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(
|
|
(128, 128),
|
|
lambda: (0, 0),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, 32)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
memory_space=plgpu.SMEM,
|
|
)
|
|
if not to_smem:
|
|
in_spec, out_spec = out_spec, in_spec
|
|
f = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
|
np.testing.assert_array_equal(f(x), x)
|
|
|
|
def test_scoped_copy_with_transforms(self):
|
|
ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
def body(tmp_ref):
|
|
plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = tmp_ref[...] * 2
|
|
pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts))
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(
|
|
(128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM,
|
|
)
|
|
f = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
|
np.testing.assert_array_equal(f(x), x * 2)
|
|
|
|
def test_copy_with_transforms_and_indexing(self):
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
for i in range(2):
|
|
plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier=barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(
|
|
(2, 128, 128),
|
|
lambda: (0, 0, 0),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, 32)),
|
|
plgpu.TransposeTransform((0, 2, 1, 3, 4)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
memory_space=plgpu.SMEM,
|
|
)
|
|
f = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128)
|
|
np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0))
|
|
|
|
def test_indexing_before_transpose(self):
|
|
def kernel(x_ref, o_ref, barrier_ref):
|
|
for i in range(2):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier=barrier_ref
|
|
)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
|
|
in_spec = pl.BlockSpec(memory_space=plgpu.GMEM)
|
|
out_spec = plgpu.GPUBlockSpec(
|
|
(2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM,
|
|
)
|
|
f = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32),
|
|
in_specs=(in_spec,),
|
|
out_specs=out_spec,
|
|
scratch_shapes=[plgpu.Barrier(num_arrivals=1)],
|
|
)
|
|
x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128)
|
|
xt = x.transpose((1, 0, 2))
|
|
np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0))
|
|
|
|
def test_copy_gmem_to_smem_in_run_scoped(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
|
)
|
|
def kernel(x_ref_gmem, o_ref):
|
|
def body(barrier_ref):
|
|
def inner_body(scratch_ref):
|
|
plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref)
|
|
plgpu.barrier_wait(barrier_ref)
|
|
o_ref[...] = scratch_ref[...] + 1
|
|
pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32))
|
|
pl.run_scoped(body, plgpu.Barrier(num_arrivals=1))
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
|
|
|
def test_add_doubled_sum(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + jnp.sum(x_ref[...]) + jnp.sum(x_ref[...])
|
|
|
|
x = jnp.arange(128).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + x.sum()*2)
|
|
|
|
@parameterized.parameters(False, True)
|
|
def test_rsqrt(self, approx_math):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
|
compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
o_ref[...] = jax.lax.rsqrt(x_ref[...])
|
|
|
|
x = jnp.arange(128).astype(jnp.float32)
|
|
np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x))
|
|
|
|
@parameterized.product(input_factor=[0.001, 1, 10, 100, 100])
|
|
def test_layer_norm(self, input_factor):
|
|
eps = 1e-5
|
|
gamma = 1.0
|
|
beta = 1.0
|
|
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def layer_norm(x_ref, o_ref):
|
|
x_mean = jnp.mean(x_ref[...])
|
|
x_centered = x_ref[...] - x_mean
|
|
o_ref[...] = (
|
|
x_centered * jax.lax.rsqrt(jnp.mean(x_centered**2) + eps) * gamma
|
|
+ beta
|
|
)
|
|
|
|
def layer_norm_np(x):
|
|
x_mean = np.mean(x)
|
|
x_centered = x - x_mean
|
|
return (x_centered / np.sqrt(np.mean(x_centered**2) + eps) * gamma) + beta
|
|
|
|
# Ones are always fully precise
|
|
x = jnp.ones((256,)).astype(jnp.float32) * input_factor
|
|
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x))
|
|
|
|
# random (and anything else is not)
|
|
x = (
|
|
jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32)
|
|
* input_factor
|
|
)
|
|
# TODO(cperivol): find out why in this particular case we have a small-ish error.
|
|
rtol = 1e-07 if input_factor > 10 else 5e-5
|
|
np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol)
|
|
|
|
def test_print(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del x_ref, o_ref
|
|
pl.debug_print("It works!")
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
with jtu.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertEqual(output(), "It works!\n")
|
|
|
|
def test_print_wgmma_tiled_layout(self):
|
|
shape = (128, 64)
|
|
size = math.prod(shape)
|
|
def kernel(x_ref, o_ref):
|
|
pl.debug_print("{}", x_ref[...])
|
|
spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)))
|
|
x = jnp.arange(size, dtype=jnp.float32).reshape(shape)
|
|
f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec)
|
|
|
|
with jtu.capture_stdout() as get_output:
|
|
jax.block_until_ready(f(x))
|
|
|
|
output = get_output()
|
|
results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output)
|
|
self.assertLen(results, size)
|
|
for i, j, v in results:
|
|
i, j, v = map(int, (i, j, v))
|
|
self.assertEqual(v, i * shape[1] + j)
|
|
|
|
def test_print_scalar(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del o_ref
|
|
pl.debug_print("x.sum() = {}", x_ref[...].sum())
|
|
|
|
x = jnp.arange(256)
|
|
with jtu.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn(f"x.sum() = {x.sum()}", output())
|
|
|
|
def test_print_scalar_array(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del o_ref
|
|
pl.debug_print("x.sum() = {}", x_ref[...].sum() + 1)
|
|
|
|
x = jnp.arange(256)
|
|
with jtu.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn(f"x.sum() = {x.sum() + 1}", output())
|
|
|
|
def test_print_array(self):
|
|
in_shape = [2, 1, 64, 64]
|
|
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
del o_ref
|
|
pl.debug_print("x: {}", x_ref[...])
|
|
|
|
x = jnp.arange(math.prod(in_shape)).reshape(in_shape)
|
|
with jtu.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
|
|
|
|
def test_run_scoped(self):
|
|
def kernel(x_ref, o_ref):
|
|
def body(tmp_ref):
|
|
self.assertEqual(tmp_ref.shape, (8, 128))
|
|
tmp_ref[...] = x_ref[...] + 1.0
|
|
return tmp_ref[...]
|
|
|
|
tmp = pl.run_scoped(body, plgpu.SMEM((8, 128), jnp.float32))
|
|
self.assertEqual(tmp.shape, (8, 128))
|
|
o_ref[...] = tmp
|
|
|
|
inp = np.ones((8, 128))
|
|
f = pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
|
)
|
|
o = f(inp)
|
|
np.testing.assert_array_equal(o, inp + 1.0)
|
|
|
|
def test_program_id(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
in_specs=(),
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
|
|
grid=2,
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(),
|
|
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
|
|
)
|
|
|
|
def test_program_id_in_block_spec(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_specs=pl.BlockSpec((128,), lambda *_: pl.program_id(0)),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
|
|
grid=2,
|
|
)
|
|
def kernel(o_ref):
|
|
del o_ref
|
|
|
|
# ``assertRaises`` have no way of asserting against the cause, so we
|
|
# have to use ``traceback.format_exception`` manually.
|
|
with self.assertRaises(Exception) as exc_info:
|
|
kernel()
|
|
self.assertIn(
|
|
"not supported in this context",
|
|
"".join(traceback.format_exception(exc_info.exception)),
|
|
)
|
|
|
|
def test_num_programs(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
in_specs=(),
|
|
out_specs=pl.BlockSpec((128,), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
|
|
grid=2,
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0))
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(),
|
|
jnp.full([256], 2, dtype=jnp.int32),
|
|
)
|
|
|
|
def test_swizzled_blockspec_shapes(self):
|
|
|
|
spec = plgpu.GPUBlockSpec(
|
|
(128, 64),
|
|
lambda *i: i,
|
|
transforms=(
|
|
plgpu.TilingTransform((64, 64)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
)
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
in_specs=[spec],
|
|
out_specs=spec,
|
|
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16),
|
|
grid=(2, 2),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
assert x_ref.shape == (128, 64), x_ref.shape
|
|
o_ref[...] = x_ref[...]
|
|
|
|
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
|
|
np.testing.assert_array_equal(kernel(x), x)
|
|
|
|
def test_fori_loop_array(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
# Equivalent to x_ref[...] + 2 + 3.
|
|
o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...])
|
|
|
|
x = jnp.arange(256).astype(jnp.float32)
|
|
np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0)
|
|
|
|
def test_fori_loop_scalar(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
|
|
)
|
|
def kernel(o_ref):
|
|
# Equivalent to 2 + 3.
|
|
o_ref[...] = jax.lax.broadcast(
|
|
jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0.0), o_ref.shape
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
kernel(), jnp.full([256], 5.0, dtype=jnp.float32)
|
|
)
|
|
|
|
def test_fori_loop_indexed_store(self):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
|
|
)
|
|
def kernel(x_ref, y_ref, o_ref):
|
|
def body(idx, _):
|
|
o_ref[idx] = x_ref[idx] + y_ref[idx]
|
|
return ()
|
|
|
|
jax.lax.fori_loop(0, 4, body, ())
|
|
|
|
x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
|
|
y = x + 1
|
|
np.testing.assert_array_equal(kernel(x, y), x + y)
|
|
|
|
def test_cond(self):
|
|
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
|
)
|
|
def kernel(x_ref, o_ref):
|
|
acc = x_ref[...].sum()
|
|
jax.lax.cond(
|
|
acc % 2 == 0,
|
|
lambda: pl.debug_print("acc * 2: {}", acc * 2),
|
|
lambda: pl.debug_print("acc: {}", acc),
|
|
)
|
|
o_ref[...] = jnp.broadcast_to(acc, o_ref.shape)
|
|
|
|
x = jnp.arange(256)
|
|
with jtu.capture_stdout() as output:
|
|
jax.block_until_ready(kernel(x))
|
|
|
|
self.assertIn("acc * 2:", output())
|
|
|
|
@parameterized.parameters(jnp.float16, jnp.float32)
|
|
def test_wgmma(self, dtype):
|
|
# TensorCores can only fuse transposes of 16-bit values, and RHS
|
|
# is expected to be column major by default.
|
|
rhs_transpose = jnp.dtype(dtype).itemsize != 2
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
if rhs_transpose:
|
|
b_ref = plgpu.transpose_ref(b_ref, (1, 0))
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
|
return acc_ref[...]
|
|
|
|
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
|
|
b_shape = (128, 192)
|
|
if rhs_transpose:
|
|
b_shape = b_shape[::-1]
|
|
b = jax.random.uniform(key2, shape=b_shape, dtype=dtype)
|
|
|
|
rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),)
|
|
if rhs_transpose:
|
|
rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),)
|
|
res = pl.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(64, 128),
|
|
lambda i, j: (i, j),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
plgpu.GPUBlockSpec(
|
|
b_shape,
|
|
lambda *i: i,
|
|
transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)),
|
|
),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
|
grid=(1, 1),
|
|
)(a, b)
|
|
np.testing.assert_allclose(
|
|
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
|
|
)
|
|
|
|
def test_wgmma_registers(self):
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
|
|
return acc_ref[...]
|
|
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
|
|
|
|
transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128))
|
|
res = pl.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms),
|
|
plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
|
|
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
|
|
|
def test_wgmma_sliced_ref(self):
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0])
|
|
return acc_ref[...]
|
|
|
|
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16)
|
|
|
|
res = pl.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(2, 64, 128), lambda: (0, 0, 0),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, 64)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
plgpu.GPUBlockSpec(
|
|
(2, 128, 192), lambda: (0, 0, 0),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, 64)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
|
|
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3)
|
|
|
|
def test_wgmma_sliced_acc(self):
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize
|
|
def kernel(a_ref, b_ref, o_ref):
|
|
def scope(acc_ref):
|
|
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
|
return acc_ref[:, :64], acc_ref[:, 64:]
|
|
|
|
o_ref[:, :64], o_ref[:, 64:] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
|
|
b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16)
|
|
res = pl.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(64, 128),
|
|
lambda i, j: (i, j),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
plgpu.GPUBlockSpec(
|
|
(128, 128),
|
|
lambda *i: i,
|
|
transforms=(
|
|
plgpu.TilingTransform((elems_128b, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i),
|
|
out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32),
|
|
grid=(1, 1),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
|
|
|
def test_input_output_aliases(self):
|
|
# Note that we're writing to the input pointer, which should alias b_ptr.
|
|
def kernel(a_ref, b_ref):
|
|
del b_ref
|
|
a_ref[...] = jnp.ones_like(a_ref)
|
|
|
|
a = np.zeros((64, 64), dtype=jnp.float32)
|
|
b = pl.pallas_call(
|
|
kernel,
|
|
in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)],
|
|
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM),
|
|
input_output_aliases={0: 0},
|
|
out_shape=a,
|
|
)(a)
|
|
np.testing.assert_array_equal(b, np.ones_like(a))
|
|
|
|
def test_realistic_matmul(self):
|
|
dtype = jnp.float16
|
|
swizzle = 128
|
|
elems_128b = swizzle // jnp.dtype(dtype).itemsize
|
|
grid_m, grid_k, grid_n = 132, 10, 4
|
|
tile_m = tile_n = 128
|
|
tile_k = elems_128b
|
|
m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n
|
|
def kernel(a_ref, b_ref, o_ref, acc_ref):
|
|
# Make sure tiling does not alter the shape of references
|
|
assert a_ref.shape == (tile_m, tile_k)
|
|
assert b_ref.shape == (tile_k, tile_n)
|
|
assert o_ref.shape == acc_ref.shape == (tile_m, tile_n)
|
|
plgpu.wgmma(acc_ref, a_ref, b_ref)
|
|
is_last_step = pl.program_id(2) == grid_k - 1
|
|
@pl.when(is_last_step)
|
|
def _epilogue():
|
|
o_ref[...] = acc_ref[...].astype(dtype)
|
|
plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1
|
|
|
|
key1, key2 = jax.random.split(jax.random.key(42), 2)
|
|
a = jax.random.uniform(key1, shape=(m, k), dtype=dtype)
|
|
b = jax.random.uniform(key2, shape=(k, n), dtype=dtype)
|
|
|
|
res = pl.pallas_call(
|
|
kernel,
|
|
in_specs=[
|
|
plgpu.GPUBlockSpec(
|
|
(tile_m, tile_k),
|
|
lambda m, n, k: (m, k),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
plgpu.GPUBlockSpec(
|
|
(tile_k, tile_n),
|
|
lambda m, n, k: (k, n),
|
|
transforms=(
|
|
plgpu.TilingTransform((elems_128b, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
],
|
|
out_specs=plgpu.GPUBlockSpec(
|
|
(tile_m, tile_n),
|
|
lambda m, n, k: (m, n),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, elems_128b)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
),
|
|
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
|
|
scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)],
|
|
grid=(grid_m, grid_n, grid_k),
|
|
compiler_params=plgpu.GPUCompilerParams(
|
|
dimension_semantics=["parallel", "parallel", "sequential"],
|
|
max_concurrent_steps=2,
|
|
delay_release=1,
|
|
),
|
|
)(a, b)
|
|
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
|
|
|
|
def test_slicing(self):
|
|
left = upper = slice(None, 64)
|
|
right = lower = slice(64, None)
|
|
# We rotate the four quadrants of the input clockwise.
|
|
def rotate(src, dst):
|
|
dst[upper, left] = src[lower, left]
|
|
dst[upper, right] = src[upper, left]
|
|
dst[lower, right] = src[upper, right]
|
|
dst[lower, left] = src[lower, right]
|
|
|
|
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
|
|
spec = plgpu.GPUBlockSpec(
|
|
(128, 128),
|
|
lambda: (0, 0),
|
|
transforms=(
|
|
plgpu.TilingTransform((64, 64)),
|
|
plgpu.SwizzleTransform(128),
|
|
),
|
|
)
|
|
f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec)
|
|
expected = np.empty_like(x)
|
|
rotate(x, expected)
|
|
np.testing.assert_array_equal(f(x), expected)
|
|
|
|
def test_layout_cast(self, shape=(256, 64)):
|
|
@functools.partial(
|
|
pl.pallas_call,
|
|
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
|
)
|
|
def kernel(o_ref):
|
|
o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0), plgpu.Layout.WGMMA)
|
|
|
|
x = jnp.full(shape, 42.0)
|
|
np.testing.assert_array_equal(kernel(), x)
|
|
|
|
|
|
class PipelineTest(PallasTest):
|
|
|
|
def test_manual(self, max_concurrent_steps=2, num_steps=4):
|
|
|
|
def kernel(x_gmem, o_gmem):
|
|
return pl.run_scoped(
|
|
functools.partial(scoped_kernel, x_gmem, o_gmem),
|
|
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
|
|
plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32),
|
|
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
|
|
)
|
|
|
|
def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier):
|
|
gmem_slice = pl.ds(pl.program_id(0) * 32, 32)
|
|
|
|
def body(step, _):
|
|
slot = step % max_concurrent_steps
|
|
|
|
# Wait for the current GMEM->SMEM copy to complete.
|
|
plgpu.barrier_wait(barrier.at[slot])
|
|
# Wait for the previous output SMEM->GMEM copy to complete.
|
|
plgpu.wait_smem_to_gmem(max_concurrent_steps - 1)
|
|
|
|
o_smem[...] = x_smem[...] + 1.0
|
|
|
|
plgpu.copy_smem_to_gmem(
|
|
o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)]
|
|
)
|
|
|
|
fetch_step = step + max_concurrent_steps
|
|
fetch_slot = slot # (x + y) % y == x % y
|
|
jax.lax.cond(
|
|
fetch_step < num_steps,
|
|
lambda: plgpu.copy_gmem_to_smem(
|
|
x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)],
|
|
x_smem.at[fetch_slot],
|
|
barrier=barrier.at[fetch_slot],
|
|
),
|
|
lambda: None,
|
|
)
|
|
return ()
|
|
|
|
# Initialize the pipeline.
|
|
for slot in range(min(max_concurrent_steps, num_steps)):
|
|
plgpu.copy_gmem_to_smem(
|
|
x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)],
|
|
x_smem.at[slot],
|
|
barrier=barrier.at[slot],
|
|
)
|
|
|
|
jax.lax.fori_loop(0, num_steps, body, ())
|
|
|
|
# Finalize the pipeline.
|
|
plgpu.wait_smem_to_gmem(0)
|
|
|
|
x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32)
|
|
kernel_fn = pl.pallas_call(
|
|
kernel,
|
|
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
|
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
grid=(4, 1),
|
|
)
|
|
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)
|
|
|
|
|
|
class CoreMapTest(PallasTest):
|
|
|
|
def test_multiple_wg(self):
|
|
mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",))
|
|
|
|
@jax.jit
|
|
def f():
|
|
@pl.run_state
|
|
def inner(y_ref):
|
|
@pl.core_map(mesh)
|
|
def kernel():
|
|
wg_idx = jax.lax.axis_index("y")
|
|
y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,))
|
|
y_init = jnp.zeros((2, 128), np.int32)
|
|
return inner(y_init)
|
|
np.testing.assert_array_equal(
|
|
f(), np.repeat(np.arange(2), 128).reshape(2, 128)
|
|
)
|
|
|
|
def test_multiple_wg_with_grid(self):
|
|
mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg"))
|
|
|
|
@jax.jit
|
|
def f():
|
|
@pl.run_state
|
|
def inner(y_ref):
|
|
@pl.core_map(mesh)
|
|
def kernel():
|
|
xy_idx = jax.lax.axis_index(("x", "y"))
|
|
yx_idx = jax.lax.axis_index(("y", "x"))
|
|
wg_idx = jax.lax.axis_index("wg")
|
|
num_wgs = jax.lax.psum(1, "wg")
|
|
y_ref[xy_idx, wg_idx] = jnp.broadcast_to(
|
|
yx_idx * num_wgs + wg_idx, (128,)
|
|
)
|
|
y_init = jnp.zeros((4, 2, 128), np.int32)
|
|
return inner(y_init)
|
|
np.testing.assert_array_equal(
|
|
f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|