mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mgpu] FragentedArray.foreach() can now optionally return a new array
PiperOrigin-RevId: 700708119
This commit is contained in:
parent
03b6945ee7
commit
f3acfa93bb
@ -1243,15 +1243,29 @@ class FragmentedArray:
|
||||
lambda t, p, f: arith.select(p, t, f), self, on_false,
|
||||
)
|
||||
|
||||
def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
|
||||
def foreach(
|
||||
self,
|
||||
fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None],
|
||||
*,
|
||||
create_array=False,
|
||||
is_signed=None,
|
||||
):
|
||||
"""Call a function for each value and index."""
|
||||
index = ir.IndexType.get()
|
||||
for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True):
|
||||
assert len(idx) == len(self.shape), (idx, self.shape)
|
||||
new_regs = None
|
||||
if create_array:
|
||||
new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type))
|
||||
for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True):
|
||||
reg = self.registers[reg_idx]
|
||||
assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape)
|
||||
[elems] = ir.VectorType(reg.type).shape
|
||||
for i in range(elems):
|
||||
i = c(i, index)
|
||||
fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i)))
|
||||
val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i)))
|
||||
if create_array:
|
||||
new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i)
|
||||
|
||||
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
|
||||
|
||||
def store_untiled(self, ref: ir.Value):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
|
@ -1361,6 +1361,39 @@ class FragmentedArrayTest(TestCase):
|
||||
rhs = rhs = 0 if rhs_is_literal else iota + 1
|
||||
np.testing.assert_array_equal(result, op(iota, rhs))
|
||||
|
||||
def test_foreach(self):
|
||||
dtype = jnp.int32
|
||||
swizzle = 128
|
||||
tile = 64, swizzle // jnp.dtype(dtype).itemsize
|
||||
shape = 128, 192
|
||||
tiled_shape = mgpu.tile_shape(shape, tile)
|
||||
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
||||
cst = 9999
|
||||
def causal(val, idx):
|
||||
row, col = idx
|
||||
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
|
||||
return arith.select(mask, val, c(cst, mlir_dtype))
|
||||
|
||||
tiling = mgpu.TileTransform(tile)
|
||||
def kernel(ctx, dst, smem):
|
||||
x = iota_tensor(shape[0], shape[1], dtype)
|
||||
x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem)
|
||||
mgpu.commit_shared()
|
||||
ctx.async_copy(src_ref=smem, dst_ref=dst)
|
||||
ctx.await_async_copy(0)
|
||||
|
||||
iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
|
||||
result = mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
(1, 1, 1),
|
||||
(128, 1, 1),
|
||||
(),
|
||||
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
|
||||
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
|
||||
)()
|
||||
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
@parameterized.product(
|
||||
op=[operator.and_, operator.or_, operator.xor],
|
||||
dtype=[jnp.uint32],
|
||||
|
Loading…
x
Reference in New Issue
Block a user