1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 05:06:07 +00:00
rocm_jax/jax/_src/tpu_custom_call.py
Peter Hawkins 7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00

487 lines
17 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX bindings for Mosaic."""
# mypy: ignore-errors
from __future__ import annotations
import base64
import collections.abc
from collections.abc import Callable, Sequence
import dataclasses
import functools
import io
import os
import time
from typing import Any
import jax
from jax import core
from jax._src import config
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib import tpu
from jax._src.lib import xla_client
from jax.interpreters import xla
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import mhlo
from jaxlib.mlir.passmanager import PassManager
try:
from absl import flags
FLAGS = flags.FLAGS
except ImportError:
FLAGS = {}
_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state(
name="mosaic_use_python_pipeline",
default=False,
help=(
"Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel"
" is called (for Pallas, this happens at JAX lowering time), instead of"
" later within XLA."
),
)
_MOSAIC_ALLOW_HLO = config.bool_state(
name="jax_mosaic_allow_hlo",
default=False,
help="Allow hlo dialects in Mosaic",
)
tpu_custom_call_p = core.Primitive("tpu_custom_call")
tpu_custom_call_p.def_impl(
functools.partial(xla.apply_primitive, tpu_custom_call_p))
tpu_custom_call_p.multiple_results = True
@dataclasses.dataclass(frozen=True)
class CostEstimate:
flops: int
transcendentals: int
bytes_accessed: int
def to_json(self) -> bytes:
return (
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
f' "bytes_accessed": {self.bytes_accessed}}}'
).encode('ascii')
@dataclasses.dataclass(frozen=True)
class CustomCallBackendConfig:
"""Represents an unserialized backend config for custom calls."""
lowered_module_asm: bytes
has_communication: bool
collective_id: int | None
device_type: str | None
cost_estimate: CostEstimate | None
needs_hlo_passes: bool
needs_layout_passes: bool
vmem_limit_bytes: int | None
flags: dict[str, bool | int | float] | None
allow_input_fusion: list[bool] | None
serialization_format: int | None
internal_scratch_in_bytes: int | None
# We omit the body while printing, because primitive params get embedded
# in HLO metadata, and the body blows up its size.
def __repr__(self):
return "CustomCallBackendConfig(<omitted>)"
def to_json(self) -> bytes:
"""Serializes the backend config into JSON."""
# We format the JSON ourselves, because json.dumps seems to be overly slow.
config = io.BytesIO()
config.write(b'{"custom_call_config": {"body": "')
config.write(base64.b64encode(self.lowered_module_asm))
config.write(b'"')
if self.has_communication:
config.write(b', "has_communication": ')
config.write(str(self.has_communication).lower().encode("ascii"))
if self.collective_id is not None:
config.write(b', "collective_id": ')
config.write(str(self.collective_id).encode("ascii"))
if self.cost_estimate is not None:
config.write(b', "cost_estimate": ')
config.write(self.cost_estimate.to_json())
if self.needs_hlo_passes:
config.write(b', "needs_hlo_passes": ')
config.write(str(self.needs_hlo_passes).lower().encode("ascii"))
if self.serialization_format is not None:
config.write(b', "serialization_format": ')
config.write(str(self.serialization_format).lower().encode("ascii"))
if self.needs_layout_passes:
config.write(b', "needs_layout_passes": ')
config.write(str(self.needs_layout_passes).lower().encode("ascii"))
if self.allow_input_fusion is not None:
config.write(b', "allow_input_fusion": [')
for i, value in enumerate(self.allow_input_fusion):
config.write(b"true" if value else b"false")
# config.write(str(value).lower().encode("ascii"))
if i + 1 != len(self.allow_input_fusion):
config.write(b",")
config.write(b"]")
if self.internal_scratch_in_bytes is not None:
config.write(b', "internal_scratch_in_bytes": ')
config.write(str(self.internal_scratch_in_bytes).encode("ascii"))
config.write(b"}") # End of custom_call_config.
if self.device_type is not None:
config.write(b', "device_type": ')
config.write(
('"DEVICE_TYPE_' + self.device_type.upper() + '"').encode("ascii")
)
if self.vmem_limit_bytes is not None:
config.write(
b', "scoped_memory_configs": [{"memory_space":1, "offset": 0,'
b' "size": '
)
config.write(str(self.vmem_limit_bytes).encode("ascii"))
config.write(b'}]')
if self.flags is not None:
config.write(b', "flag_configs": [')
for i, (flag, value) in enumerate(self.flags.items()):
config.write(b'{"flag_type": "')
config.write(flag.encode("ascii"))
config.write(b'", value: {')
if isinstance(value, bool):
config.write(b'"boolean_value": ')
config.write(b"true" if value else b"false")
elif isinstance(value, int):
config.write(b'"integer_value": ')
config.write(str(value).encode("ascii"))
elif isinstance(value, float):
config.write(b'"double_value": ')
config.write(str(value).encode("ascii"))
else:
raise ValueError("invalid flag value: " + str(value))
config.write(b"}}")
if i + 1 != len(self.flags):
config.write(b",")
config.write(b"]")
# Prevent the compiler from sharding the custom call beyond what Mosaic does
# based on user annotations
config.write(b', "implicit_sharding": {"type": "MANUAL"}')
config.write(b"}")
return config.getvalue()
@tpu_custom_call_p.def_abstract_eval
def _tpu_custom_call_abstract_eval(*_, out_avals, **__):
return out_avals
def _avals_to_layouts(avals) -> Sequence[Sequence[int]]:
return [tuple(range(a.ndim - 1, -1, -1)) for a in avals]
def _tpu_custom_call_lowering(
ctx: mlir.LoweringRuleContext,
*in_nodes, # pylint: disable=missing-function-docstring
config: CustomCallBackendConfig,
kernel_name: str | None,
out_avals: Any,
input_output_aliases: tuple[tuple[int, int], ...],
) -> ...:
i32_type = ir.IntegerType.get_signless(32)
result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals]
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
)
elif isinstance(axis_context, sharding_impls.ShardingContext):
if axis_context.num_devices != 1:
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
)
elif config.has_communication:
raise NotImplementedError(
"Replica lowering for Mosaic kernels not implemented."
)
extra_attributes = {}
# Add kernel_name and kernel_metadata as attributes to the custom call op.
# This is because we do not want to pollute the backend_config with this
# information.
if kernel_name is not None:
extra_attributes = dict(kernel_name=ir.StringAttr.get(kernel_name))
call = mlir.custom_call(
"tpu_custom_call",
result_types=result_types,
operands=in_nodes,
backend_config=config.to_json(),
api_version=1,
operand_output_aliases=dict(input_output_aliases),
operand_layouts=_avals_to_layouts(ctx.avals_in),
result_layouts=_avals_to_layouts(ctx.avals_out),
extra_attributes=extra_attributes)
return call.results
mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering,
platform="tpu")
def _lower_tpu_kernel(
module: ir.Module,
hardware_generation: int,
) -> ir.Module:
"""Runs MLIR passes lowering the given module to an MLIR module.
Uses Python versions of infer-memref-layout and apply-vector-layout.
Args:
module: The MLIR module to lower.
hardware_generation: The TPU hardware generation to target.
Returns:
An MLIR module implementing the kernel.
"""
try:
module.operation.verify()
except ir.MLIRError as e:
raise ValueError("The compiled module fails MLIR verification") from e
with module.context as ctx, module.operation.location as _:
ctx.append_dialect_registry(mlir.upstream_dialects)
ctx.load_all_available_dialects()
tpu.register_dialect(ctx)
mhlo.register_mhlo_dialect(ctx)
mhlo.register_mhlo_passes()
dump_mlir(module, "original")
if _MOSAIC_ALLOW_HLO.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [
"hlo-legalize-to-arithmetic",
"func.func(hlo-legalize-to-linalg)",
"func.func(linalg-vectorization)",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-hlo-conversion")
pipeline = [
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-memref-layout")
pipeline = [
"canonicalize",
"cse",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-simplify")
try:
on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value
except KeyError:
on_device_checks = False
if checks := on_device_checks:
checks = set(checks.split(","))
if checks == {"bounds"}: # We only support one kind of checks now.
pipeline = PassManager.parse(
"builtin.module(func.func(debug-assert-insertion))"
)
pipeline.run(module.operation)
dump_mlir(module, "post-assert-insertion")
elif checks:
checks.discard("bounds")
raise ValueError(
f"Unrecognized on-device check categories: {', '.join(checks)}"
)
pipeline = [
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-vector-layout")
mxu_size = 128 if hardware_generation < 6 else 256
pipeline = [
"func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128"
f" hardware-generation={hardware_generation}"
f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}"
"})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-apply-vector-layout")
pipeline = PassManager.parse("builtin.module(canonicalize)")
pipeline.run(module.operation)
dump_mlir(module, "pre-lower-to-llo")
return module
def as_tpu_kernel(
module: ir.Module,
out_type: Any,
*,
cost_estimate: CostEstimate | None = None,
backend: str | xla_client.Client = "tpu",
device_type: str | None = None,
kernel_name: str | None = None,
vmem_limit_bytes: int | None = None,
flags: dict[str, bool | int | float] | None = None,
allow_input_fusion: list[bool] | None = None,
input_output_aliases: tuple[tuple[int, int], ...] = (),
internal_scratch_in_bytes: int | None = None,
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
# We use jax.jit to make sure we hit the fast compilation cache.
if vmem_limit_bytes is not None and not isinstance(vmem_limit_bytes, int):
raise ValueError(
"vmem_limit_bytes must be an int: provided with a"
f" {type(vmem_limit_bytes)}."
)
has_communication, has_custom_barrier = tpu.private_has_communication(
module.operation
)
needs_hlo_passes = _MOSAIC_ALLOW_HLO.value
needs_layout_passes = not device_type
# We'll mutate the module, so clone it
with module.context as ctx, module.operation.location as _:
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True)
)
if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
some_tpu = jax.devices(backend)[0]
device_kind = some_tpu.device_kind
if not device_kind.startswith("TPU v"):
raise ValueError(
f"Unrecognized TPU device kind: {device_kind}. "
"tpu_custom_call cannot be lowered on a machine without TPUs "
"when mosaic_use_python_pipeline=True.")
hardware_generation = int(device_kind[len("TPU v")])
module = _lower_tpu_kernel(module, hardware_generation)
needs_hlo_passes = False
needs_layout_passes = False
prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects
ctx.allow_unregistered_dialects = True
try:
pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})")
pipeline.run(module.operation)
finally:
ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects
bytecode_buffer = io.BytesIO()
module.operation.write_bytecode(bytecode_buffer, desired_version=0)
asm = bytecode_buffer.getvalue()
# TODO(amagni): Kernel name and regeneration metadata could alternatively be
# added as a custom attribute to the MLIR call op rather than including them
# in the backend_config.
return _lowered_as_tpu_kernel(
asm,
out_type,
needs_hlo_passes=needs_hlo_passes,
needs_layout_passes=needs_layout_passes,
device_type=device_type,
has_communication=has_communication,
has_custom_barrier=has_custom_barrier,
kernel_name=kernel_name,
cost_estimate=cost_estimate,
vmem_limit_bytes=vmem_limit_bytes,
flags=flags,
allow_input_fusion=allow_input_fusion,
input_output_aliases=input_output_aliases,
internal_scratch_in_bytes=internal_scratch_in_bytes,
)
def _lowered_as_tpu_kernel(
lowered_module_asm: bytes,
out_type: Any,
*,
cost_estimate: CostEstimate | None = None,
needs_hlo_passes: bool = False,
needs_layout_passes: bool = False,
device_type: str | None = None,
has_communication: bool = False,
has_custom_barrier: bool = False,
kernel_name: str | None = None,
vmem_limit_bytes: int | None = None,
flags: dict[str, bool | int | float] | None = None,
allow_input_fusion: list[bool] | None = None,
input_output_aliases: tuple[tuple[int, int], ...] = (),
serialization_format: int | None = 1,
internal_scratch_in_bytes: int | None = None,
):
"""Turns a low-level MLIR Mosaic kernel into a JAX-compatible function."""
unpack = False
if not isinstance(out_type, collections.abc.Iterable):
out_type = (out_type,)
unpack = True
out_avals = tuple(core.ShapedArray(ty.shape, ty.dtype) for ty in out_type)
def apply_kernel(*args, collective_id: int | None = None):
if has_custom_barrier:
if collective_id is None:
raise ValueError(
"collective_id has to be specified when using a custom barrier"
)
elif collective_id is not None:
raise ValueError(
"collective_id has to be unspecified or None when not using a custom"
" barrier"
)
config = CustomCallBackendConfig(
lowered_module_asm,
has_communication,
collective_id,
device_type,
cost_estimate,
needs_hlo_passes,
needs_layout_passes,
vmem_limit_bytes,
flags,
allow_input_fusion,
serialization_format,
internal_scratch_in_bytes,
)
result = tpu_custom_call_p.bind(
*args,
config=config,
kernel_name=kernel_name,
out_avals=out_avals,
input_output_aliases=input_output_aliases,
)
return result[0] if unpack else result
return jax.jit(apply_kernel, static_argnames=["collective_id"])
def dump_mlir(module: ir.Module, name: str):
"""A helper function to dump mosaic mlir module"""
try:
should_dump = FLAGS["xla_mosaic_dump_to"].value
except KeyError:
return
if should_dump == "sponge":
outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None)
if outdir:
path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}.txt")
with open(path, "w") as f:
f.write(str(module))