mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added another boxDim check to mosaic_gpu_init_tma_desc
PiperOrigin-RevId: 660314586
This commit is contained in:
parent
803453ed74
commit
28ca734d9b
@ -442,6 +442,16 @@ class LaunchContext:
|
||||
rank = len(slice_shape)
|
||||
if rank > 5: # TODO: apaszke - Implement stride compression
|
||||
raise ValueError("Async copies only support striding up to 5 dimensions")
|
||||
if max(slice_shape) > 256:
|
||||
raise ValueError(
|
||||
"Async copies only support copying <=256 elements along each"
|
||||
" dimension"
|
||||
)
|
||||
if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0:
|
||||
raise ValueError(
|
||||
"Async copies require the number of bytes copied along the last"
|
||||
f" dimension to be divisible by 16, but got {zeroth_bw}"
|
||||
)
|
||||
if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth:
|
||||
raise ValueError(
|
||||
f"Async copies with {swizzle=} require last dimension of the slice to"
|
||||
|
@ -88,6 +88,13 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
|
||||
tma_window_shape_i, rank - i - 1);
|
||||
abort();
|
||||
}
|
||||
if (i == 0 && (tma_window_shape_i * elem_bytewidth) % 16 != 0) {
|
||||
fprintf(stderr,
|
||||
"The last dimension of window shape must have a bytewidth "
|
||||
"divisible by 16, but got %d*%ld at index %ld\n",
|
||||
tma_window_shape_i, elem_bytewidth, rank - i - 1);
|
||||
abort();
|
||||
}
|
||||
tma_window_shape[i] = tma_window_shape_i;
|
||||
}
|
||||
cuuint32_t element_strides[5] = {1, 1, 1, 1, 1};
|
||||
|
@ -961,6 +961,24 @@ class TMATest(TestCase):
|
||||
y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_tma_invalid(self):
|
||||
def kernel(ctx, src, dst, tmp):
|
||||
copy(src, tmp)
|
||||
ctx.async_copy(src_ref=tmp, dst_ref=dst)
|
||||
ctx.await_async_copy(0)
|
||||
|
||||
def run_kernel(shape):
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
_ = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "only support striding up to 5"):
|
||||
run_kernel([1] * 6)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "last dimension to be divisible by 16"
|
||||
):
|
||||
run_kernel([23])
|
||||
|
||||
|
||||
class FragmentedArrayTest(TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user