mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
249 lines
10 KiB
Python
249 lines
10 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""A small library of helpers for use in jaxlib to build MLIR operations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
from functools import partial
|
|
from typing import Union
|
|
|
|
import jaxlib.mlir.ir as ir
|
|
import jaxlib.mlir.dialects.stablehlo as hlo
|
|
import numpy as np
|
|
|
|
|
|
_dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = {
|
|
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
|
|
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
|
|
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
|
|
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
|
|
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
|
|
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
|
|
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
|
|
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
|
|
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
|
|
np.dtype(np.float16): ir.F16Type.get,
|
|
np.dtype(np.float32): ir.F32Type.get,
|
|
np.dtype(np.float64): ir.F64Type.get,
|
|
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
|
|
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
|
|
}
|
|
def dtype_to_ir_type(dtype) -> ir.Type:
|
|
return _dtype_to_ir_type_factory[np.dtype(dtype)]()
|
|
|
|
|
|
def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
|
|
return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))
|
|
|
|
|
|
# When we generate custom calls with dynamic shapes we have to pass
|
|
# both the result_types, with ir.ShapedType.get_dynamic_size in place of
|
|
# the dynamic dimensions, and also result_shapes, which are ir.Value
|
|
# representing 1D int32 tensors. If all the shapes are static we can use
|
|
# result_shapes=None. We first construct for each result a pair with the shape
|
|
# and element type, the shape containing either integer or ir.Value.
|
|
DimensionSize = Union[int, ir.Value] # an ir.Value if not static dimension
|
|
ShapeTypePair = tuple[Sequence[DimensionSize], ir.Type]
|
|
|
|
def mk_result_types_and_shapes(
|
|
shape_type_pairs: Sequence[ShapeTypePair]
|
|
) -> tuple[list[ir.Type], list[ir.Value] | None]:
|
|
result_types: list[ir.Type] = []
|
|
result_shapes: list[ir.Value] = []
|
|
has_dynamic_shapes = any(
|
|
any(not isinstance(d, int) for d in rshape)
|
|
for rshape, _ in shape_type_pairs)
|
|
for (rshape, rtype) in shape_type_pairs:
|
|
if has_dynamic_shapes:
|
|
result_shapes.append(shape_tensor(rshape))
|
|
result_types.append(
|
|
ir.RankedTensorType.get(
|
|
[d if isinstance(d, int) else ir.ShapedType.get_dynamic_size()
|
|
for d in rshape],
|
|
rtype))
|
|
return (result_types,
|
|
result_shapes if has_dynamic_shapes else None)
|
|
|
|
# TODO(necula): share this with mlir.shape_tensor
|
|
def shape_tensor(sizes: Sequence[int | ir.Value]) -> ir.Value:
|
|
int1d = shape_dtype_to_ir_type((1,), np.int32)
|
|
i32_type = shape_dtype_to_ir_type((), np.int32)
|
|
def dim_to_i32x1(d):
|
|
if type(d) is int:
|
|
return hlo_const(np.array([d], dtype=np.int32))
|
|
else:
|
|
if d.type != i32_type:
|
|
d = hlo.convert(i32_type, d)
|
|
return hlo.reshape(int1d, d)
|
|
ds = [dim_to_i32x1(sz) for sz in sizes]
|
|
if not ds:
|
|
return hlo_const(np.array([], np.int32))
|
|
elif len(ds) == 1:
|
|
return ds[0]
|
|
else:
|
|
return hlo.concatenate(
|
|
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0))
|
|
|
|
def hlo_const(x: np.ndarray) -> ir.Value:
|
|
assert isinstance(x, np.ndarray)
|
|
return hlo.constant(
|
|
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)))
|
|
|
|
def hlo_u8(x: int):
|
|
return hlo_const(np.array(x, dtype=np.uint8))
|
|
def hlo_s32(x: int):
|
|
return hlo_const(np.array(x, dtype=np.int32))
|
|
|
|
def ensure_hlo_s32(x: DimensionSize):
|
|
return hlo_s32(x) if isinstance(x, int) else x
|
|
|
|
def dense_int_array(xs) -> ir.DenseI64ArrayAttr:
|
|
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
|
|
|
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
|
if type(x) is int:
|
|
if type(y) is int:
|
|
return min(x, y)
|
|
x = hlo_s32(x)
|
|
if type(y) is int:
|
|
y = hlo_s32(y)
|
|
return hlo.minimum(x, y)
|
|
|
|
|
|
def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
|
if type(x) is int:
|
|
if type(y) is int:
|
|
return x + y
|
|
x = hlo_s32(x)
|
|
if type(y) is int:
|
|
y = hlo_s32(y)
|
|
return hlo.add(x, y)
|
|
|
|
|
|
# TODO(necula): this is identical with mlir.custom_call, but meant for use
|
|
# in jaxlib. Find a way to share these implementations.
|
|
def custom_call(
|
|
call_target_name: str,
|
|
*,
|
|
result_types: Sequence[ir.Type],
|
|
operands: Sequence[ir.Value],
|
|
backend_config: str | bytes | dict[str, ir.Attribute] = "",
|
|
has_side_effect: bool = False,
|
|
result_shapes: Sequence[ir.Value] | None = None,
|
|
called_computations: Sequence[str] = (),
|
|
api_version: int = 2,
|
|
operand_output_aliases: dict[int, int] | None = None,
|
|
operand_layouts: Sequence[Sequence[int]] | None = None,
|
|
result_layouts: Sequence[Sequence[int]] | None = None,
|
|
extra_attributes: dict[str, ir.Attribute] | None = None,
|
|
) -> ir.Operation:
|
|
"""Helper function for building an hlo.CustomCall.
|
|
|
|
Args:
|
|
call_target_name: the name of the custom call target
|
|
result_types: the MLIR types of the results of the custom call
|
|
operands: the MLIR IR values that are arguments to the custom call
|
|
backend_config: an opaque string passed to the custom call kernel
|
|
has_side_effect: if True, marks the custom call as effectful
|
|
result_shapes: tensors that represent the result shapes, to be used when
|
|
the results have dynamic shapes. If not-None, its length must match the
|
|
number of the results.
|
|
called_computations: the list of function names called by the custom call.
|
|
api_version: the ABI contract version of the custom call
|
|
operand_output_aliases: a dict mapping operand numbers to outputs they alias
|
|
operand_layouts: a sequence of layouts (dimension orders) for each operand
|
|
result_layouts: a sequence of layouts (dimension orders) for each result
|
|
extra_attributes: additional IR attributes to apply to the custom_call.
|
|
"""
|
|
operands = list(operands)
|
|
|
|
if backend_config is None:
|
|
backend_config_attr = ir.StringAttr.get("")
|
|
elif isinstance(backend_config, (str, bytes)):
|
|
backend_config_attr = ir.StringAttr.get(backend_config)
|
|
elif isinstance(backend_config, dict):
|
|
# TODO(necula): it seems that the CustomCallOp constructor requires that
|
|
# backend_config_attr be a string attribute, even though in some cases we
|
|
# need it to be a DictAttr, e.g., for ApproxTopK on TPU.
|
|
# "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
|
|
# To workaround this limitation we first set it to the empty string and we
|
|
# use an unregistered attribute mhlo.backend_config to hold the DictAttr.
|
|
# We must also use api_version=1 to ensure that mhlo.backend_config is
|
|
# handled properly.
|
|
backend_config_attr = ir.StringAttr.get("")
|
|
api_version = 1
|
|
else:
|
|
raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
|
|
attributes = dict(
|
|
call_target_name=ir.StringAttr.get(call_target_name),
|
|
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
|
backend_config=backend_config_attr,
|
|
api_version=ir.IntegerAttr.get(
|
|
ir.IntegerType.get_signless(32), api_version),
|
|
called_computations=ir.ArrayAttr.get(
|
|
[ir.FlatSymbolRefAttr.get(name) for name in called_computations]
|
|
),
|
|
)
|
|
if operand_output_aliases is not None:
|
|
attributes["output_operand_aliases"] = ir.ArrayAttr.get([
|
|
hlo.OutputOperandAlias.get(
|
|
# if len(result_types) == 1 then the aliasing refers implicitly to
|
|
# the only output.
|
|
output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
|
|
operand_index=input_idx,
|
|
operand_tuple_indices=[],
|
|
)
|
|
for input_idx, output_idx in (operand_output_aliases.items() or ())
|
|
])
|
|
|
|
if extra_attributes is not None:
|
|
attributes.update(extra_attributes)
|
|
|
|
if result_shapes is not None:
|
|
# We add the result_shapes at the end of the operands, and must pass
|
|
# the indices_of_output_operands attribute. This attribute is not yet
|
|
# accepted by the CustomCall constructor, so we use build_generic
|
|
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
|
|
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
|
|
dtype=np.int64))
|
|
if operand_layouts is not None:
|
|
assert len(operand_layouts) == len(operands), (operand_layouts, operands)
|
|
operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
|
|
operands = list(operands) + list(result_shapes)
|
|
|
|
if operand_layouts is not None:
|
|
attributes["operand_layouts"] = ir.ArrayAttr.get([
|
|
ir.DenseIntElementsAttr.get(
|
|
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
|
type=ir.IndexType.get()) for l in operand_layouts
|
|
])
|
|
if result_layouts is not None:
|
|
assert result_layouts is not None
|
|
assert len(result_layouts) == len(result_types), (
|
|
result_layouts, result_types)
|
|
attributes["result_layouts"] = ir.ArrayAttr.get([
|
|
ir.DenseIntElementsAttr.get(
|
|
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
|
type=ir.IndexType.get()) for l in result_layouts
|
|
])
|
|
|
|
op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
|
|
attributes=attributes)
|
|
if isinstance(backend_config, dict):
|
|
backend_config_attr = ir.DictAttr.get(backend_config)
|
|
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
|
return op
|