mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU] Allow overriding memory space assignment of kernel outputs
PiperOrigin-RevId: 666400770
This commit is contained in:
parent
a247058327
commit
498ddd50ef
@ -208,6 +208,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
device_type=mosaic_params.get("device_type"),
|
||||
internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"),
|
||||
collective_id=mosaic_params.get("collective_id", None),
|
||||
output_memory_spaces=None, # TODO(apaszke,sharadmv): Implement this.
|
||||
)
|
||||
_maybe_cast_to_bool = lambda x, aval: x.astype(
|
||||
jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x
|
||||
|
@ -21,6 +21,7 @@ import base64
|
||||
import collections.abc
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
@ -67,6 +68,20 @@ tpu_custom_call_p.def_impl(
|
||||
tpu_custom_call_p.multiple_results = True
|
||||
|
||||
|
||||
class MemorySpace(enum.Enum):
|
||||
HBM = enum.auto()
|
||||
VMEM = enum.auto()
|
||||
|
||||
@property
|
||||
def color(self) -> int:
|
||||
if self == MemorySpace.HBM:
|
||||
return 0
|
||||
elif self == MemorySpace.VMEM:
|
||||
return 1
|
||||
else:
|
||||
raise ValueError("invalid memory space: " + str(self))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CostEstimate:
|
||||
flops: int
|
||||
@ -95,6 +110,7 @@ class CustomCallBackendConfig:
|
||||
allow_input_fusion: list[bool] | None
|
||||
serialization_format: int | None
|
||||
internal_scratch_in_bytes: int | None
|
||||
output_memory_spaces: tuple[MemorySpace, ...] | None
|
||||
|
||||
# We omit the body while printing, because primitive params get embedded
|
||||
# in HLO metadata, and the body blows up its size.
|
||||
@ -137,6 +153,13 @@ class CustomCallBackendConfig:
|
||||
if self.internal_scratch_in_bytes is not None:
|
||||
config.write(b', "internal_scratch_in_bytes": ')
|
||||
config.write(str(self.internal_scratch_in_bytes).encode("ascii"))
|
||||
if self.output_memory_spaces is not None:
|
||||
config.write(b', "output_memory_colors": [')
|
||||
for i, memory_space in enumerate(self.output_memory_spaces):
|
||||
if i:
|
||||
config.write(b",")
|
||||
config.write(str(memory_space.color).encode("ascii"))
|
||||
config.write(b"]")
|
||||
config.write(b"}") # End of custom_call_config.
|
||||
if self.device_type is not None:
|
||||
config.write(b', "device_type": ')
|
||||
@ -420,6 +443,7 @@ def _lower_to_custom_call_config(
|
||||
internal_scratch_in_bytes: int | None,
|
||||
collective_id: int | None,
|
||||
serialization_format: int | None,
|
||||
output_memory_spaces: tuple[MemorySpace, ...] | None = None,
|
||||
) -> CustomCallBackendConfig:
|
||||
lowered_module_asm, (
|
||||
has_communication,
|
||||
@ -445,6 +469,7 @@ def _lower_to_custom_call_config(
|
||||
has_communication=has_communication,
|
||||
needs_hlo_passes=needs_hlo_passes,
|
||||
needs_layout_passes=needs_layout_passes,
|
||||
output_memory_spaces=output_memory_spaces,
|
||||
)
|
||||
|
||||
|
||||
@ -463,6 +488,7 @@ def _lowered_to_custom_call_config(
|
||||
needs_hlo_passes: bool,
|
||||
needs_layout_passes: bool,
|
||||
device_type: str | None,
|
||||
output_memory_spaces: tuple[MemorySpace, ...] | None = None,
|
||||
):
|
||||
if has_custom_barrier:
|
||||
if collective_id is None:
|
||||
@ -492,6 +518,7 @@ def _lowered_to_custom_call_config(
|
||||
allow_input_fusion,
|
||||
serialization_format,
|
||||
internal_scratch_in_bytes,
|
||||
output_memory_spaces,
|
||||
)
|
||||
return config
|
||||
|
||||
@ -511,6 +538,7 @@ def lower_module_to_custom_call(
|
||||
internal_scratch_in_bytes: int | None,
|
||||
collective_id: int | None,
|
||||
serialization_format: int | None,
|
||||
output_memory_spaces: tuple[MemorySpace, ...] | None,
|
||||
device_type: str | None,
|
||||
) -> Sequence[ir.Value]:
|
||||
config = _lower_to_custom_call_config(
|
||||
@ -524,6 +552,7 @@ def lower_module_to_custom_call(
|
||||
collective_id=collective_id,
|
||||
device_type=device_type,
|
||||
serialization_format=serialization_format,
|
||||
output_memory_spaces=output_memory_spaces,
|
||||
)
|
||||
return _tpu_custom_call_lowering(
|
||||
ctx,
|
||||
@ -550,6 +579,7 @@ def as_tpu_kernel(
|
||||
internal_scratch_in_bytes: int | None = None,
|
||||
collective_id: int | None = None,
|
||||
serialization_format: int | None = 1,
|
||||
output_memory_spaces: tuple[MemorySpace, ...] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
|
||||
config = _lower_to_custom_call_config(
|
||||
@ -563,6 +593,7 @@ def as_tpu_kernel(
|
||||
internal_scratch_in_bytes=internal_scratch_in_bytes,
|
||||
collective_id=collective_id,
|
||||
serialization_format=serialization_format,
|
||||
output_memory_spaces=output_memory_spaces,
|
||||
)
|
||||
return _as_jax_callable(
|
||||
config,
|
||||
|
Loading…
x
Reference in New Issue
Block a user