mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

We are migrating some attrs on some StableHLO ops to use DenseI64ArrayAttr instead of DenseIntElementsAttr. Using DenseI64ArrayAttr enforces that the attr values are 1-dimensional and provides nicer APIs. (see https://github.com/openxla/stablehlo/issues/1578 for additional context) Unfortunately, we have to duplicate the `dense_int_array` function because we migrated the ops in batches. We can't use the existing `dense_int_array` function because it would produce arrays for ops that hadn't yet been migrated. This PR makes the final batch of changes, so no additional methods should be added going forward. We also have to introduce a new `dense_bool_array` function, with a similar version check. When the minimum supported jaxlib version uses a recent enough version of StableHLO (v6 or above), it will be possible to remove the version checks and remove the duplicated `dense_int_array_v6` function. PiperOrigin-RevId: 601271749
258 lines
11 KiB
Python
258 lines
11 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""A small library of helpers for use in jaxlib to build MLIR operations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from functools import partial
|
|
from typing import Callable, Union
|
|
|
|
import jaxlib.mlir.ir as ir
|
|
import jaxlib.mlir.dialects.stablehlo as hlo
|
|
import numpy as np
|
|
|
|
|
|
_dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = {
|
|
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
|
|
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
|
|
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
|
|
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
|
|
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
|
|
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
|
|
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
|
|
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
|
|
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
|
|
np.dtype(np.float16): ir.F16Type.get,
|
|
np.dtype(np.float32): ir.F32Type.get,
|
|
np.dtype(np.float64): ir.F64Type.get,
|
|
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
|
|
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
|
|
}
|
|
def dtype_to_ir_type(dtype) -> ir.Type:
|
|
return _dtype_to_ir_type_factory[np.dtype(dtype)]()
|
|
|
|
|
|
def shape_dtype_to_ir_type(shape: Sequence[int], dtype) -> ir.Type:
|
|
return ir.RankedTensorType.get(shape, dtype_to_ir_type(dtype))
|
|
|
|
|
|
# When we generate custom calls with dynamic shapes we have to pass
|
|
# both the result_types, with ir.ShapedType.get_dynamic_size in place of
|
|
# the dynamic dimensions, and also result_shapes, which are ir.Value
|
|
# representing 1D int32 tensors. If all the shapes are static we can use
|
|
# result_shapes=None. We first construct for each result a pair with the shape
|
|
# and element type, the shape containing either integer or ir.Value.
|
|
DimensionSize = Union[int, ir.Value] # an ir.Value if not static dimension
|
|
ShapeTypePair = tuple[Sequence[DimensionSize], ir.Type]
|
|
|
|
def mk_result_types_and_shapes(
|
|
shape_type_pairs: Sequence[ShapeTypePair]
|
|
) -> tuple[list[ir.Type], list[ir.Value] | None]:
|
|
result_types: list[ir.Type] = []
|
|
result_shapes: list[ir.Value] = []
|
|
has_dynamic_shapes = any(
|
|
any(not isinstance(d, int) for d in rshape)
|
|
for rshape, _ in shape_type_pairs)
|
|
for (rshape, rtype) in shape_type_pairs:
|
|
if has_dynamic_shapes:
|
|
result_shapes.append(shape_tensor(rshape))
|
|
result_types.append(
|
|
ir.RankedTensorType.get(
|
|
[d if isinstance(d, int) else ir.ShapedType.get_dynamic_size()
|
|
for d in rshape],
|
|
rtype))
|
|
return (result_types,
|
|
result_shapes if has_dynamic_shapes else None)
|
|
|
|
# TODO(necula): share this with mlir.shape_tensor
|
|
def shape_tensor(sizes: Sequence[int | ir.Value]) -> ir.Value:
|
|
int1d = shape_dtype_to_ir_type((1,), np.int32)
|
|
i32_type = shape_dtype_to_ir_type((), np.int32)
|
|
def dim_to_i32x1(d):
|
|
if type(d) is int:
|
|
return hlo_const(np.array([d], dtype=np.int32))
|
|
else:
|
|
if d.type != i32_type:
|
|
d = hlo.convert(i32_type, d)
|
|
return hlo.reshape(int1d, d)
|
|
ds = [dim_to_i32x1(sz) for sz in sizes]
|
|
if not ds:
|
|
return hlo_const(np.array([], np.int32))
|
|
elif len(ds) == 1:
|
|
return ds[0]
|
|
else:
|
|
return hlo.concatenate(
|
|
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0))
|
|
|
|
def hlo_const(x: np.ndarray) -> ir.Value:
|
|
assert isinstance(x, np.ndarray)
|
|
return hlo.constant(
|
|
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)))
|
|
|
|
def hlo_u8(x: int):
|
|
return hlo_const(np.array(x, dtype=np.uint8))
|
|
def hlo_s32(x: int):
|
|
return hlo_const(np.array(x, dtype=np.int32))
|
|
|
|
def ensure_hlo_s32(x: DimensionSize):
|
|
return hlo_s32(x) if isinstance(x, int) else x
|
|
|
|
def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
|
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
|
|
if hlo.get_api_version() < 5:
|
|
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
|
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
|
|
|
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
|
|
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
|
if hlo.get_api_version() < 6:
|
|
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
|
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
|
|
|
def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
|
if type(x) is int:
|
|
if type(y) is int:
|
|
return min(x, y)
|
|
x = hlo_s32(x)
|
|
if type(y) is int:
|
|
y = hlo_s32(y)
|
|
return hlo.minimum(x, y)
|
|
|
|
|
|
def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
|
if type(x) is int:
|
|
if type(y) is int:
|
|
return x + y
|
|
x = hlo_s32(x)
|
|
if type(y) is int:
|
|
y = hlo_s32(y)
|
|
return hlo.add(x, y)
|
|
|
|
|
|
# TODO(necula): this is identical with mlir.custom_call, but meant for use
|
|
# in jaxlib. Find a way to share these implementations.
|
|
def custom_call(
|
|
call_target_name: str,
|
|
*,
|
|
result_types: Sequence[ir.Type],
|
|
operands: Sequence[ir.Value],
|
|
backend_config: str | bytes | dict[str, ir.Attribute] = "",
|
|
has_side_effect: bool = False,
|
|
result_shapes: Sequence[ir.Value] | None = None,
|
|
called_computations: Sequence[str] = (),
|
|
api_version: int = 2,
|
|
operand_output_aliases: dict[int, int] | None = None,
|
|
operand_layouts: Sequence[Sequence[int]] | None = None,
|
|
result_layouts: Sequence[Sequence[int]] | None = None,
|
|
extra_attributes: dict[str, ir.Attribute] | None = None,
|
|
) -> ir.Operation:
|
|
"""Helper function for building an hlo.CustomCall.
|
|
|
|
Args:
|
|
call_target_name: the name of the custom call target
|
|
result_types: the MLIR types of the results of the custom call
|
|
operands: the MLIR IR values that are arguments to the custom call
|
|
backend_config: an opaque string passed to the custom call kernel
|
|
has_side_effect: if True, marks the custom call as effectful
|
|
result_shapes: tensors that represent the result shapes, to be used when
|
|
the results have dynamic shapes. If not-None, its length must match the
|
|
number of the results.
|
|
called_computations: the list of function names called by the custom call.
|
|
api_version: the ABI contract version of the custom call
|
|
operand_output_aliases: a dict mapping operand numbers to outputs they alias
|
|
operand_layouts: a sequence of layouts (dimension orders) for each operand
|
|
result_layouts: a sequence of layouts (dimension orders) for each result
|
|
extra_attributes: additional IR attributes to apply to the custom_call.
|
|
"""
|
|
operands = list(operands)
|
|
|
|
if backend_config is None:
|
|
backend_config_attr = ir.StringAttr.get("")
|
|
elif isinstance(backend_config, (str, bytes)):
|
|
backend_config_attr = ir.StringAttr.get(backend_config)
|
|
elif isinstance(backend_config, dict):
|
|
# TODO(necula): it seems that the CustomCallOp constructor requires that
|
|
# backend_config_attr be a string attribute, even though in some cases we
|
|
# need it to be a DictAttr, e.g., for ApproxTopK on TPU.
|
|
# "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
|
|
# To workaround this limitation we first set it to the empty string and we
|
|
# use an unregistered attribute mhlo.backend_config to hold the DictAttr.
|
|
# We must also use api_version=1 to ensure that mhlo.backend_config is
|
|
# handled properly.
|
|
backend_config_attr = ir.StringAttr.get("")
|
|
api_version = 1
|
|
else:
|
|
raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
|
|
attributes = dict(
|
|
call_target_name=ir.StringAttr.get(call_target_name),
|
|
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
|
backend_config=backend_config_attr,
|
|
api_version=ir.IntegerAttr.get(
|
|
ir.IntegerType.get_signless(32), api_version),
|
|
called_computations=ir.ArrayAttr.get(
|
|
[ir.FlatSymbolRefAttr.get(name) for name in called_computations]
|
|
),
|
|
)
|
|
if operand_output_aliases is not None:
|
|
attributes["output_operand_aliases"] = ir.ArrayAttr.get([
|
|
hlo.OutputOperandAlias.get(
|
|
# if len(result_types) == 1 then the aliasing refers implicitly to
|
|
# the only output.
|
|
output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
|
|
operand_index=input_idx,
|
|
operand_tuple_indices=[],
|
|
)
|
|
for input_idx, output_idx in (operand_output_aliases.items() or ())
|
|
])
|
|
|
|
if extra_attributes is not None:
|
|
attributes.update(extra_attributes)
|
|
|
|
if result_shapes is not None:
|
|
# We add the result_shapes at the end of the operands, and must pass
|
|
# the indices_of_output_operands attribute. This attribute is not yet
|
|
# accepted by the CustomCall constructor, so we use build_generic
|
|
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
|
|
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
|
|
dtype=np.int64))
|
|
if operand_layouts is not None:
|
|
assert len(operand_layouts) == len(operands), (operand_layouts, operands)
|
|
operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
|
|
operands = list(operands) + list(result_shapes)
|
|
|
|
if operand_layouts is not None:
|
|
attributes["operand_layouts"] = ir.ArrayAttr.get([
|
|
ir.DenseIntElementsAttr.get(
|
|
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
|
type=ir.IndexType.get()) for l in operand_layouts
|
|
])
|
|
if result_layouts is not None:
|
|
assert result_layouts is not None
|
|
assert len(result_layouts) == len(result_types), (
|
|
result_layouts, result_types)
|
|
attributes["result_layouts"] = ir.ArrayAttr.get([
|
|
ir.DenseIntElementsAttr.get(
|
|
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
|
type=ir.IndexType.get()) for l in result_layouts
|
|
])
|
|
|
|
op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
|
|
attributes=attributes)
|
|
if isinstance(backend_config, dict):
|
|
backend_config_attr = ir.DictAttr.get(backend_config)
|
|
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
|
return op
|