[Mosaic TPU] Allow specify priority in enqueueDMA.

For now we only support priority 0 (on-demand thread) and priority 1 (background thread) on local DMA.

PiperOrigin-RevId: 743780185
This commit is contained in:
Jevin Jiang 2025-04-03 19:39:03 -07:00 committed by jax authors
parent 5b3e419515
commit c1bdd1a234
5 changed files with 37 additions and 5 deletions

View File

@ -1073,6 +1073,7 @@ pytype_strict_library(
srcs = ["_src/tpu_custom_call.py"],
visibility = [":internal"],
deps = [
":cloud_tpu_init",
":config",
":core",
":jax",

View File

@ -32,6 +32,7 @@ import jax
from jax._src import config
from jax._src import core
from jax._src import sharding_impls
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
from jax._src.interpreters import mlir
from jax._src.lib import tpu
from jax._src.lib import xla_client
@ -64,7 +65,14 @@ _MOSAIC_ALLOW_HLO = config.bool_state(
# This tracks the latest Mosaic IR version with a monthly delay.
FWD_COMPAT_IR_VERSION = 3
FWD_COMPAT_IR_VERSION = 4
DEFAULT_IR_VERSION = None
# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date.
if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple(
jax.lib.__version__
) < (0, 5, 4):
FWD_COMPAT_IR_VERSION = 3
DEFAULT_IR_VERSION = 3
tpu_custom_call_p = core.Primitive("tpu_custom_call")
@ -671,7 +679,9 @@ def lower_module_to_custom_call(
serialization_format=serialization_format,
output_memory_spaces=output_memory_spaces,
kernel_name=kernel_name,
ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None,
ir_version=FWD_COMPAT_IR_VERSION
if ctx.is_forward_compat()
else DEFAULT_IR_VERSION,
)
return _tpu_custom_call_lowering(
ctx,

View File

@ -752,7 +752,9 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
AnyMemRef:$target,
MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore,
Optional<I32>:$device_id, // For remote DMAs
Optional<I32>:$core_id // For megacore
Optional<I32>:$core_id, // For megacore
// Smaller number means higher priority. 0 is the highest and the default.
DefaultValuedAttr<I32Attr, "0">:$priority
);
let hasVerifier = 1;
}

View File

@ -955,13 +955,24 @@ LogicalResult EnqueueDMAOp::verify() {
"device_id or core_id is specified");
}
}
bool is_remote = getDeviceId() || getCoreId();
if (getSourceSemaphore()) {
if (!getDeviceId() && !getCoreId()) {
if (!is_remote) {
return emitOpError(
"DMA destination device_id or core_id must be specified when source "
"semaphore is specified");
}
}
int priority = getPriority();
if (priority < 0 || priority > 1) {
return emitOpError(
"Not implemented: only support priority 0 or 1, but got ")
<< priority;
}
if (priority != 0 && is_remote) {
return emitOpError(
"Not implemented: non-zero priority is not supported for remote DMA");
}
return success();
}

View File

@ -40,7 +40,7 @@ constexpr StringRef kMangledDialect = "stable_mosaic.";
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
// When this is bumped, we should file a TODO to update the forward-compatible
// version in tpu_custom_call.py in a month!
constexpr int kVersion = 3;
constexpr int kVersion = 4;
using SerdeRuleType = jaxlib::mosaic::SerdeRuleType;
@ -62,6 +62,11 @@ LogicalResult enqueue_dma_upgrade(Operation* op, int version) {
<< op->getNumOperands();
}
}
if (version < 4) {
op->setAttr("priority",
mlir::IntegerAttr::get(
mlir::IntegerType::get(op->getContext(), 32), 0));
}
return success();
}
@ -69,6 +74,9 @@ LogicalResult enqueue_dma_downgrade(Operation* op, int version) {
if (version < 2) {
return op->emitError("Downgrade to version ") << version << " unsupported";
}
if (version < 4) {
op->removeAttr("priority");
}
return success();
}