Added another boxDim check to mosaic_gpu_init_tma_desc

PiperOrigin-RevId: 660314586
This commit is contained in:
Sergei Lebedev 2024-08-07 03:15:56 -07:00 committed by jax authors
parent 803453ed74
commit 28ca734d9b
3 changed files with 35 additions and 0 deletions

View File

@ -442,6 +442,16 @@ class LaunchContext:
rank = len(slice_shape)
if rank > 5: # TODO: apaszke - Implement stride compression
raise ValueError("Async copies only support striding up to 5 dimensions")
if max(slice_shape) > 256:
raise ValueError(
"Async copies only support copying <=256 elements along each"
" dimension"
)
if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0:
raise ValueError(
"Async copies require the number of bytes copied along the last"
f" dimension to be divisible by 16, but got {zeroth_bw}"
)
if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth:
raise ValueError(
f"Async copies with {swizzle=} require last dimension of the slice to"

View File

@ -88,6 +88,13 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
tma_window_shape_i, rank - i - 1);
abort();
}
if (i == 0 && (tma_window_shape_i * elem_bytewidth) % 16 != 0) {
fprintf(stderr,
"The last dimension of window shape must have a bytewidth "
"divisible by 16, but got %d*%ld at index %ld\n",
tma_window_shape_i, elem_bytewidth, rank - i - 1);
abort();
}
tma_window_shape[i] = tma_window_shape_i;
}
cuuint32_t element_strides[5] = {1, 1, 1, 1, 1};

View File

@ -961,6 +961,24 @@ class TMATest(TestCase):
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
def test_tma_invalid(self):
def kernel(ctx, src, dst, tmp):
copy(src, tmp)
ctx.async_copy(src_ref=tmp, dst_ref=dst)
ctx.await_async_copy(0)
def run_kernel(shape):
x = np.arange(np.prod(shape)).reshape(shape)
_ = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
with self.assertRaisesRegex(ValueError, "only support striding up to 5"):
run_kernel([1] * 6)
with self.assertRaisesRegex(
ValueError, "last dimension to be divisible by 16"
):
run_kernel([23])
class FragmentedArrayTest(TestCase):