[Mosaic TPU] Allow overriding memory space assignment of kernel outputs

PiperOrigin-RevId: 666400770
This commit is contained in:
Adam Paszke 2024-08-22 10:22:59 -07:00 committed by jax authors
parent a247058327
commit 498ddd50ef
2 changed files with 32 additions and 0 deletions

View File

@ -208,6 +208,7 @@ def pallas_call_tpu_lowering_rule(
device_type=mosaic_params.get("device_type"),
internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"),
collective_id=mosaic_params.get("collective_id", None),
output_memory_spaces=None, # TODO(apaszke,sharadmv): Implement this.
)
_maybe_cast_to_bool = lambda x, aval: x.astype(
jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x

View File

@ -21,6 +21,7 @@ import base64
import collections.abc
from collections.abc import Callable, Sequence
import dataclasses
import enum
import functools
import io
import os
@ -67,6 +68,20 @@ tpu_custom_call_p.def_impl(
tpu_custom_call_p.multiple_results = True
class MemorySpace(enum.Enum):
HBM = enum.auto()
VMEM = enum.auto()
@property
def color(self) -> int:
if self == MemorySpace.HBM:
return 0
elif self == MemorySpace.VMEM:
return 1
else:
raise ValueError("invalid memory space: " + str(self))
@dataclasses.dataclass(frozen=True)
class CostEstimate:
flops: int
@ -95,6 +110,7 @@ class CustomCallBackendConfig:
allow_input_fusion: list[bool] | None
serialization_format: int | None
internal_scratch_in_bytes: int | None
output_memory_spaces: tuple[MemorySpace, ...] | None
# We omit the body while printing, because primitive params get embedded
# in HLO metadata, and the body blows up its size.
@ -137,6 +153,13 @@ class CustomCallBackendConfig:
if self.internal_scratch_in_bytes is not None:
config.write(b', "internal_scratch_in_bytes": ')
config.write(str(self.internal_scratch_in_bytes).encode("ascii"))
if self.output_memory_spaces is not None:
config.write(b', "output_memory_colors": [')
for i, memory_space in enumerate(self.output_memory_spaces):
if i:
config.write(b",")
config.write(str(memory_space.color).encode("ascii"))
config.write(b"]")
config.write(b"}") # End of custom_call_config.
if self.device_type is not None:
config.write(b', "device_type": ')
@ -420,6 +443,7 @@ def _lower_to_custom_call_config(
internal_scratch_in_bytes: int | None,
collective_id: int | None,
serialization_format: int | None,
output_memory_spaces: tuple[MemorySpace, ...] | None = None,
) -> CustomCallBackendConfig:
lowered_module_asm, (
has_communication,
@ -445,6 +469,7 @@ def _lower_to_custom_call_config(
has_communication=has_communication,
needs_hlo_passes=needs_hlo_passes,
needs_layout_passes=needs_layout_passes,
output_memory_spaces=output_memory_spaces,
)
@ -463,6 +488,7 @@ def _lowered_to_custom_call_config(
needs_hlo_passes: bool,
needs_layout_passes: bool,
device_type: str | None,
output_memory_spaces: tuple[MemorySpace, ...] | None = None,
):
if has_custom_barrier:
if collective_id is None:
@ -492,6 +518,7 @@ def _lowered_to_custom_call_config(
allow_input_fusion,
serialization_format,
internal_scratch_in_bytes,
output_memory_spaces,
)
return config
@ -511,6 +538,7 @@ def lower_module_to_custom_call(
internal_scratch_in_bytes: int | None,
collective_id: int | None,
serialization_format: int | None,
output_memory_spaces: tuple[MemorySpace, ...] | None,
device_type: str | None,
) -> Sequence[ir.Value]:
config = _lower_to_custom_call_config(
@ -524,6 +552,7 @@ def lower_module_to_custom_call(
collective_id=collective_id,
device_type=device_type,
serialization_format=serialization_format,
output_memory_spaces=output_memory_spaces,
)
return _tpu_custom_call_lowering(
ctx,
@ -550,6 +579,7 @@ def as_tpu_kernel(
internal_scratch_in_bytes: int | None = None,
collective_id: int | None = None,
serialization_format: int | None = 1,
output_memory_spaces: tuple[MemorySpace, ...] | None = None,
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
config = _lower_to_custom_call_config(
@ -563,6 +593,7 @@ def as_tpu_kernel(
internal_scratch_in_bytes=internal_scratch_in_bytes,
collective_id=collective_id,
serialization_format=serialization_format,
output_memory_spaces=output_memory_spaces,
)
return _as_jax_callable(
config,