rocm_jax/jax/_src/tpu_custom_call.py
2024-04-26 08:36:32 -07:00

515 lines
18 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX bindings for Mosaic."""
# mypy: ignore-errors
from __future__ import annotations
import base64
import collections.abc
import dataclasses
import functools
import io
import os
import time
from typing import Any, Callable
import jax
from jax import core
from jax._src import config
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib import tpu
from jax._src.lib import xla_client
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import xla
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import mhlo
from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager
import numpy as np
try:
from absl import flags
FLAGS = flags.FLAGS
except ImportError:
FLAGS = {}
_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state(
name="mosaic_use_python_pipeline",
default=False,
help=(
"Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel"
" is called (for Pallas, this happens at JAX lowering time), instead of"
" later within XLA."
),
)
_MOSAIC_ALLOW_HLO = config.define_bool_state(
name="jax_mosaic_allow_hlo",
default=False,
help="Allow hlo dialects in Mosaic",
)
tpu_custom_call_p = core.Primitive("tpu_custom_call")
tpu_custom_call_p.def_impl(
functools.partial(xla.apply_primitive, tpu_custom_call_p))
tpu_custom_call_p.multiple_results = True
@dataclasses.dataclass(frozen=True)
class CostEstimate:
flops: int
transcendentals: int
bytes_accessed: int
def to_json(self) -> bytes:
return (
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
f' "bytes_accessed": {self.bytes_accessed}}}'
).encode('ascii')
@dataclasses.dataclass(frozen=True)
class CustomCallBackendConfig:
"""Represents an unserialized backend config for custom calls."""
lowered_module_asm: bytes
has_communication: bool
collective_id: int | None
device_type: str | None
cost_estimate: CostEstimate | None
needs_hlo_passes: bool
needs_layout_passes: bool
vmem_limit_bytes: int | None
flags: dict[str, bool | int | float] | None
allow_input_fusion: list[bool] | None
serialization_format: int | None
# We omit the body while printing, because primitive params get embedded
# in HLO metadata, and the body blows up its size.
def __repr__(self):
return "CustomCallBackendConfig(<omitted>)"
def to_json(self) -> bytes:
"""Serializes the backend config into JSON."""
# We format the JSON ourselves, because json.dumps seems to be overly slow.
config = io.BytesIO()
config.write(b'{"custom_call_config": {"body": "')
config.write(base64.b64encode(self.lowered_module_asm))
config.write(b'"')
if self.has_communication:
config.write(b', "has_communication": ')
config.write(str(self.has_communication).lower().encode("ascii"))
if self.collective_id is not None:
config.write(b', "collective_id": ')
config.write(str(self.collective_id).encode("ascii"))
if self.cost_estimate is not None:
config.write(b', "cost_estimate": ')
config.write(self.cost_estimate.to_json())
if self.needs_hlo_passes:
config.write(b', "needs_hlo_passes": ')
config.write(str(self.needs_hlo_passes).lower().encode("ascii"))
if self.serialization_format is not None:
config.write(b', "serialization_format": ')
config.write(str(self.serialization_format).lower().encode("ascii"))
if self.needs_layout_passes:
config.write(b', "needs_layout_passes": ')
config.write(str(self.needs_layout_passes).lower().encode("ascii"))
if self.allow_input_fusion is not None:
config.write(b', "allow_input_fusion": [')
for i, value in enumerate(self.allow_input_fusion):
config.write(b"true" if value else b"false")
# config.write(str(value).lower().encode("ascii"))
if i + 1 != len(self.allow_input_fusion):
config.write(b",")
config.write(b"]")
config.write(b"}") # End of custom_call_config.
if self.device_type is not None:
config.write(b', "device_type": ')
config.write(
('"DEVICE_TYPE_' + self.device_type.upper() + '"').encode("ascii")
)
if self.vmem_limit_bytes is not None:
config.write(
b', "scoped_memory_configs": [{"memory_space":1, "offset": 0,'
b' "size": '
)
config.write(str(self.vmem_limit_bytes).encode("ascii"))
config.write(b'}]')
if self.flags is not None:
config.write(b', "flag_configs": [')
for i, (flag, value) in enumerate(self.flags.items()):
config.write(b'{"flag_type": "')
config.write(flag.encode("ascii"))
config.write(b'", value: {')
if isinstance(value, bool):
config.write(b'"boolean_value": ')
config.write(b"true" if value else b"false")
elif isinstance(value, int):
config.write(b'"integer_value": ')
config.write(str(value).encode("ascii"))
elif isinstance(value, float):
config.write(b'"double_value": ')
config.write(str(value).encode("ascii"))
else:
raise ValueError("invalid flag value: " + str(value))
config.write(b"}}")
if i + 1 != len(self.flags):
config.write(b",")
config.write(b"]")
# Prevent the compiler from sharding the custom call beyond what Mosaic does
# based on user annotations
config.write(b', "implicit_sharding": {"type": "MANUAL"}')
config.write(b"}")
return config.getvalue()
@tpu_custom_call_p.def_abstract_eval
def _tpu_custom_call_abstract_eval(*_, out_avals, **__):
return out_avals
def _aval_to_layout(aval):
arange = np.arange(aval.ndim, dtype=np.dtype(np.int64))[::-1].copy()
return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get())
def _avals_to_layouts(avals):
return ir.ArrayAttr.get([_aval_to_layout(a) for a in avals])
def _tpu_custom_call_lowering(
ctx: mlir.LoweringRuleContext,
*in_nodes, # pylint: disable=missing-function-docstring
config: CustomCallBackendConfig,
kernel_name: str | None,
kernel_regeneration_metadata: bytes | None,
out_avals: Any,
input_output_aliases: tuple[tuple[int, int], ...],
) -> ...:
i32_type = ir.IntegerType.get_signless(32)
multiple_results = len(out_avals) > 1
if multiple_results:
result_type = ir.TupleType.get_tuple(
[mlir.aval_to_ir_type(aval) for aval in out_avals]
)
else:
result_type = mlir.aval_to_ir_type(out_avals[0])
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
)
elif isinstance(axis_context, sharding_impls.ShardingContext):
if axis_context.num_devices != 1:
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
)
elif config.has_communication:
raise NotImplementedError(
"Replica lowering for Mosaic kernels not implemented."
)
call = stablehlo.CustomCallOp(
[result_type],
in_nodes,
call_target_name=ir.StringAttr.get(b"tpu_custom_call"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(config.to_json()),
api_version=ir.IntegerAttr.get(i32_type, 1),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=_avals_to_layouts(ctx.avals_in),
result_layouts=_avals_to_layouts(ctx.avals_out),
output_operand_aliases=ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(
# if len(result_types) == 1 then the aliasing refers implicitly to
# the only output.
output_tuple_indices=[output_idx]
if len(out_avals) > 1
else [],
operand_index=input_idx,
operand_tuple_indices=[],
)
for input_idx, output_idx in input_output_aliases
]),
)
# Add kernel_name and kernel_regeneration_metadata as attributes to the
# custom call op. This is because we do not want to pollute the backend_config
# with this information.
if kernel_name is not None:
call.attributes["kernel_name"] = ir.StringAttr.get(kernel_name)
if kernel_regeneration_metadata is not None:
call.attributes["kernel_regeneration_metadata"] = ir.StringAttr.get(
base64.b64encode(kernel_regeneration_metadata)
)
if multiple_results:
results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i))
for i in range(len(out_avals))]
else:
results = call.results
return results
mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering,
platform="tpu")
def _lower_tpu_kernel(
module: ir.Module,
hardware_generation: int,
) -> ir.Module:
"""Runs MLIR passes lowering the given module to an MLIR module.
Uses Python versions of infer-memref-layout and apply-vector-layout.
Args:
module: The MLIR module to lower.
hardware_generation: The TPU hardware generation to target.
Returns:
An MLIR module implementing the kernel.
"""
try:
module.operation.verify()
except ir.MLIRError as e:
raise ValueError("The compiled module fails MLIR verification") from e
with module.context as ctx, module.operation.location as _:
ctx.append_dialect_registry(mlir.upstream_dialects)
ctx.load_all_available_dialects()
tpu.register_dialect(ctx)
mhlo.register_mhlo_dialect(ctx)
mhlo.register_mhlo_passes()
dump_mlir(module, "original")
if _MOSAIC_ALLOW_HLO.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [
"hlo-legalize-to-arithmetic",
"func.func(hlo-legalize-to-linalg)",
"func.func(linalg-vectorization)",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-hlo-conversion")
pipeline = [
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-memref-layout")
pipeline = [
"canonicalize",
"cse",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-simplify")
try:
on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value
except KeyError:
on_device_checks = False
if checks := on_device_checks:
checks = set(checks.split(","))
if checks == {"bounds"}: # We only support one kind of checks now.
pipeline = PassManager.parse(
"builtin.module(func.func(debug-assert-insertion))"
)
pipeline.run(module.operation)
dump_mlir(module, "post-assert-insertion")
elif checks:
checks.discard("bounds")
raise ValueError(
f"Unrecognized on-device check categories: {', '.join(checks)}"
)
pipeline = [
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-vector-layout")
mxu_size = 128 if hardware_generation < 6 else 256
pipeline = [
"func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128"
f" hardware-generation={hardware_generation}"
f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}"
"})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-apply-vector-layout")
pipeline = PassManager.parse("builtin.module(canonicalize)")
pipeline.run(module.operation)
dump_mlir(module, "pre-lower-to-llo")
return module
def as_tpu_kernel(
module: ir.Module,
out_type: Any,
*,
cost_estimate: CostEstimate | None = None,
backend: str | xla_client.Client = "tpu",
device_type: str | None = None,
kernel_name: str | None = None,
kernel_regeneration_metadata: bytes | None = None,
vmem_limit_bytes: int | None = None,
flags: dict[str, bool | int | float] | None = None,
allow_input_fusion: list[bool] | None = None,
input_output_aliases: tuple[tuple[int, int], ...] = (),
) -> Callable[..., Any]:
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
# We use jax.jit to make sure we hit the fast compilation cache.
some_tpu = jax.devices(backend)[0]
device_kind = some_tpu.device_kind
if not device_kind.startswith("TPU v"):
raise ValueError(f"Unrecognized TPU device kind: {device_kind}.")
if vmem_limit_bytes is not None and not isinstance(vmem_limit_bytes, int):
raise ValueError(
"vmem_limit_bytes must be an int: provided with a"
f" {type(vmem_limit_bytes)}."
)
hardware_generation = int(device_kind[len("TPU v")])
has_communication, has_custom_barrier = tpu.private_has_communication(
module.operation
)
needs_hlo_passes = _MOSAIC_ALLOW_HLO.value
needs_layout_passes = not device_type
# We'll mutate the module, so clone it
with module.context as ctx, module.operation.location as _:
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True)
)
if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
module = _lower_tpu_kernel(module, hardware_generation)
needs_hlo_passes = False
needs_layout_passes = False
prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects
ctx.allow_unregistered_dialects = True
try:
pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})")
pipeline.run(module.operation)
finally:
ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects
bytecode_buffer = io.BytesIO()
module.operation.write_bytecode(bytecode_buffer, desired_version=0)
asm = bytecode_buffer.getvalue()
# TODO(amagni): Kernel name and regeneration metadata could alternatively be
# added as a custom attribute to the MLIR call op rather than including them
# in the backend_config.
return _lowered_as_tpu_kernel(
asm,
out_type,
needs_hlo_passes=needs_hlo_passes,
needs_layout_passes=needs_layout_passes,
device_type=device_type,
has_communication=has_communication,
has_custom_barrier=has_custom_barrier,
kernel_name=kernel_name,
kernel_regeneration_metadata=kernel_regeneration_metadata,
cost_estimate=cost_estimate,
vmem_limit_bytes=vmem_limit_bytes,
flags=flags,
allow_input_fusion=allow_input_fusion,
input_output_aliases=input_output_aliases,
)
def _lowered_as_tpu_kernel(
lowered_module_asm: bytes,
out_type: Any,
*,
cost_estimate: CostEstimate | None = None,
needs_hlo_passes: bool = False,
needs_layout_passes: bool = False,
device_type: str | None = None,
has_communication: bool = False,
has_custom_barrier: bool = False,
kernel_name: str | None = None,
kernel_regeneration_metadata: bytes | None = None,
vmem_limit_bytes: int | None = None,
flags: dict[str, bool | int | float] | None = None,
allow_input_fusion: list[bool] | None = None,
input_output_aliases: tuple[tuple[int, int], ...] = (),
serialization_format: int | None = 1,
):
"""Turns a low-level MLIR Mosaic kernel into a JAX-compatible function."""
unpack = False
if not isinstance(out_type, collections.abc.Iterable):
out_type = (out_type,)
unpack = True
out_avals = tuple(core.ShapedArray(ty.shape, ty.dtype) for ty in out_type)
def apply_kernel(*args, collective_id: int | None = None):
if has_custom_barrier:
if collective_id is None:
raise ValueError(
"collective_id has to be specified when using a custom barrier"
)
elif collective_id is not None:
raise ValueError(
"collective_id has to be unspecified or None when not using a custom"
" barrier"
)
config = CustomCallBackendConfig(
lowered_module_asm,
has_communication,
collective_id,
device_type,
cost_estimate,
needs_hlo_passes,
needs_layout_passes,
vmem_limit_bytes,
flags,
allow_input_fusion,
serialization_format=serialization_format,
)
result = tpu_custom_call_p.bind(
*args,
config=config,
kernel_name=kernel_name,
kernel_regeneration_metadata=kernel_regeneration_metadata,
out_avals=out_avals,
input_output_aliases=input_output_aliases,
)
return result[0] if unpack else result
return jax.jit(apply_kernel, static_argnames=["collective_id"])
def dump_mlir(module: ir.Module, name: str):
"""A helper function to dump mosaic mlir module"""
try:
should_dump = FLAGS["xla_mosaic_dump_to"].value
except KeyError:
return
if should_dump == "sponge":
outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None)
if outdir:
path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}.txt")
with open(path, "w") as f:
f.write(str(module))