diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 6f3f4abc8..12ca2d74a 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -25,10 +25,11 @@ from jaxlib.mlir.dialects import math as mlir_math from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector +from jaxlib.mlir.extras import types import numpy as np -from . import utils from . import dsl as mgpu +from . import utils # mypy: ignore-errors @@ -596,18 +597,14 @@ class FragmentedArray: @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): bw = mgpu.bytewidth(dtype) - cols_per_tile = 128 // bw m, n = shape if n % 32 != 0: raise NotImplementedError cols_per_tile = 128 // bw if swizzle != 128: raise NotImplementedError("Only 128B swizzle supported") - index = ir.IndexType.get() - - def c(x): - return arith.ConstantOp(index, ir.IntegerAttr.get(index, x)) + c = arith.ConstantOp.create_index tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE)) lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31} warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3} @@ -618,8 +615,7 @@ class FragmentedArray: ) else: # We rely on canonicalization to clean up the selects. - i1 = ir.IntegerType.get_signless(1) - is_even_row = arith.constant(i1, ir.IntegerAttr.get(i1, 1)) + is_even_row = arith.constant(types.bool(), ir.BoolAttr.get(True)) row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16))) col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6} # The swizzle pattern is constant for a given thread.