mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mgpu] Debug print for mlir vectors.
PiperOrigin-RevId: 700714031
This commit is contained in:
parent
d449f12a2e
commit
df8ecb971a
@ -107,28 +107,46 @@ def c(val: int | float, ty):
|
||||
raise NotImplementedError(ty)
|
||||
return arith.constant(ty, attr)
|
||||
|
||||
def _debug_scalar_ty_format(arg):
|
||||
ty_format = None
|
||||
if ir.IndexType.isinstance(arg.type):
|
||||
return "%llu"
|
||||
if ir.IntegerType.isinstance(arg.type):
|
||||
width = ir.IntegerType(arg.type).width
|
||||
ty_format = "%llu"
|
||||
if width < 64:
|
||||
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
|
||||
if ir.F32Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
if ir.F16Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
arg = arith.extf(ir.F32Type.get(), arg)
|
||||
|
||||
return ty_format, arg
|
||||
|
||||
def debug_print(fmt, *args, uniform=True):
|
||||
type_formats = []
|
||||
new_args = []
|
||||
for arg in args:
|
||||
ty_format = None
|
||||
if ir.IndexType.isinstance(arg.type):
|
||||
ty_format = "%llu"
|
||||
if ir.IntegerType.isinstance(arg.type):
|
||||
width = ir.IntegerType(arg.type).width
|
||||
ty_format = "%llu"
|
||||
if width < 64:
|
||||
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
|
||||
if ir.F32Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
if ir.F16Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
arg = arith.extf(ir.F32Type.get(), arg)
|
||||
if ir.VectorType.isinstance(arg.type):
|
||||
index = ir.IndexType.get()
|
||||
vec_ty = ir.VectorType(arg.type)
|
||||
if len(vec_ty.shape) > 1:
|
||||
raise NotImplementedError(vec_ty)
|
||||
vec_args = [
|
||||
vector.extractelement(arg, position=c(i, index))
|
||||
for i in range(vec_ty.shape[0])
|
||||
]
|
||||
ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args))
|
||||
ty_format = f"[{','.join(ty_formats)}]"
|
||||
new_args += args
|
||||
else:
|
||||
ty_format, arg = _debug_scalar_ty_format(arg)
|
||||
new_args.append(arg)
|
||||
|
||||
if ty_format is None:
|
||||
raise NotImplementedError(arg.type)
|
||||
type_formats.append(ty_format)
|
||||
new_args.append(arg)
|
||||
ctx = (
|
||||
functools.partial(single_thread, per_block=False)
|
||||
if uniform
|
||||
|
Loading…
x
Reference in New Issue
Block a user