[mgpu] Debug print for mlir vectors.

PiperOrigin-RevId: 700714031
This commit is contained in:
Christos Perivolaropoulos 2024-11-27 08:44:13 -08:00 committed by jax authors
parent d449f12a2e
commit df8ecb971a

View File

@ -107,28 +107,46 @@ def c(val: int | float, ty):
raise NotImplementedError(ty)
return arith.constant(ty, attr)
def _debug_scalar_ty_format(arg):
ty_format = None
if ir.IndexType.isinstance(arg.type):
return "%llu"
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
ty_format = "%llu"
if width < 64:
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
if ir.F16Type.isinstance(arg.type):
ty_format = "%f"
arg = arith.extf(ir.F32Type.get(), arg)
return ty_format, arg
def debug_print(fmt, *args, uniform=True):
type_formats = []
new_args = []
for arg in args:
ty_format = None
if ir.IndexType.isinstance(arg.type):
ty_format = "%llu"
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
ty_format = "%llu"
if width < 64:
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
if ir.F16Type.isinstance(arg.type):
ty_format = "%f"
arg = arith.extf(ir.F32Type.get(), arg)
if ir.VectorType.isinstance(arg.type):
index = ir.IndexType.get()
vec_ty = ir.VectorType(arg.type)
if len(vec_ty.shape) > 1:
raise NotImplementedError(vec_ty)
vec_args = [
vector.extractelement(arg, position=c(i, index))
for i in range(vec_ty.shape[0])
]
ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args))
ty_format = f"[{','.join(ty_formats)}]"
new_args += args
else:
ty_format, arg = _debug_scalar_ty_format(arg)
new_args.append(arg)
if ty_format is None:
raise NotImplementedError(arg.type)
type_formats.append(ty_format)
new_args.append(arg)
ctx = (
functools.partial(single_thread, per_block=False)
if uniform