[mgpu] FragentedArray.foreach() can now optionally return a new array

PiperOrigin-RevId: 700708119
This commit is contained in:
Christos Perivolaropoulos 2024-11-27 08:20:11 -08:00 committed by jax authors
parent 03b6945ee7
commit f3acfa93bb
2 changed files with 51 additions and 4 deletions

View File

@ -1243,15 +1243,29 @@ class FragmentedArray:
lambda t, p, f: arith.select(p, t, f), self, on_false,
)
def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
def foreach(
self,
fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None],
*,
create_array=False,
is_signed=None,
):
"""Call a function for each value and index."""
index = ir.IndexType.get()
for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True):
assert len(idx) == len(self.shape), (idx, self.shape)
new_regs = None
if create_array:
new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type))
for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True):
reg = self.registers[reg_idx]
assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape)
[elems] = ir.VectorType(reg.type).shape
for i in range(elems):
i = c(i, index)
fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i)))
val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i)))
if create_array:
new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i)
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
def store_untiled(self, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):

View File

@ -1361,6 +1361,39 @@ class FragmentedArrayTest(TestCase):
rhs = rhs = 0 if rhs_is_literal else iota + 1
np.testing.assert_array_equal(result, op(iota, rhs))
def test_foreach(self):
dtype = jnp.int32
swizzle = 128
tile = 64, swizzle // jnp.dtype(dtype).itemsize
shape = 128, 192
tiled_shape = mgpu.tile_shape(shape, tile)
mlir_dtype = utils.dtype_to_ir_type(dtype)
cst = 9999
def causal(val, idx):
row, col = idx
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
return arith.select(mask, val, c(cst, mlir_dtype))
tiling = mgpu.TileTransform(tile)
def kernel(ctx, dst, smem):
x = iota_tensor(shape[0], shape[1], dtype)
x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem, dst_ref=dst)
ctx.await_async_copy(0)
iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
result = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
(),
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
)()
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
np.testing.assert_array_equal(result, expected)
@parameterized.product(
op=[operator.and_, operator.or_, operator.xor],
dtype=[jnp.uint32],