mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] Add support for specifying estimated costs for Mosaic kernels
PiperOrigin-RevId: 565310871
This commit is contained in:
parent
cab68db752
commit
8acf597eba
@ -69,6 +69,19 @@ tpu_custom_call_p.def_impl(
|
||||
tpu_custom_call_p.multiple_results = True
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CostEstimate:
|
||||
flops: int
|
||||
transcendentals: int
|
||||
bytes_accessed: int
|
||||
|
||||
def to_json(self) -> bytes:
|
||||
return (
|
||||
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
|
||||
f' "bytes_accessed": {self.bytes_accessed}}}'
|
||||
).encode('ascii')
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CustomCallBackendConfig:
|
||||
"""Represents an unserialized backend config for custom calls."""
|
||||
@ -76,13 +89,14 @@ class CustomCallBackendConfig:
|
||||
has_communication: bool
|
||||
collective_id: int | None
|
||||
device_type: str | None
|
||||
cost_estimate: CostEstimate | None
|
||||
|
||||
# We omit the body while printing, because primitive params get embedded
|
||||
# in HLO metadata, and the body blows up its size.
|
||||
def __repr__(self):
|
||||
return "CustomCallBackendConfig(<omitted>)"
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> bytes:
|
||||
"""Serializes the backend config into JSON."""
|
||||
# We format the JSON ourselves, because json.dumps seems to be overly slow.
|
||||
config = io.BytesIO()
|
||||
@ -95,6 +109,9 @@ class CustomCallBackendConfig:
|
||||
if self.collective_id is not None:
|
||||
config.write(b', "collective_id": ')
|
||||
config.write(str(self.collective_id).encode("ascii"))
|
||||
if self.cost_estimate is not None:
|
||||
config.write(b', "cost_estimate": ')
|
||||
config.write(self.cost_estimate.to_json())
|
||||
config.write(b"}")
|
||||
if self.device_type is not None:
|
||||
config.write(b', "device_type": ')
|
||||
@ -303,6 +320,7 @@ def as_tpu_kernel(
|
||||
module: ir.Module,
|
||||
out_type: Any,
|
||||
*,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
backend: str | xla_client.Client = "tpu",
|
||||
device_type: str | None = None,
|
||||
kernel_name: str | None = None,
|
||||
@ -333,6 +351,7 @@ def as_tpu_kernel(
|
||||
has_custom_barrier=has_custom_barrier,
|
||||
kernel_name=kernel_name,
|
||||
kernel_regeneration_metadata=kernel_regeneration_metadata,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
|
||||
|
||||
@ -341,6 +360,7 @@ def _lowered_as_tpu_kernel(
|
||||
out_type: Any,
|
||||
constants: Sequence[Any] = (),
|
||||
*,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
device_type: str | None = None,
|
||||
has_communication: bool = False,
|
||||
has_custom_barrier: bool = False,
|
||||
@ -369,6 +389,7 @@ def _lowered_as_tpu_kernel(
|
||||
has_communication,
|
||||
collective_id,
|
||||
device_type,
|
||||
cost_estimate,
|
||||
)
|
||||
result = tpu_custom_call_p.bind(
|
||||
*args,
|
||||
|
Loading…
x
Reference in New Issue
Block a user