mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic TPU] Allow specify priority in enqueueDMA.
For now we only support priority 0 (on-demand thread) and priority 1 (background thread) on local DMA. PiperOrigin-RevId: 743780185
This commit is contained in:
parent
5b3e419515
commit
c1bdd1a234
@ -1073,6 +1073,7 @@ pytype_strict_library(
|
||||
srcs = ["_src/tpu_custom_call.py"],
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":cloud_tpu_init",
|
||||
":config",
|
||||
":core",
|
||||
":jax",
|
||||
|
@ -32,6 +32,7 @@ import jax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import tpu
|
||||
from jax._src.lib import xla_client
|
||||
@ -64,7 +65,14 @@ _MOSAIC_ALLOW_HLO = config.bool_state(
|
||||
|
||||
|
||||
# This tracks the latest Mosaic IR version with a monthly delay.
|
||||
FWD_COMPAT_IR_VERSION = 3
|
||||
FWD_COMPAT_IR_VERSION = 4
|
||||
DEFAULT_IR_VERSION = None
|
||||
# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date.
|
||||
if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple(
|
||||
jax.lib.__version__
|
||||
) < (0, 5, 4):
|
||||
FWD_COMPAT_IR_VERSION = 3
|
||||
DEFAULT_IR_VERSION = 3
|
||||
|
||||
|
||||
tpu_custom_call_p = core.Primitive("tpu_custom_call")
|
||||
@ -671,7 +679,9 @@ def lower_module_to_custom_call(
|
||||
serialization_format=serialization_format,
|
||||
output_memory_spaces=output_memory_spaces,
|
||||
kernel_name=kernel_name,
|
||||
ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None,
|
||||
ir_version=FWD_COMPAT_IR_VERSION
|
||||
if ctx.is_forward_compat()
|
||||
else DEFAULT_IR_VERSION,
|
||||
)
|
||||
return _tpu_custom_call_lowering(
|
||||
ctx,
|
||||
|
@ -752,7 +752,9 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
|
||||
AnyMemRef:$target,
|
||||
MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore,
|
||||
Optional<I32>:$device_id, // For remote DMAs
|
||||
Optional<I32>:$core_id // For megacore
|
||||
Optional<I32>:$core_id, // For megacore
|
||||
// Smaller number means higher priority. 0 is the highest and the default.
|
||||
DefaultValuedAttr<I32Attr, "0">:$priority
|
||||
);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
@ -955,13 +955,24 @@ LogicalResult EnqueueDMAOp::verify() {
|
||||
"device_id or core_id is specified");
|
||||
}
|
||||
}
|
||||
bool is_remote = getDeviceId() || getCoreId();
|
||||
if (getSourceSemaphore()) {
|
||||
if (!getDeviceId() && !getCoreId()) {
|
||||
if (!is_remote) {
|
||||
return emitOpError(
|
||||
"DMA destination device_id or core_id must be specified when source "
|
||||
"semaphore is specified");
|
||||
}
|
||||
}
|
||||
int priority = getPriority();
|
||||
if (priority < 0 || priority > 1) {
|
||||
return emitOpError(
|
||||
"Not implemented: only support priority 0 or 1, but got ")
|
||||
<< priority;
|
||||
}
|
||||
if (priority != 0 && is_remote) {
|
||||
return emitOpError(
|
||||
"Not implemented: non-zero priority is not supported for remote DMA");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ constexpr StringRef kMangledDialect = "stable_mosaic.";
|
||||
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
|
||||
// When this is bumped, we should file a TODO to update the forward-compatible
|
||||
// version in tpu_custom_call.py in a month!
|
||||
constexpr int kVersion = 3;
|
||||
constexpr int kVersion = 4;
|
||||
|
||||
using SerdeRuleType = jaxlib::mosaic::SerdeRuleType;
|
||||
|
||||
@ -62,6 +62,11 @@ LogicalResult enqueue_dma_upgrade(Operation* op, int version) {
|
||||
<< op->getNumOperands();
|
||||
}
|
||||
}
|
||||
if (version < 4) {
|
||||
op->setAttr("priority",
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(op->getContext(), 32), 0));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -69,6 +74,9 @@ LogicalResult enqueue_dma_downgrade(Operation* op, int version) {
|
||||
if (version < 2) {
|
||||
return op->emitError("Downgrade to version ") << version << " unsupported";
|
||||
}
|
||||
if (version < 4) {
|
||||
op->removeAttr("priority");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user