mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Set TPU CustomCall device type based on the core_type attribute
This CL deprecates the device_type parameter of `tpu_custom_call.as_tpu_kernel()` in favour of the `tpu.core_type` annotation. The latter is more fine-grained: it is applied on `func.FuncOp` instead of the entire module, supports `tc`, `sc_scalar_subcore` and `sc_vector_subcore`. `device_type` of the TPU CustomCall HLO is set to `sparsecore` if `sc_scalar_subcore` or `sc_vector_subcore` annotation is provided. Otherwise, `device_type` is not set and the CustomCall targets TC. PiperOrigin-RevId: 692212644
This commit is contained in:
parent
bd7c301968
commit
f462d7e586
@ -453,6 +453,44 @@ def _lower_mosaic_module_to_asm(
|
||||
)
|
||||
|
||||
|
||||
def _get_device_type(module: ir.Module) -> str | None:
|
||||
"""Determines the device type based on the core_type annotations."""
|
||||
sparsecore_func_found = False
|
||||
tensorcore_func_found = False
|
||||
|
||||
def assign_device_type_based_on_core_type(op: ir.Operation) -> ir.WalkResult:
|
||||
nonlocal sparsecore_func_found
|
||||
nonlocal tensorcore_func_found
|
||||
if op.name == "func.func":
|
||||
if "tpu.core_type" in op.attributes:
|
||||
core_type = op.attributes["tpu.core_type"]
|
||||
if str(core_type) in [
|
||||
f"#tpu.core_type<{c}>"
|
||||
for c in ["sc_scalar_subcore", "sc_vector_subcore"]
|
||||
]:
|
||||
sparsecore_func_found = True
|
||||
if tensorcore_func_found:
|
||||
return ir.WalkResult.INTERRUPT
|
||||
return ir.WalkResult.SKIP
|
||||
if str(core_type) == "#tpu.core_type<tc>":
|
||||
tensorcore_func_found = True
|
||||
return ir.WalkResult.SKIP
|
||||
raise ValueError(f"Unknown core type: {core_type}")
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
module.operation.walk(
|
||||
assign_device_type_based_on_core_type, walk_order=ir.WalkOrder.PRE_ORDER
|
||||
)
|
||||
if tensorcore_func_found and sparsecore_func_found:
|
||||
raise ValueError(
|
||||
"A single Mosaic kernel cannot contain both "
|
||||
"TensorCore and SparseCore functions."
|
||||
)
|
||||
if sparsecore_func_found:
|
||||
return "sparsecore"
|
||||
return None
|
||||
|
||||
|
||||
def _lower_to_custom_call_config(
|
||||
module: ir.Module,
|
||||
*,
|
||||
@ -592,7 +630,6 @@ def as_tpu_kernel(
|
||||
*,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
backend: str | xla_client.Client = "tpu",
|
||||
device_type: str | None = None,
|
||||
kernel_name: str | None = None,
|
||||
vmem_limit_bytes: int | None = None,
|
||||
flags: dict[str, bool | int | float] | None = None,
|
||||
@ -604,6 +641,7 @@ def as_tpu_kernel(
|
||||
output_memory_spaces: tuple[MemorySpace | None, ...] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
|
||||
device_type = _get_device_type(module)
|
||||
config = _lower_to_custom_call_config(
|
||||
module,
|
||||
backend=backend,
|
||||
|
Loading…
x
Reference in New Issue
Block a user