[mosaic:gpu] Minor cleanup in FragmentedArray.transfer_tile.

- Remove redundant line.
- Use `ConstantOp.create_index`.
- Use `BoolAttr`.

PiperOrigin-RevId: 638616982
This commit is contained in:
Chris Jones 2024-05-30 05:23:48 -07:00 committed by jax authors
parent cfe64cd5ce
commit a5fc31e425

View File

@ -25,10 +25,11 @@ from jaxlib.mlir.dialects import math as mlir_math
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import vector
from jaxlib.mlir.extras import types
import numpy as np
from . import utils
from . import dsl as mgpu
from . import utils
# mypy: ignore-errors
@ -596,18 +597,14 @@ class FragmentedArray:
@staticmethod
def transfer_tiled(shape, dtype, swizzle: int | None):
bw = mgpu.bytewidth(dtype)
cols_per_tile = 128 // bw
m, n = shape
if n % 32 != 0:
raise NotImplementedError
cols_per_tile = 128 // bw
if swizzle != 128:
raise NotImplementedError("Only 128B swizzle supported")
index = ir.IndexType.get()
def c(x):
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
c = arith.ConstantOp.create_index
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
@ -618,8 +615,7 @@ class FragmentedArray:
)
else:
# We rely on canonicalization to clean up the selects.
i1 = ir.IntegerType.get_signless(1)
is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1))
is_even_row = arith.constant(types.bool(), ir.BoolAttr.get(True))
row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
# The swizzle pattern is constant for a given thread.