rocm_jax/jaxlib/hlo_helpers.py
Eugene Burmako a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00

79 lines
2.8 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Helpers for building MLIR operators
from typing import Dict, Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
def custom_call(
call_target_name: str,
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
operand_layouts: Sequence[Sequence[int]],
result_layouts: Sequence[Sequence[int]],
backend_config: Optional[str] = None,
has_side_effect: bool = False,
api_version: int = 2,
operand_output_aliases: Dict[int, int] = {},
) -> Union[ir.Value, Sequence[ir.Value]]:
"""Less-verbose helper for building a CustomCallOp.
Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
may be able to go away.
Args:
...
operand_output_alias: a dictionary mapping input numbers -> output numbers
that must alias.
"""
i32_type = ir.IntegerType.get_signless(32)
out = hlo.CustomCallOp(
(out_types
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=ir.StringAttr.get(
"" if backend_config is None else backend_config),
api_version=ir.IntegerAttr.get(i32_type, api_version),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in operand_layouts
]),
result_layouts=ir.ArrayAttr.get([
ir.DenseIntElementsAttr.get(
np.atleast_1d(np.asarray(l, dtype=np.int64)),
type=ir.IndexType.get()) for l in result_layouts
]),
output_operand_aliases=ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(
output_tuple_indices=[] if len(out_types) == 1 else [output],
operand_index=input,
operand_tuple_indices=[])
for input, output in operand_output_aliases.items()
]))
if len(out_types) == 1:
return out.result
else:
return [
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(len(out_types))
]