[Mosaic] Set TPU CustomCall device type based on the core_type attribute

This CL deprecates the device_type parameter of `tpu_custom_call.as_tpu_kernel()` in favour of the `tpu.core_type` annotation.
The latter is more fine-grained: it is applied on `func.FuncOp` instead of the entire module, supports `tc`, `sc_scalar_subcore` and `sc_vector_subcore`.

`device_type` of the TPU CustomCall HLO is set to `sparsecore` if `sc_scalar_subcore` or `sc_vector_subcore` annotation is provided. Otherwise, `device_type` is not set and the CustomCall targets TC.

PiperOrigin-RevId: 692212644
This commit is contained in:
Naums Mogers 2024-11-01 10:02:07 -07:00 committed by jax authors
parent bd7c301968
commit f462d7e586

View File

@ -453,6 +453,44 @@ def _lower_mosaic_module_to_asm(
)
def _get_device_type(module: ir.Module) -> str | None:
"""Determines the device type based on the core_type annotations."""
sparsecore_func_found = False
tensorcore_func_found = False
def assign_device_type_based_on_core_type(op: ir.Operation) -> ir.WalkResult:
nonlocal sparsecore_func_found
nonlocal tensorcore_func_found
if op.name == "func.func":
if "tpu.core_type" in op.attributes:
core_type = op.attributes["tpu.core_type"]
if str(core_type) in [
f"#tpu.core_type<{c}>"
for c in ["sc_scalar_subcore", "sc_vector_subcore"]
]:
sparsecore_func_found = True
if tensorcore_func_found:
return ir.WalkResult.INTERRUPT
return ir.WalkResult.SKIP
if str(core_type) == "#tpu.core_type<tc>":
tensorcore_func_found = True
return ir.WalkResult.SKIP
raise ValueError(f"Unknown core type: {core_type}")
return ir.WalkResult.ADVANCE
module.operation.walk(
assign_device_type_based_on_core_type, walk_order=ir.WalkOrder.PRE_ORDER
)
if tensorcore_func_found and sparsecore_func_found:
raise ValueError(
"A single Mosaic kernel cannot contain both "
"TensorCore and SparseCore functions."
)
if sparsecore_func_found:
return "sparsecore"
return None
def _lower_to_custom_call_config(
module: ir.Module,
*,
@ -592,7 +630,6 @@ def as_tpu_kernel(
*,
cost_estimate: CostEstimate | None = None,
backend: str | xla_client.Client = "tpu",
device_type: str | None = None,
kernel_name: str | None = None,
vmem_limit_bytes: int | None = None,
flags: dict[str, bool | int | float] | None = None,
@ -604,6 +641,7 @@ def as_tpu_kernel(
output_memory_spaces: tuple[MemorySpace | None, ...] | None = None,
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
device_type = _get_device_type(module)
config = _lower_to_custom_call_config(
module,
backend=backend,