From 8acf597ebac658fdfe12580ecedfcc37f5e3cbe1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 14 Sep 2023 02:50:08 -0700 Subject: [PATCH] [Mosaic] Add support for specifying estimated costs for Mosaic kernels PiperOrigin-RevId: 565310871 --- jax/_src/tpu_custom_call.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index d30a6b991..95786539e 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -69,6 +69,19 @@ tpu_custom_call_p.def_impl( tpu_custom_call_p.multiple_results = True +@dataclasses.dataclass(frozen=True) +class CostEstimate: + flops: int + transcendentals: int + bytes_accessed: int + + def to_json(self) -> bytes: + return ( + f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},' + f' "bytes_accessed": {self.bytes_accessed}}}' + ).encode('ascii') + + @dataclasses.dataclass(frozen=True) class CustomCallBackendConfig: """Represents an unserialized backend config for custom calls.""" @@ -76,13 +89,14 @@ class CustomCallBackendConfig: has_communication: bool collective_id: int | None device_type: str | None + cost_estimate: CostEstimate | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. def __repr__(self): return "CustomCallBackendConfig()" - def to_json(self): + def to_json(self) -> bytes: """Serializes the backend config into JSON.""" # We format the JSON ourselves, because json.dumps seems to be overly slow. config = io.BytesIO() @@ -95,6 +109,9 @@ class CustomCallBackendConfig: if self.collective_id is not None: config.write(b', "collective_id": ') config.write(str(self.collective_id).encode("ascii")) + if self.cost_estimate is not None: + config.write(b', "cost_estimate": ') + config.write(self.cost_estimate.to_json()) config.write(b"}") if self.device_type is not None: config.write(b', "device_type": ') @@ -303,6 +320,7 @@ def as_tpu_kernel( module: ir.Module, out_type: Any, *, + cost_estimate: CostEstimate | None = None, backend: str | xla_client.Client = "tpu", device_type: str | None = None, kernel_name: str | None = None, @@ -333,6 +351,7 @@ def as_tpu_kernel( has_custom_barrier=has_custom_barrier, kernel_name=kernel_name, kernel_regeneration_metadata=kernel_regeneration_metadata, + cost_estimate=cost_estimate, ) @@ -341,6 +360,7 @@ def _lowered_as_tpu_kernel( out_type: Any, constants: Sequence[Any] = (), *, + cost_estimate: CostEstimate | None = None, device_type: str | None = None, has_communication: bool = False, has_custom_barrier: bool = False, @@ -369,6 +389,7 @@ def _lowered_as_tpu_kernel( has_communication, collective_id, device_type, + cost_estimate, ) result = tpu_custom_call_p.bind( *args,