mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[mosaic:gpu] Minor cleanup in FragmentedArray.transfer_tile
.
- Remove redundant line. - Use `ConstantOp.create_index`. - Use `BoolAttr`. PiperOrigin-RevId: 638616982
This commit is contained in:
parent
cfe64cd5ce
commit
a5fc31e425
@ -25,10 +25,11 @@ from jaxlib.mlir.dialects import math as mlir_math
|
||||
from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import vector
|
||||
from jaxlib.mlir.extras import types
|
||||
import numpy as np
|
||||
|
||||
from . import utils
|
||||
from . import dsl as mgpu
|
||||
from . import utils
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
@ -596,18 +597,14 @@ class FragmentedArray:
|
||||
@staticmethod
|
||||
def transfer_tiled(shape, dtype, swizzle: int | None):
|
||||
bw = mgpu.bytewidth(dtype)
|
||||
cols_per_tile = 128 // bw
|
||||
m, n = shape
|
||||
if n % 32 != 0:
|
||||
raise NotImplementedError
|
||||
cols_per_tile = 128 // bw
|
||||
if swizzle != 128:
|
||||
raise NotImplementedError("Only 128B swizzle supported")
|
||||
index = ir.IndexType.get()
|
||||
|
||||
def c(x):
|
||||
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
|
||||
|
||||
c = arith.ConstantOp.create_index
|
||||
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
|
||||
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
|
||||
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
|
||||
@ -618,8 +615,7 @@ class FragmentedArray:
|
||||
)
|
||||
else:
|
||||
# We rely on canonicalization to clean up the selects.
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1))
|
||||
is_even_row = arith.constant(types.bool(), ir.BoolAttr.get(True))
|
||||
row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
|
||||
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
|
||||
# The swizzle pattern is constant for a given thread.
|
||||
|
Loading…
x
Reference in New Issue
Block a user