[Mosaic] Add support for specifying estimated costs for Mosaic kernels

PiperOrigin-RevId: 565310871
This commit is contained in:
Adam Paszke 2023-09-14 02:50:08 -07:00 committed by jax authors
parent cab68db752
commit 8acf597eba

View File

@ -69,6 +69,19 @@ tpu_custom_call_p.def_impl(
tpu_custom_call_p.multiple_results = True
@dataclasses.dataclass(frozen=True)
class CostEstimate:
flops: int
transcendentals: int
bytes_accessed: int
def to_json(self) -> bytes:
return (
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
f' "bytes_accessed": {self.bytes_accessed}}}'
).encode('ascii')
@dataclasses.dataclass(frozen=True)
class CustomCallBackendConfig:
"""Represents an unserialized backend config for custom calls."""
@ -76,13 +89,14 @@ class CustomCallBackendConfig:
has_communication: bool
collective_id: int | None
device_type: str | None
cost_estimate: CostEstimate | None
# We omit the body while printing, because primitive params get embedded
# in HLO metadata, and the body blows up its size.
def __repr__(self):
return "CustomCallBackendConfig(<omitted>)"
def to_json(self):
def to_json(self) -> bytes:
"""Serializes the backend config into JSON."""
# We format the JSON ourselves, because json.dumps seems to be overly slow.
config = io.BytesIO()
@ -95,6 +109,9 @@ class CustomCallBackendConfig:
if self.collective_id is not None:
config.write(b', "collective_id": ')
config.write(str(self.collective_id).encode("ascii"))
if self.cost_estimate is not None:
config.write(b', "cost_estimate": ')
config.write(self.cost_estimate.to_json())
config.write(b"}")
if self.device_type is not None:
config.write(b', "device_type": ')
@ -303,6 +320,7 @@ def as_tpu_kernel(
module: ir.Module,
out_type: Any,
*,
cost_estimate: CostEstimate | None = None,
backend: str | xla_client.Client = "tpu",
device_type: str | None = None,
kernel_name: str | None = None,
@ -333,6 +351,7 @@ def as_tpu_kernel(
has_custom_barrier=has_custom_barrier,
kernel_name=kernel_name,
kernel_regeneration_metadata=kernel_regeneration_metadata,
cost_estimate=cost_estimate,
)
@ -341,6 +360,7 @@ def _lowered_as_tpu_kernel(
out_type: Any,
constants: Sequence[Any] = (),
*,
cost_estimate: CostEstimate | None = None,
device_type: str | None = None,
has_communication: bool = False,
has_custom_barrier: bool = False,
@ -369,6 +389,7 @@ def _lowered_as_tpu_kernel(
has_communication,
collective_id,
device_type,
cost_estimate,
)
result = tpu_custom_call_p.bind(
*args,