mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas TPU] Add vector support to pl.debug_print
PiperOrigin-RevId: 715085454
This commit is contained in:
parent
f69592ae78
commit
9ba1fd2801
@ -11,6 +11,12 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
-->
|
||||
|
||||
## Released with jax 0.5.0
|
||||
|
||||
* New functionality
|
||||
|
||||
* Added vector support for {func}`jax.experimental.pallas.debug_print` on TPU.
|
||||
|
||||
## Released with jax 0.4.37
|
||||
|
||||
* New functionality
|
||||
|
@ -3140,9 +3140,17 @@ lowering_rules[tpu_primitives.delay_p] = _delay_rule
|
||||
def _debug_print_rule(
|
||||
ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
|
||||
):
|
||||
if any(aval.shape for aval in ctx.avals_in):
|
||||
raise NotImplementedError("Only scalar values are supported")
|
||||
is_scalar_inputs = [aval.shape == () for aval in ctx.avals_in]
|
||||
is_all_scalars = all(is_scalar_inputs)
|
||||
is_single_vector = len(is_scalar_inputs) == 1 and not is_scalar_inputs[0]
|
||||
if not (is_all_scalars or is_single_vector):
|
||||
raise ValueError(
|
||||
"All inputs to debug_print must be all scalars or a single vector, but"
|
||||
f" got {ctx.avals_in}"
|
||||
)
|
||||
|
||||
# Scalar case.
|
||||
if is_all_scalars:
|
||||
primitives.check_debug_print_format(fmt, *args)
|
||||
if has_placeholders:
|
||||
if not all(
|
||||
@ -3156,13 +3164,47 @@ def _debug_print_rule(
|
||||
)
|
||||
|
||||
# TPU expects $0, $1 etc as placeholders.
|
||||
tpu_fmt = "".join(
|
||||
fmt = "".join(
|
||||
f"{text}${idx}"
|
||||
for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
|
||||
)
|
||||
else:
|
||||
tpu_fmt = fmt
|
||||
tpu.log(args, tpu_fmt, formatted=has_placeholders)
|
||||
|
||||
tpu.log(args, fmt, formatted=has_placeholders)
|
||||
return ()
|
||||
|
||||
# Vector case.
|
||||
# Copy the array to vmem for logging.
|
||||
# Note that the shape of the array must be explicitly provided here. This is
|
||||
# because the underlying implementation aligns shapes to tile boundaries,
|
||||
# potentially altering the original shape and making it unrecoverable.
|
||||
if len(ctx.avals_in) != 1:
|
||||
raise ValueError(
|
||||
"Only one vector input to debug_print is supported."
|
||||
)
|
||||
(aval,) = ctx.avals_in
|
||||
(arg,) = args
|
||||
|
||||
if not has_placeholders or not fmt.endswith("{}"):
|
||||
raise ValueError("For vector input, the format string must end with {}.")
|
||||
|
||||
fmt = fmt[:-2]
|
||||
|
||||
region = tpu.RegionOp(())
|
||||
with ir.InsertionPoint(region.body):
|
||||
element_type = _dtype_to_ir_type(aval.dtype)
|
||||
ref_type = ir.MemRefType.get(
|
||||
aval.shape,
|
||||
element_type,
|
||||
memory_space=ir.Attribute.parse("#tpu.memory_space<vmem>"),
|
||||
)
|
||||
ref = memref.alloca(ref_type, [], [])
|
||||
|
||||
index_type = ir.IndexType.get()
|
||||
zero = arith.constant(index_type, 0)
|
||||
indices = [zero] * len(aval.shape)
|
||||
vector.store(arg, ref, indices)
|
||||
tpu.log_buffer(ref, aval.shape, fmt)
|
||||
tpu.yield_([])
|
||||
return ()
|
||||
|
||||
|
||||
|
@ -732,9 +732,12 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike):
|
||||
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
|
||||
contain a placeholder for each value to be printed. Format specs and
|
||||
conversions are not supported. All values must be scalars.
|
||||
* In TPU, if ``fmt`` contains placeholders, all values must be 32-bit
|
||||
integers. If there are no placeholders, the values are printed after
|
||||
the format string. All values must be scalars.
|
||||
* On TPU, if all inputs are scalars: If ``fmt`` contains placeholders,
|
||||
all values must be 32-bit integers. If there are no placeholders, the
|
||||
values are printed after the format string.
|
||||
* On TPU, if the input is a single vector, the vector is printed after
|
||||
the format string. The format string must end with a single placeholder
|
||||
``{}``.
|
||||
*args: The values to print.
|
||||
""" # fmt: skip
|
||||
has_placeholders = False
|
||||
|
@ -781,6 +781,17 @@ def TPU_LogOp : TPU_Op<"log"> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TPU_LogBufferOp : TPU_Op<"log_buffer"> {
|
||||
let arguments = (ins
|
||||
AnyMemRef:$input,
|
||||
DenseI64ArrayAttr:$shape,
|
||||
StrAttr:$tag
|
||||
);
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> {
|
||||
let dependentDialects = [
|
||||
"::mlir::func::FuncDialect",
|
||||
|
@ -1134,6 +1134,15 @@ LogicalResult WeirdOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LogBufferOp::verify() {
|
||||
const MemRefType input_type = getInput().getType();
|
||||
if (input_type.getRank() != getShape().size()) {
|
||||
return emitOpError(
|
||||
"Shape must have the same length as the rank of the input");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void PackSubelementsOp::build(OpBuilder &builder, OperationState &state,
|
||||
const VectorType output_type,
|
||||
const ArrayRef<Value> padded_sources,
|
||||
|
@ -1339,10 +1339,11 @@ class OpsTest(PallasBaseTest):
|
||||
"plgpu.TritonCompilerParams unavailable on Windows",
|
||||
)
|
||||
def test_debug_print(self):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")
|
||||
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Not supported on TPU")
|
||||
|
||||
# TODO: this test flakes on gpu
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
@ -1369,7 +1370,7 @@ class OpsTest(PallasBaseTest):
|
||||
)
|
||||
def test_debug_print_with_values(self):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Not supported on TPU")
|
||||
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")
|
||||
|
||||
# TODO: this test flakes on gpu
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
|
@ -2114,6 +2114,51 @@ class PallasCallPrintTest(PallasBaseTest):
|
||||
jax.block_until_ready(compiled_kernel(x))
|
||||
self.assertIn('x[0] == 42', get_output())
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype)
|
||||
for shape in (
|
||||
(2, 8, 128),
|
||||
# test unaligned shapes
|
||||
(3,),
|
||||
(3, 4),
|
||||
(2, 3, 4),
|
||||
(2, 9, 129),
|
||||
)
|
||||
for dtype in (jnp.int32, jnp.uint32, jnp.float32)
|
||||
)
|
||||
def test_debug_print_vector(self, shape, dtype):
|
||||
# TODO(ayx): Remove after this date.
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 1, 16):
|
||||
self.skipTest("Requires libtpu built after 2025-01-16")
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct(shape, dtype),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
pl.debug_print("{}", x_ref[...])
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
n = np.prod(shape)
|
||||
x = jnp.arange(n, dtype=dtype).reshape(shape)
|
||||
compiled_kernel = (
|
||||
jax.jit(kernel)
|
||||
.lower(x)
|
||||
.compile({"xla_tpu_enable_log_recorder": "true"})
|
||||
)
|
||||
with jtu.capture_stderr() as get_output:
|
||||
jax.block_until_ready(compiled_kernel(x))
|
||||
output = get_output()
|
||||
numbers = [
|
||||
int(num)
|
||||
for line in output.splitlines()
|
||||
if (match := re.search(r"\{(.*)", line)) # extract contents after `{`
|
||||
for num in re.findall(r"\d+", match.group(1))
|
||||
]
|
||||
# Check if the numbers in the output match the values generated by `arange`.
|
||||
self.assertLen(numbers, n)
|
||||
self.assertTrue(all(num == i for i, num in enumerate(numbers)))
|
||||
|
||||
|
||||
class PallasCallTraceTest(PallasBaseTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user