[Pallas TPU] Add vector support to pl.debug_print

PiperOrigin-RevId: 715085454
This commit is contained in:
Ayaka 2025-01-13 13:21:45 -08:00 committed by jax authors
parent f69592ae78
commit 9ba1fd2801
7 changed files with 142 additions and 25 deletions

View File

@ -11,6 +11,12 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->
## Released with jax 0.5.0
* New functionality
* Added vector support for {func}`jax.experimental.pallas.debug_print` on TPU.
## Released with jax 0.4.37
* New functionality

View File

@ -3140,29 +3140,71 @@ lowering_rules[tpu_primitives.delay_p] = _delay_rule
def _debug_print_rule(
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
):
if any(aval.shape for aval in ctx.avals_in):
raise NotImplementedError("Only scalar values are supported")
is_scalar_inputs = [aval.shape == () for aval in ctx.avals_in]
is_all_scalars = all(is_scalar_inputs)
is_single_vector = len(is_scalar_inputs) == 1 and not is_scalar_inputs[0]
if not (is_all_scalars or is_single_vector):
raise ValueError(
"All inputs to debug_print must be all scalars or a single vector, but"
f" got {ctx.avals_in}"
)
primitives.check_debug_print_format(fmt, *args)
if has_placeholders:
if not all(
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
for arg in args
):
raise TypeError(
"All arguments must be 32-bit integers when using"
" placeholders (`{...}`). If you need to print values of other types,"
" remove placeholders from the format string."
# Scalar case.
if is_all_scalars:
primitives.check_debug_print_format(fmt, *args)
if has_placeholders:
if not all(
isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
for arg in args
):
raise TypeError(
"All arguments must be 32-bit integers when using"
" placeholders (`{...}`). If you need to print values of other types,"
" remove placeholders from the format string."
)
# TPU expects $0, $1 etc as placeholders.
fmt = "".join(
f"{text}${idx}"
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
)
# TPU expects $0, $1 etc as placeholders.
tpu_fmt = "".join(
f"{text}${idx}"
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
tpu.log(args, fmt, formatted=has_placeholders)
return ()
# Vector case.
# Copy the array to vmem for logging.
# Note that the shape of the array must be explicitly provided here. This is
# because the underlying implementation aligns shapes to tile boundaries,
# potentially altering the original shape and making it unrecoverable.
if len(ctx.avals_in) != 1:
raise ValueError(
"Only one vector input to debug_print is supported."
)
else:
tpu_fmt = fmt
tpu.log(args, tpu_fmt, formatted=has_placeholders)
(aval,) = ctx.avals_in
(arg,) = args
if not has_placeholders or not fmt.endswith("{}"):
raise ValueError("For vector input, the format string must end with {}.")
fmt = fmt[:-2]
region = tpu.RegionOp(())
with ir.InsertionPoint(region.body):
element_type = _dtype_to_ir_type(aval.dtype)
ref_type = ir.MemRefType.get(
aval.shape,
element_type,
memory_space=ir.Attribute.parse("#tpu.memory_space<vmem>"),
)
ref = memref.alloca(ref_type, [], [])
index_type = ir.IndexType.get()
zero = arith.constant(index_type, 0)
indices = [zero] * len(aval.shape)
vector.store(arg, ref, indices)
tpu.log_buffer(ref, aval.shape, fmt)
tpu.yield_([])
return ()

View File

@ -732,9 +732,12 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike):
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
contain a placeholder for each value to be printed. Format specs and
conversions are not supported. All values must be scalars.
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
integers. If there are no placeholders, the values are printed after
the format string. All values must be scalars.
* On TPU, if all inputs are scalars: If ``fmt`` contains placeholders,
all values must be 32-bit integers. If there are no placeholders, the
values are printed after the format string.
* On TPU, if the input is a single vector, the vector is printed after
the format string. The format string must end with a single placeholder
``{}``.
*args: The values to print.
""" # fmt: skip
has_placeholders = False

View File

@ -781,6 +781,17 @@ def TPU_LogOp : TPU_Op<"log"> {
let hasVerifier = 1;
}
def TPU_LogBufferOp : TPU_Op<"log_buffer"> {
let arguments = (ins
AnyMemRef:$input,
DenseI64ArrayAttr:$shape,
StrAttr:$tag
);
let results = (outs);
let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }];
let hasVerifier = 1;
}
def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",

View File

@ -1134,6 +1134,15 @@ LogicalResult WeirdOp::verify() {
return success();
}
LogicalResult LogBufferOp::verify() {
const MemRefType input_type = getInput().getType();
if (input_type.getRank() != getShape().size()) {
return emitOpError(
"Shape must have the same length as the rank of the input");
}
return success();
}
void PackSubelementsOp::build(OpBuilder &builder, OperationState &state,
const VectorType output_type,
const ArrayRef<Value> padded_sources,

View File

@ -1339,10 +1339,11 @@ class OpsTest(PallasBaseTest):
"plgpu.TritonCompilerParams unavailable on Windows",
)
def test_debug_print(self):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")
# TODO: this test flakes on gpu
if jtu.test_device_matches(["gpu"]):
@ -1369,7 +1370,7 @@ class OpsTest(PallasBaseTest):
)
def test_debug_print_with_values(self):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")
# TODO: this test flakes on gpu
if jtu.test_device_matches(["gpu"]):

View File

@ -2114,6 +2114,51 @@ class PallasCallPrintTest(PallasBaseTest):
jax.block_until_ready(compiled_kernel(x))
self.assertIn('x[0] == 42', get_output())
@parameterized.named_parameters(
(f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype)
for shape in (
(2, 8, 128),
# test unaligned shapes
(3,),
(3, 4),
(2, 3, 4),
(2, 9, 129),
)
for dtype in (jnp.int32, jnp.uint32, jnp.float32)
)
def test_debug_print_vector(self, shape, dtype):
# TODO(ayx): Remove after this date.
if not jtu.if_cloud_tpu_at_least(2025, 1, 16):
self.skipTest("Requires libtpu built after 2025-01-16")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, dtype),
)
def kernel(x_ref, o_ref):
pl.debug_print("{}", x_ref[...])
o_ref[...] = x_ref[...]
n = np.prod(shape)
x = jnp.arange(n, dtype=dtype).reshape(shape)
compiled_kernel = (
jax.jit(kernel)
.lower(x)
.compile({"xla_tpu_enable_log_recorder": "true"})
)
with jtu.capture_stderr() as get_output:
jax.block_until_ready(compiled_kernel(x))
output = get_output()
numbers = [
int(num)
for line in output.splitlines()
if (match := re.search(r"\{(.*)", line)) # extract contents after `{`
for num in re.findall(r"\d+", match.group(1))
]
# Check if the numbers in the output match the values generated by `arange`.
self.assertLen(numbers, n)
self.assertTrue(all(num == i for i, num in enumerate(numbers)))
class PallasCallTraceTest(PallasBaseTest):