mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call().
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities. Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules. This function has two benefits over just building the stablehlo directly: a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults). Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper. PiperOrigin-RevId: 561042402
This commit is contained in:
parent
ba8c4c3086
commit
d0a6813ea2
@ -1157,9 +1157,8 @@ def wrap_with_memory_kind(
|
||||
result_type = x.type
|
||||
else:
|
||||
result_type = aval_to_ir_type(aval_out)
|
||||
op = custom_call("annotate_device_placement", [result_type], [x],
|
||||
has_side_effect=False,
|
||||
api_version=1)
|
||||
op = custom_call("annotate_device_placement", result_types=[result_type],
|
||||
operands=[x], api_version=1)
|
||||
mka = get_compute_type(memory_kind)
|
||||
dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)}
|
||||
if is_input and mka == 'host':
|
||||
@ -1698,9 +1697,8 @@ def _wrap_with_spmd_op(name: str,
|
||||
else:
|
||||
result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)]
|
||||
|
||||
op = custom_call(name, [result_type], [x],
|
||||
op = custom_call(name, result_types=[result_type], operands=[x],
|
||||
backend_config=backend_config,
|
||||
has_side_effect=False,
|
||||
api_version=1,
|
||||
result_shapes=result_shapes)
|
||||
set_sharding(op, sharding_proto)
|
||||
@ -2153,27 +2151,42 @@ def build_xla_computation_helper(
|
||||
|
||||
def custom_call(
|
||||
call_target_name: str,
|
||||
out_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
*,
|
||||
backend_config: str | dict[str, ir.Attribute] = "",
|
||||
result_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
backend_config: str | bytes | dict[str, ir.Attribute] = "",
|
||||
has_side_effect: bool = False,
|
||||
result_shapes: Sequence[ir.Value] | None = None,
|
||||
called_computations: Sequence[str] = (),
|
||||
api_version: int = 2,
|
||||
extra_attributes: dict[str, ir.Attribute] = {},
|
||||
operand_output_aliases: dict[int, int] | None = None,
|
||||
operand_layouts: Sequence[Sequence[int]] | None = None,
|
||||
result_layouts: Sequence[Sequence[int]] | None = None,
|
||||
extra_attributes: dict[str, ir.Attribute] | None = None,
|
||||
) -> ir.Operation:
|
||||
"""Wraps a hlo.CustomCall.
|
||||
"""Helper function for building an hlo.CustomCall.
|
||||
|
||||
Args:
|
||||
call_target_name: the name of the custom call target
|
||||
result_types: the MLIR types of the results of the custom call
|
||||
operands: the MLIR IR values that are arguments to the custom call
|
||||
backend_config: an opaque string passed to the custom call kernel
|
||||
has_side_effect: if True, marks the custom call as effectful
|
||||
result_shapes: tensors that represent the result shapes, to be used when
|
||||
the results have dynamic shapes. If not-None, its length must match the
|
||||
number of the results.
|
||||
called_computations: the list of function names called by the custom call.
|
||||
api_version: the ABI contract version of the custom call
|
||||
operand_output_aliases: a dict mapping operand numbers to outputs they alias
|
||||
operand_layouts: a sequence of layouts (dimension orders) for each operand
|
||||
result_layouts: a sequence of layouts (dimension orders) for each result
|
||||
extra_attributes: additional IR attributes to apply to the custom_call.
|
||||
"""
|
||||
operands = list(operands)
|
||||
|
||||
if backend_config is None:
|
||||
backend_config_attr = ir.StringAttr.get("")
|
||||
elif isinstance(backend_config, str):
|
||||
elif isinstance(backend_config, (str, bytes)):
|
||||
backend_config_attr = ir.StringAttr.get(backend_config)
|
||||
elif isinstance(backend_config, dict):
|
||||
# TODO(necula): it seems that the CustomCallOp constructor requires that
|
||||
@ -2193,9 +2206,23 @@ def custom_call(
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
backend_config=backend_config_attr,
|
||||
api_version=i32_attr(api_version),
|
||||
called_computations=ir.ArrayAttr.get([
|
||||
ir.FlatSymbolRefAttr.get(name) for name in called_computations]),
|
||||
called_computations=ir.ArrayAttr.get(
|
||||
[ir.FlatSymbolRefAttr.get(name) for name in called_computations]
|
||||
),
|
||||
)
|
||||
if operand_output_aliases is not None:
|
||||
attributes["output_operand_aliases"] = ir.ArrayAttr.get([
|
||||
hlo.OutputOperandAlias.get(
|
||||
# if len(result_types) == 1 then the aliasing refers implicitly to
|
||||
# the only output.
|
||||
output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
|
||||
operand_index=input_idx,
|
||||
operand_tuple_indices=[],
|
||||
)
|
||||
for input_idx, output_idx in (operand_output_aliases.items() or ())
|
||||
])
|
||||
|
||||
if extra_attributes is not None:
|
||||
attributes.update(extra_attributes)
|
||||
|
||||
if result_shapes is not None:
|
||||
@ -2205,9 +2232,29 @@ def custom_call(
|
||||
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
|
||||
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
|
||||
dtype=np.int64))
|
||||
if operand_layouts is not None:
|
||||
assert len(operand_layouts) == len(operands), (operand_layouts, operands)
|
||||
operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
|
||||
operands = list(operands) + list(result_shapes)
|
||||
|
||||
op = hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
|
||||
if operand_layouts is not None:
|
||||
attributes["operand_layouts"] = ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(
|
||||
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
||||
type=ir.IndexType.get()) for l in operand_layouts
|
||||
])
|
||||
if result_layouts is not None:
|
||||
assert result_layouts is not None
|
||||
assert len(result_layouts) == len(result_types), (
|
||||
result_layouts, result_types)
|
||||
attributes["result_layouts"] = ir.ArrayAttr.get([
|
||||
ir.DenseIntElementsAttr.get(
|
||||
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
||||
type=ir.IndexType.get()) for l in result_layouts
|
||||
])
|
||||
|
||||
op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
|
||||
attributes=attributes)
|
||||
if isinstance(backend_config, dict):
|
||||
backend_config_attr = ir.DictAttr.get(backend_config)
|
||||
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
||||
@ -2251,8 +2298,8 @@ def reduce_window(
|
||||
|
||||
rw = custom_call(
|
||||
"stablehlo.dynamic_reduce_window",
|
||||
list(map(aval_to_ir_type, out_avals)),
|
||||
[
|
||||
result_types=list(map(aval_to_ir_type, out_avals)),
|
||||
operands=[
|
||||
*operands, *init_values,
|
||||
eval_dynamic_shape_as_tensor(ctx, window_dimensions),
|
||||
eval_dynamic_shape_as_tensor(ctx, window_strides),
|
||||
|
@ -347,8 +347,8 @@ def _approx_top_k_lowering(ctx, operand, *, k,
|
||||
|
||||
out = mlir.custom_call(
|
||||
"ApproxTopK",
|
||||
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
[operand, iota, init_val, init_arg],
|
||||
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
operands=[operand, iota, init_val, init_arg],
|
||||
called_computations=[comparator.name.value],
|
||||
backend_config=backend_config,
|
||||
result_shapes=result_shapes)
|
||||
|
@ -4233,9 +4233,9 @@ def _top_k_lower(ctx, operand, k):
|
||||
out_values_aval, out_indices_aval, = ctx.avals_out
|
||||
return mlir.custom_call(
|
||||
"stablehlo.dynamic_top_k",
|
||||
[mlir.aval_to_ir_type(out_values_aval),
|
||||
result_types=[mlir.aval_to_ir_type(out_values_aval),
|
||||
mlir.aval_to_ir_type(out_indices_aval)],
|
||||
[operand, k_value]).results
|
||||
operands=[operand, k_value]).results
|
||||
|
||||
mlir.register_lowering(top_k_p, _top_k_lower)
|
||||
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
||||
@ -4499,9 +4499,9 @@ def _rng_bit_generator_lowering(
|
||||
mlir.eval_dynamic_shape(ctx, out_vals_aval.shape))
|
||||
out_key, out_vals = mlir.custom_call(
|
||||
"stablehlo.dynamic_rng_bit_generator",
|
||||
[key.type,
|
||||
result_types=[key.type,
|
||||
mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))],
|
||||
[key, output_shape],
|
||||
operands=[key, output_shape],
|
||||
extra_attributes=dict(rng_algorithm=algorithm_attr)).results
|
||||
else:
|
||||
out_key, out_vals = hlo.RngBitGeneratorOp(
|
||||
|
@ -644,8 +644,8 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
|
||||
result_shapes = None
|
||||
op = mlir.custom_call(
|
||||
"Eigh",
|
||||
result_types,
|
||||
[operand],
|
||||
result_types=result_types,
|
||||
operands=[operand],
|
||||
backend_config=backend_config,
|
||||
api_version=1,
|
||||
result_shapes=result_shapes,
|
||||
@ -1301,8 +1301,8 @@ def _lu_tpu_lowering_rule(ctx, operand):
|
||||
result_shapes = None
|
||||
op = mlir.custom_call(
|
||||
"LuDecomposition",
|
||||
result_types,
|
||||
[operand],
|
||||
result_types=result_types,
|
||||
operands=[operand],
|
||||
result_shapes=result_shapes)
|
||||
return op.results
|
||||
|
||||
@ -1436,8 +1436,8 @@ def _geqrf_lowering_rule(ctx, operand):
|
||||
result_shapes = None
|
||||
op = mlir.custom_call(
|
||||
"Qr",
|
||||
result_types,
|
||||
[operand],
|
||||
result_types=result_types,
|
||||
operands=[operand],
|
||||
api_version=1,
|
||||
result_shapes=result_shapes
|
||||
)
|
||||
@ -1561,8 +1561,8 @@ def _householder_product_lowering_rule(ctx, a, taus):
|
||||
result_shapes = None
|
||||
op = mlir.custom_call(
|
||||
"ProductOfElementaryHouseholderReflectors",
|
||||
[mlir.aval_to_ir_type(aval_out)],
|
||||
[a, taus],
|
||||
result_types=[mlir.aval_to_ir_type(aval_out)],
|
||||
operands=[a, taus],
|
||||
api_version=1,
|
||||
result_shapes=result_shapes)
|
||||
return [op.result]
|
||||
|
@ -830,8 +830,8 @@ def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext,
|
||||
error_message: str):
|
||||
op = mlir.custom_call(
|
||||
"shape_assertion",
|
||||
[], # No results
|
||||
[assert_what, *error_message_inputs],
|
||||
result_types=[], # No results
|
||||
operands=[assert_what, *error_message_inputs],
|
||||
has_side_effect=True,
|
||||
extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message))
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ from jax._src.interpreters.mlir import (
|
||||
aval_to_ir_type as aval_to_ir_type,
|
||||
aval_to_ir_types as aval_to_ir_types,
|
||||
core_call_lowering as core_call_lowering,
|
||||
custom_call as custom_call,
|
||||
dense_bool_elements as dense_bool_elements,
|
||||
dense_int_elements as dense_int_elements,
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
|
83
tests/filecheck/custom_call.filecheck.py
Normal file
83
tests/filecheck/custom_call.filecheck.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for mlir.custom_call().
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from absl import app
|
||||
|
||||
import jax
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
import numpy as np
|
||||
|
||||
ShapedArray = jax.core.ShapedArray
|
||||
|
||||
def print_custom_call(name, arg_avals, result_avals, **kw):
|
||||
print(f"TEST: {name}")
|
||||
ctx = mlir.make_ir_context()
|
||||
loc = ir.Location.unknown(context=ctx)
|
||||
with ctx, loc:
|
||||
module = ir.Module.create(loc=ir.Location.unknown())
|
||||
ip = ir.InsertionPoint(module.body)
|
||||
arg_types = [mlir.aval_to_ir_type(aval) for aval in arg_avals]
|
||||
result_types = [mlir.aval_to_ir_type(aval) for aval in result_avals]
|
||||
ftype = ir.FunctionType.get(arg_types, result_types)
|
||||
func = func_dialect.FuncOp("func", ftype, ip=ip)
|
||||
entry_block = func.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
outs = mlir.custom_call(
|
||||
name, result_types=result_types, operands=entry_block.arguments, **kw
|
||||
)
|
||||
func_dialect.ReturnOp(outs)
|
||||
module.operation.verify()
|
||||
print(str(module))
|
||||
|
||||
def main(_):
|
||||
aval1 = ShapedArray((2, 3), np.dtype(np.float32))
|
||||
aval2 = ShapedArray((3, 4), np.dtype(np.int64))
|
||||
# CHECK-LABEL: TEST: simple
|
||||
# CHECK: stablehlo.custom_call @simple(%arg0) {api_version = 2 : i32} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
print_custom_call("simple", [aval1], [aval2])
|
||||
|
||||
# CHECK-LABEL: TEST: sideeffect
|
||||
# CHECK: stablehlo.custom_call @sideeffect(%arg0) {has_side_effect = true} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
print_custom_call("sideeffect", [aval1], [aval2], api_version=1,
|
||||
has_side_effect=True)
|
||||
|
||||
# CHECK-LABEL: TEST: backendconfig
|
||||
# CHECK: stablehlo.custom_call @backendconfig(%arg0) {backend_config = "hello"} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
print_custom_call("backendconfig", [aval1], [aval2], api_version=1,
|
||||
backend_config=b"hello")
|
||||
|
||||
# CHECK-LABEL: TEST: calledcomputations
|
||||
# CHECK: stablehlo.custom_call @calledcomputations(%arg0) {called_computations = [@a, @b]} : (tensor<2x3xf32>) -> tensor<3x4xi64>
|
||||
print_custom_call("calledcomputations", [aval1], [aval2], api_version=1,
|
||||
called_computations=["a", "b"])
|
||||
|
||||
# CHECK-LABEL: TEST: aliases
|
||||
# CHECK: stablehlo.custom_call @aliases(%arg0, %arg1) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 1, operand_tuple_indices = []>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
|
||||
print_custom_call("aliases", [aval1, aval2], [aval2, aval1], api_version=1,
|
||||
operand_output_aliases={1: 0})
|
||||
|
||||
# CHECK-LABEL: TEST: layouts
|
||||
# CHECK: stablehlo.custom_call @layouts(%arg0, %arg1) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>]} : (tensor<2x3xf32>, tensor<3x4xi64>) -> (tensor<3x4xi64>, tensor<2x3xf32>)
|
||||
print_custom_call("layouts", [aval1, aval2], [aval2, aval1], api_version=1,
|
||||
operand_layouts=[[0, 1], [1, 0]],
|
||||
result_layouts=[[1, 0], [0, 1]])
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
Loading…
x
Reference in New Issue
Block a user