rocm_jax/jaxlib/hlo_helpers.py
Peter Hawkins 7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00

249 lines
10 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A small library of helpers for use in jaxlib to build MLIR operations."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
from typing import Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
_dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
def dtype_to_ir_type(dtype) -> ir.Type:
return _dtype_to_ir_type_factory[np.dtype(dtype)]()
def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))
# When we generate custom calls with dynamic shapes we have to pass
# both the result_types, with ir.ShapedType.get_dynamic_size in place of
# the dynamic dimensions, and also result_shapes, which are ir.Value
# representing 1D int32 tensors. If all the shapes are static we can use
# result_shapes=None. We first construct for each result a pair with the shape
# and element type, the shape containing either integer or ir.Value.
DimensionSize = Union[int, ir.Value] # an ir.Value if not static dimension
ShapeTypePair = tuple[Sequence[DimensionSize], ir.Type]
def mk_result_types_and_shapes(
shape_type_pairs: Sequence[ShapeTypePair]
) -> tuple[list[ir.Type], list[ir.Value] | None]:
result_types: list[ir.Type] = []
result_shapes: list[ir.Value] = []
has_dynamic_shapes = any(
any(not isinstance(d, int) for d in rshape)
for rshape, _ in shape_type_pairs)
for (rshape, rtype) in shape_type_pairs:
if has_dynamic_shapes:
result_shapes.append(shape_tensor(rshape))
result_types.append(
ir.RankedTensorType.get(
[d if isinstance(d, int) else ir.ShapedType.get_dynamic_size()
for d in rshape],
rtype))
return (result_types,
result_shapes if has_dynamic_shapes else None)
# TODO(necula): share this with mlir.shape_tensor
def shape_tensor(sizes: Sequence[int | ir.Value]) -> ir.Value:
int1d = shape_dtype_to_ir_type((1,), np.int32)
i32_type = shape_dtype_to_ir_type((), np.int32)
def dim_to_i32x1(d):
if type(d) is int:
return hlo_const(np.array([d], dtype=np.int32))
else:
if d.type != i32_type:
d = hlo.convert(i32_type, d)
return hlo.reshape(int1d, d)
ds = [dim_to_i32x1(sz) for sz in sizes]
if not ds:
return hlo_const(np.array([], np.int32))
elif len(ds) == 1:
return ds[0]
else:
return hlo.concatenate(
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0))
def hlo_const(x: np.ndarray) -> ir.Value:
assert isinstance(x, np.ndarray)
return hlo.constant(
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)))
def hlo_u8(x: int):
return hlo_const(np.array(x, dtype=np.uint8))
def hlo_s32(x: int):
return hlo_const(np.array(x, dtype=np.int32))
def ensure_hlo_s32(x: DimensionSize):
return hlo_s32(x) if isinstance(x, int) else x
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
if type(x) is int:
if type(y) is int:
return min(x, y)
x = hlo_s32(x)
if type(y) is int:
y = hlo_s32(y)
return hlo.minimum(x, y)
def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
if type(x) is int:
if type(y) is int:
return x + y
x = hlo_s32(x)
if type(y) is int:
y = hlo_s32(y)
return hlo.add(x, y)
# TODO(necula): this is identical with mlir.custom_call, but meant for use
# in jaxlib. Find a way to share these implementations.
def custom_call(
call_target_name: str,
*,
result_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
backend_config: str | bytes | dict[str, ir.Attribute] = "",
has_side_effect: bool = False,
result_shapes: Sequence[ir.Value] | None = None,
called_computations: Sequence[str] = (),
api_version: int = 2,
operand_output_aliases: dict[int, int] | None = None,
operand_layouts: Sequence[Sequence[int]] | None = None,
result_layouts: Sequence[Sequence[int]] | None = None,
extra_attributes: dict[str, ir.Attribute] | None = None,
) -> ir.Operation:
"""Helper function for building an hlo.CustomCall.
Args:
call_target_name: the name of the custom call target
result_types: the MLIR types of the results of the custom call
operands: the MLIR IR values that are arguments to the custom call
backend_config: an opaque string passed to the custom call kernel
has_side_effect: if True, marks the custom call as effectful
result_shapes: tensors that represent the result shapes, to be used when
the results have dynamic shapes. If not-None, its length must match the
number of the results.
called_computations: the list of function names called by the custom call.
api_version: the ABI contract version of the custom call
operand_output_aliases: a dict mapping operand numbers to outputs they alias
operand_layouts: a sequence of layouts (dimension orders) for each operand
result_layouts: a sequence of layouts (dimension orders) for each result
extra_attributes: additional IR attributes to apply to the custom_call.
"""
operands = list(operands)
if backend_config is None:
backend_config_attr = ir.StringAttr.get("")
elif isinstance(backend_config, (str, bytes)):
backend_config_attr = ir.StringAttr.get(backend_config)
elif isinstance(backend_config, dict):
# TODO(necula): it seems that the CustomCallOp constructor requires that
# backend_config_attr be a string attribute, even though in some cases we
# need it to be a DictAttr, e.g., for ApproxTopK on TPU.
# "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
# To workaround this limitation we first set it to the empty string and we
# use an unregistered attribute mhlo.backend_config to hold the DictAttr.
# We must also use api_version=1 to ensure that mhlo.backend_config is
# handled properly.
backend_config_attr = ir.StringAttr.get("")
api_version = 1
else:
raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
attributes = dict(
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=backend_config_attr,
api_version=ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), api_version),
called_computations=ir.ArrayAttr.get(
[ir.FlatSymbolRefAttr.get(name) for name in called_computations]
),
)
if operand_output_aliases is not None:
attributes["output_operand_aliases"] = ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(
# if len(result_types) == 1 then the aliasing refers implicitly to
# the only output.
output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
operand_index=input_idx,
operand_tuple_indices=[],
)
for input_idx, output_idx in (operand_output_aliases.items() or ())
])
if extra_attributes is not None:
attributes.update(extra_attributes)
if result_shapes is not None:
# We add the result_shapes at the end of the operands, and must pass
# the indices_of_output_operands attribute. This attribute is not yet
# accepted by the CustomCall constructor, so we use build_generic
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
dtype=np.int64))
if operand_layouts is not None:
assert len(operand_layouts) == len(operands), (operand_layouts, operands)
operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
operands = list(operands) + list(result_shapes)
if operand_layouts is not None:
attributes["operand_layouts"] = ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in operand_layouts
])
if result_layouts is not None:
assert result_layouts is not None
assert len(result_layouts) == len(result_types), (
result_layouts, result_types)
attributes["result_layouts"] = ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in result_layouts
])
op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
attributes=attributes)
if isinstance(backend_config, dict):
backend_config_attr = ir.DictAttr.get(backend_config)
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
return op