mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
2779 lines
115 KiB
Python
2779 lines
115 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Lowering and execution path that converts jaxprs into MLIR.
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
from collections.abc import Iterator, Sequence
|
|
import dataclasses
|
|
import functools
|
|
from functools import partial
|
|
import io
|
|
import itertools
|
|
import operator
|
|
import os
|
|
import re
|
|
import types
|
|
import typing
|
|
from typing import Any, Callable, NamedTuple, Protocol, Union
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import effects as effects_lib
|
|
from jax._src import linear_util as lu
|
|
from jax._src import path
|
|
from jax._src import pickle_util
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import util
|
|
from jax._src import xla_bridge as xb
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import xla
|
|
from jax._src.layout import AutoLayout, SpecifiedLayout
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.lib import xla_extension
|
|
from jax._src.lib import xla_extension_version
|
|
from jax._src.lib.mlir import dialects
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
from jax._src.lib.mlir import register_jax_dialects
|
|
from jax._src.sharding_impls import XLACompatibleSharding
|
|
from jax._src.state.types import AbstractRef
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
T = typing.TypeVar("T")
|
|
|
|
Value = Any # = ir.Value
|
|
|
|
# mypy implicitly sets this variable to true when type checking.
|
|
MYPY = False
|
|
|
|
_JAX_DUMP_IR_TO = config.DEFINE_string(
|
|
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
|
help="Path to which the IR that is emitted by JAX should be dumped as "
|
|
"text files. If omitted, JAX will not dump IR. "
|
|
"Supports the special value 'sponge' to pick the path from the "
|
|
"environment variable TEST_UNDECLARED_OUTPUTS_DIR.")
|
|
|
|
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
|
|
|
|
|
|
# IR Helpers
|
|
|
|
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
|
|
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
|
|
|
def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
|
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v5 or higher
|
|
if hlo.get_api_version() < 5:
|
|
return dense_int_elements(xs)
|
|
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
|
|
|
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
|
|
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
|
if hlo.get_api_version() < 6 or xc.mlir_api_version < 55:
|
|
return dense_int_elements(xs)
|
|
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
|
|
|
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
|
|
a = np.packbits(np.array(xs, np.bool_), bitorder='little')
|
|
# TODO(b/209005197): Work around for MLIR crash for non-splat single element
|
|
# buffers.
|
|
if len(xs) == 1:
|
|
a = np.array(0 if a.item() == 0 else 0xff, np.uint8)
|
|
return ir.DenseElementsAttr.get(
|
|
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
|
|
|
|
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr:
|
|
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
|
|
if hlo.get_api_version() < 6 or xc.mlir_api_version < 55:
|
|
return dense_bool_elements(xs)
|
|
return ir.DenseBoolArrayAttr.get(xs)
|
|
|
|
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
|
|
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
|
|
|
|
def shape_tensor(sizes: Sequence[int | ir.RankedTensorType]
|
|
) -> ir.RankedTensorType:
|
|
int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32))
|
|
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
|
|
def lower_dim(d):
|
|
if type(d) is int:
|
|
return ir_constant(np.array([d], np.int32))
|
|
else:
|
|
if d.type != i32_type:
|
|
d = hlo.convert(i32_type, d)
|
|
return hlo.reshape(int1d, d)
|
|
ds = map(lower_dim, sizes)
|
|
if not ds:
|
|
return ir_constant(np.array([], np.int32))
|
|
elif len(ds) == 1:
|
|
return ds[0]
|
|
else:
|
|
return hlo.concatenate(ds, i64_attr(0))
|
|
|
|
|
|
def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
|
|
"""Side-effects on `ctx`"""
|
|
ctx_new = ctx.replace(**ctx_override_kwargs)
|
|
out = lowering_fun(ctx_new, *args)
|
|
ctx.set_tokens_out(ctx_new.tokens_out)
|
|
return out
|
|
|
|
|
|
# IR Types
|
|
|
|
# Non-canonicalized dtype to IR type mapping.
|
|
_dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
|
|
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
|
|
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
|
|
np.dtype(dtypes.int4): partial(ir.IntegerType.get_signless, 4),
|
|
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
|
|
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
|
|
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
|
|
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
|
|
np.dtype(dtypes.uint4): partial(ir.IntegerType.get_unsigned, 4),
|
|
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
|
|
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
|
|
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
|
|
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
|
|
np.dtype(dtypes.float8_e4m3b11fnuz): ir.Float8E4M3B11FNUZType.get,
|
|
np.dtype(dtypes.float8_e4m3fn): ir.Float8E4M3FNType.get,
|
|
np.dtype(dtypes.float8_e4m3fnuz): ir.Float8E4M3FNUZType.get,
|
|
np.dtype(dtypes.float8_e5m2): ir.Float8E5M2Type.get,
|
|
np.dtype(dtypes.float8_e5m2fnuz): ir.Float8E5M2FNUZType.get,
|
|
np.dtype(dtypes.bfloat16): ir.BF16Type.get,
|
|
np.dtype(np.float16): ir.F16Type.get,
|
|
np.dtype(np.float32): ir.F32Type.get,
|
|
np.dtype(np.float64): ir.F64Type.get,
|
|
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
|
|
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
|
|
}
|
|
|
|
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
|
|
if isinstance(dtype, core.bint):
|
|
# TODO Support different-size underlying dtypes to take advantage of the
|
|
# bound for packing?
|
|
dtype = np.dtype(np.int32)
|
|
assert isinstance(dtype, (np.dtype, np.generic)), type(dtype)
|
|
dtype = np.dtype(dtype)
|
|
try:
|
|
ir_type_factory = _dtype_to_ir_type[dtype]
|
|
except KeyError as err:
|
|
raise TypeError(
|
|
f"No dtype_to_ir_type handler for dtype: {dtype}") from err
|
|
return ir_type_factory()
|
|
|
|
def _array_ir_types(aval: core.ShapedArray | core.DShapedArray
|
|
) -> Sequence[ir.Type]:
|
|
aval = core.physical_aval(aval) # type: ignore
|
|
if not core.is_constant_shape(aval.shape):
|
|
return _dynamic_array_ir_types(aval) # type: ignore
|
|
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
|
|
|
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
|
dyn_size = ir.ShapedType.get_dynamic_size()
|
|
shape = [d if type(d) is int else dyn_size for d in aval.shape]
|
|
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
|
|
|
|
ir_type_handlers: dict[type[core.AbstractValue],
|
|
Callable[[Any], Sequence[ir.Type]]] = {}
|
|
|
|
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
|
|
"""Converts a JAX aval to zero or more MLIR IR types.
|
|
|
|
In general, a JAX value may be represented by multiple IR values, so this
|
|
function returns multiple types."""
|
|
try:
|
|
return ir_type_handlers[type(aval)](aval)
|
|
except KeyError as err:
|
|
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
|
|
|
|
ir_type_handlers[core.ShapedArray] = _array_ir_types
|
|
ir_type_handlers[core.ConcreteArray] = _array_ir_types
|
|
ir_type_handlers[core.AbstractToken] = lambda _: [hlo.TokenType.get()]
|
|
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
|
|
|
|
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
|
|
"""Convenience wrapper around aval_to_ir_types for single types.
|
|
|
|
For some common cases, e.g. dense arrays, we know JAX values are represented
|
|
by a single IR value."""
|
|
types = aval_to_ir_types(aval)
|
|
if len(types) != 1:
|
|
raise TypeError(f"aval_to_ir_type called on {aval} which corresponds to "
|
|
f"multiple IR types {types}")
|
|
return types[0]
|
|
|
|
|
|
# Constants
|
|
|
|
class ConstantHandler(Protocol):
|
|
def __call__(self, val: Any) -> Sequence[ir.Value]:
|
|
"""Builds an IR representation for a constant `val`.
|
|
|
|
A JAX value is represented by zero or more IR values."""
|
|
|
|
_constant_handlers : dict[type, ConstantHandler] = {}
|
|
|
|
def register_constant_handler(type_: type, handler_fun: ConstantHandler):
|
|
_constant_handlers[type_] = handler_fun
|
|
|
|
def get_constant_handler(type_: type) -> ConstantHandler:
|
|
return _constant_handlers[type_]
|
|
|
|
def ir_constants(val: Any) -> Sequence[ir.Value]:
|
|
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
|
|
|
|
Args:
|
|
val: a Python value to be translated to a constant.
|
|
|
|
Returns:
|
|
A representation of the constant as a list of IR values.
|
|
"""
|
|
for t in type(val).__mro__:
|
|
handler = _constant_handlers.get(t)
|
|
if handler:
|
|
out = handler(val)
|
|
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
|
|
return out
|
|
if hasattr(val, '__jax_array__'):
|
|
return ir_constants(val.__jax_array__())
|
|
raise TypeError(f"No constant handler for type: {type(val)}")
|
|
|
|
def ir_constant(val: Any) -> ir.Value:
|
|
"""Convenience wrapper around ir_constants for singleton values."""
|
|
values = ir_constants(val)
|
|
if len(values) != 1:
|
|
raise TypeError(f"ir_constant called on {val} which corresponds to "
|
|
f"multiple IR values {values}")
|
|
return values[0]
|
|
|
|
|
|
def _numpy_array_constant(x: np.ndarray | np.generic) -> Sequence[ir.Value]:
|
|
element_type = dtype_to_ir_type(x.dtype)
|
|
shape = x.shape
|
|
if x.dtype == np.bool_:
|
|
x = np.packbits(x, bitorder='little') # type: ignore
|
|
x = np.ascontiguousarray(x)
|
|
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
|
|
return (hlo.constant(attr),)
|
|
|
|
|
|
def _masked_array_constant_handler(*args, **kwargs):
|
|
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
|
"Use arr.filled() to convert the value to a standard numpy array.")
|
|
|
|
register_constant_handler(np.ma.MaskedArray, _masked_array_constant_handler)
|
|
|
|
def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value]:
|
|
"""Constant handler for ndarray literals, handling zero-size strides.
|
|
|
|
In most cases this function calls _numpy_array_constant(val) except it has
|
|
special handling of arrays with any strides of size zero: for those, it
|
|
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
|
to avoid staging in large literals that might arise from np.zeros or np.ones
|
|
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
|
uses size-zero strides).
|
|
|
|
Args:
|
|
val: an ndarray.
|
|
|
|
Returns:
|
|
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
|
staged into the XLA Computation.
|
|
"""
|
|
if val.dtype == dtypes.float0:
|
|
return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_))
|
|
elif np.any(np.equal(0, val.strides)) and val.size > 0:
|
|
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
|
other_axes, = np.where(np.not_equal(0, val.strides))
|
|
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
|
|
for ax in range(val.ndim))] # type: ignore
|
|
out = hlo.broadcast_in_dim(
|
|
ir.RankedTensorType.get(
|
|
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
|
|
_numpy_array_constant(collapsed_val)[0],
|
|
dense_int_array_v6(other_axes))
|
|
return (out,)
|
|
else:
|
|
return _numpy_array_constant(val)
|
|
|
|
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
|
|
|
for _scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
np.float16, np.float32, np.float64,
|
|
np.complex64, np.complex128,
|
|
np.bool_, np.longlong, dtypes.bfloat16]:
|
|
register_constant_handler(_scalar_type, _ndarray_constant_handler) # type: ignore
|
|
|
|
def _python_scalar_handler(dtype, val):
|
|
return _numpy_array_constant(np.array(val, dtype))
|
|
|
|
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
|
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
|
|
|
def _token_constant_handler(val):
|
|
return [hlo.create_token()]
|
|
register_constant_handler(core.Token, _token_constant_handler)
|
|
|
|
# Source locations
|
|
|
|
def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
|
|
canonical_file_name = caches.canonical_name_cache.get(file_name, None)
|
|
if canonical_file_name is not None:
|
|
return canonical_file_name
|
|
|
|
pattern = config.hlo_source_file_canonicalization_regex.value
|
|
if pattern:
|
|
file_name = re.sub(pattern, '', file_name)
|
|
caches.canonical_name_cache[file_name] = file_name
|
|
return file_name
|
|
|
|
def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
|
|
is_user = ctx.traceback_caches.is_user_file_cache.get(file_name, None)
|
|
if is_user is not None:
|
|
return is_user
|
|
out = source_info_util.is_user_filename(file_name)
|
|
ctx.traceback_caches.is_user_file_cache[file_name] = out
|
|
return out
|
|
|
|
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
|
"""Converts a full traceback to a callsite() MLIR location."""
|
|
loc = ctx.traceback_caches.traceback_cache.get(tb, None)
|
|
if loc is not None:
|
|
return loc
|
|
|
|
frame_locs = []
|
|
frames_limit = config.traceback_in_locations_limit.value
|
|
frames_limit = frames_limit if frames_limit >= 0 else 1000
|
|
|
|
codes, lastis = tb.raw_frames()
|
|
for i, code in enumerate(codes):
|
|
if not _is_user_file(ctx, code.co_filename):
|
|
continue
|
|
|
|
lasti = lastis[i]
|
|
code_lasti = code, lasti
|
|
loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
|
|
if loc is None:
|
|
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
|
file_loc = ir.Location.file(
|
|
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
|
frame.start_line,
|
|
frame.start_column,
|
|
)
|
|
loc = ir.Location.name(frame.function_name, childLoc=file_loc)
|
|
ctx.traceback_caches.location_cache[code_lasti] = loc
|
|
frame_locs.append(loc)
|
|
if len(frame_locs) >= frames_limit:
|
|
break
|
|
|
|
n = len(frame_locs)
|
|
if n == 0:
|
|
loc = ir.Location.unknown()
|
|
elif n == 1:
|
|
loc = frame_locs[0]
|
|
else:
|
|
loc = ir.Location.callsite(frame_locs[0], frame_locs[1:])
|
|
ctx.traceback_caches.traceback_cache[tb] = loc
|
|
return loc
|
|
|
|
def _source_info_to_location(
|
|
ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
|
|
source_info: source_info_util.SourceInfo) -> ir.Location:
|
|
eqn_str = (f'{source_info.name_stack}/'
|
|
f'{core.str_eqn_compact(primitive.name, params)}')
|
|
if config.include_full_tracebacks_in_locations.value:
|
|
if source_info.traceback is None:
|
|
loc = ir.Location.unknown()
|
|
else:
|
|
loc = _traceback_to_location(ctx, source_info.traceback)
|
|
else:
|
|
frame = source_info_util.user_frame(source_info)
|
|
if frame is None:
|
|
loc = ir.Location.unknown()
|
|
else:
|
|
loc = ir.Location.file(get_canonical_source_file(frame.file_name,
|
|
ctx.traceback_caches),
|
|
frame.start_line, frame.start_column)
|
|
loc = ir.Location.name(eqn_str, childLoc=loc)
|
|
# TODO(phawkins): also include primitive.name as the operator type.
|
|
return loc
|
|
|
|
upstream_dialects = ir.DialectRegistry()
|
|
if register_jax_dialects:
|
|
register_jax_dialects.register_dialects(upstream_dialects)
|
|
|
|
# Dumping MLIR modules
|
|
_ir_dump_counter = itertools.count()
|
|
|
|
def dump_module_to_file(module: ir.Module, stage_name: str) -> str | None:
|
|
"""Dumps the `module` IR to a file.
|
|
|
|
Dumps the module if JAX_DUMP_IR_TO is defined.
|
|
|
|
Args:
|
|
module: The module to dump
|
|
stage_name: A name to distinguish different stages of a module, will be
|
|
appended to the `module.name`.
|
|
|
|
Returns:
|
|
The name of the file containing the dump if JAX_DUMP_IR_TO is defined and
|
|
the module was dumped, `None` otherwise.
|
|
"""
|
|
out_dir_name = _JAX_DUMP_IR_TO.value
|
|
if not out_dir_name:
|
|
return None
|
|
if out_dir_name == "sponge":
|
|
out_dir_name = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "")
|
|
if not out_dir_name:
|
|
raise ValueError("JAX_DUMP_IR_TO='sponge' but "
|
|
"TEST_UNDECLARED_OUTPUTS_DIR is not defined")
|
|
|
|
id = next(_ir_dump_counter)
|
|
sym_name = module.operation.attributes['sym_name']
|
|
module_name = ir.StringAttr(sym_name).value
|
|
|
|
name = f"jax_ir{id}_{_make_string_safe_for_filename(module_name)}_{stage_name}.mlir"
|
|
|
|
out_dir = path.Path(out_dir_name)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
full_path = out_dir / name
|
|
full_path.write_text(module_to_string(module))
|
|
return name
|
|
|
|
def dump_module_message(module: ir.Module, stage_name: str) -> str:
|
|
dumped_to = dump_module_to_file(module, stage_name)
|
|
if dumped_to:
|
|
return f"The module was dumped to {dumped_to}."
|
|
else:
|
|
return "Define JAX_DUMP_IR_TO to dump the module."
|
|
|
|
def _make_string_safe_for_filename(s: str) -> str:
|
|
return re.sub(r'[^\w.)( -]', '', s)
|
|
|
|
def module_to_string(module: ir.Module) -> str:
|
|
output = io.StringIO()
|
|
module.operation.print(file=output, enable_debug_info=True)
|
|
return output.getvalue()
|
|
|
|
def module_to_bytecode(module: ir.Module) -> bytes:
|
|
output = io.BytesIO()
|
|
module.operation.write_bytecode(file=output)
|
|
return output.getvalue()
|
|
|
|
# Translation rules
|
|
def make_ir_context() -> ir.Context:
|
|
"""Creates an MLIR context suitable for JAX IR."""
|
|
context = ir.Context()
|
|
context.append_dialect_registry(upstream_dialects)
|
|
context.load_all_available_dialects()
|
|
|
|
# If threading is enabled, each MLIR context will keep alive a thread pool.
|
|
# Since we cache MLIR modules (and hence contexts), this means we might keep
|
|
# several threads alive for each cache entry. This is a terrible idea. However
|
|
# we don't do any heavy computation on MLIR modules from Python anyway, so we
|
|
# just disable threading.
|
|
context.enable_multithreading(False)
|
|
|
|
dialects.mhlo.register_mhlo_dialect(context)
|
|
dialects.chlo.register_dialect(context)
|
|
dialects.hlo.register_dialect(context)
|
|
return context
|
|
|
|
|
|
AxisContext = Union[
|
|
sharding_impls.SPMDAxisContext,
|
|
sharding_impls.ReplicaAxisContext,
|
|
sharding_impls.ShardingContext,
|
|
]
|
|
|
|
class ShapePolyLoweringState:
|
|
# The names of the dimension variables, sorted by name. This is the order in
|
|
# which they are passed to the IR functions that need them. This is only
|
|
# used for native serialization with polymorphic shapes when
|
|
# --jax_dynamic_shapes is off.
|
|
# TODO: for multi-platform lowering we prepend to the regular dimension
|
|
# variables a fake dimension variable "platform_index_". This is a
|
|
# temporary abuse, taking advantage that for platform index we need the
|
|
# same lowering strategy as for dimension variables: add it as argument to
|
|
# inner functions, and pass the values along at the call sites.
|
|
dim_vars: tuple[str, ...]
|
|
# Whether the module uses dimension variables, either in its inputs or
|
|
# from an inner call to Exported modules that uses dimension variables.
|
|
# This includes the case when the called Exported module uses a platform
|
|
# index argument.
|
|
uses_dim_vars: bool
|
|
|
|
# If the first dimension variable is a platform index argument
|
|
has_platform_index_argument: bool
|
|
|
|
def __init__(self,
|
|
dim_vars: tuple[str, ...],
|
|
lowering_platforms: tuple[str, ...] | None):
|
|
if lowering_platforms is not None and len(lowering_platforms) > 1:
|
|
dim_vars = ("_platform_index",) + tuple(dim_vars)
|
|
self.has_platform_index_argument = True
|
|
else:
|
|
self.has_platform_index_argument = False
|
|
self.uses_dim_vars = (len(dim_vars) > 0)
|
|
self.dim_vars = dim_vars
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class LoweringParameters:
|
|
# A mapping between primitives and user-defined LoweringRules.
|
|
# When lowering a primitive, give priorioty to the rule in this map over
|
|
# existing Jax rules.
|
|
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None
|
|
|
|
# The current lowering platforms, a non-empty tuple containing some of
|
|
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
|
|
# doing multi-platform lowering, otherwise it can specify cross-platform
|
|
# lowering. The value None specifies the default lowering platform.
|
|
# This is used only in export and jax2tf.
|
|
platforms: tuple[str, ...] | None = None
|
|
|
|
# Signals that the entire computation being lowered operates on global
|
|
# constants. This will result in adding jax.global_constant attributes
|
|
# to the arguments of all functions that are created, e.g., floor_divide.
|
|
# This is used only in export and jax2tf in presence of shape polymorphism
|
|
# or multi-platform lowering.
|
|
global_constant_computation: bool = False
|
|
|
|
# TODO(b/302258959): in JAX native execution we cannot lower the tokens
|
|
# to stablehlo.token for the top-level function, due to runtime limitations.
|
|
# Instead, we use dummy bool[0] arrays. This is controlled by setting
|
|
# replace_tokens_with_dummy to True (default). However, when exporting StableHLO
|
|
# we can use real tokens, because the resulting StableHLO will not be
|
|
# executed directly, but will be embedded as an inner function in a larger
|
|
# JAX or TensorFlow program. In these cases, replace_tokens_with_dummy must
|
|
# be set to False (for serialization versions >= 9).
|
|
# Once the PJRT is extended to use tokens, we can use tokens even in the
|
|
# native execution (and we can remove this parameter).
|
|
replace_tokens_with_dummy: bool = True
|
|
|
|
@dataclasses.dataclass
|
|
class TracebackCaches:
|
|
traceback_cache: dict[xc.Traceback, ir.Location]
|
|
location_cache: dict[tuple[types.CodeType, int], ir.Location]
|
|
canonical_name_cache: dict[str, str]
|
|
is_user_file_cache: dict[str, bool]
|
|
|
|
def __init__(self):
|
|
self.traceback_cache = {}
|
|
self.location_cache = {}
|
|
self.canonical_name_cache = {}
|
|
self.is_user_file_cache = {}
|
|
|
|
@dataclasses.dataclass
|
|
class ModuleContext:
|
|
"""Module-wide context information for MLIR lowering."""
|
|
context: ir.Context
|
|
module: ir.Module
|
|
ip: ir.InsertionPoint
|
|
symbol_table: ir.SymbolTable
|
|
backend_or_name: str | xb.XlaBackend | None
|
|
platforms: Sequence[str]
|
|
axis_context: AxisContext
|
|
keepalives: list[Any]
|
|
channel_iterator: Iterator[int]
|
|
host_callbacks: list[Any]
|
|
# Keep state for the lowering of shape polymorphism
|
|
shape_poly_state: ShapePolyLoweringState
|
|
|
|
# Cached primitive lowerings.
|
|
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
|
|
|
|
# Cached traceback infromation.
|
|
traceback_caches: TracebackCaches
|
|
|
|
lowering_parameters: LoweringParameters
|
|
|
|
@property
|
|
def axis_env(self) -> sharding_impls.AxisEnv:
|
|
return self.axis_context.axis_env
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
backend_or_name: str | xb.XlaBackend | None,
|
|
platforms: Sequence[str],
|
|
axis_context: AxisContext,
|
|
keepalives: list[Any],
|
|
channel_iterator: Iterator[int],
|
|
host_callbacks: list[Any],
|
|
lowering_parameters: LoweringParameters,
|
|
context: ir.Context | None = None,
|
|
module: ir.Module | None = None,
|
|
ip: ir.InsertionPoint | None = None,
|
|
symbol_table: ir.SymbolTable | None = None,
|
|
cached_primitive_lowerings: None | (dict[Any,
|
|
func_dialect.FuncOp]) = None,
|
|
traceback_caches: None | TracebackCaches = None,
|
|
shape_poly_state = None):
|
|
|
|
self.context = context or make_ir_context()
|
|
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
|
|
self.ip = ip or ir.InsertionPoint(self.module.body)
|
|
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
|
|
self.backend_or_name = backend_or_name
|
|
self.platforms = platforms
|
|
self.axis_context = axis_context
|
|
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
|
else cached_primitive_lowerings)
|
|
self.traceback_caches = (TracebackCaches() if traceback_caches is None
|
|
else traceback_caches)
|
|
self.channel_iterator = channel_iterator
|
|
self.keepalives = keepalives
|
|
self.host_callbacks = host_callbacks
|
|
self.shape_poly_state = (
|
|
shape_poly_state or ShapePolyLoweringState((), tuple(platforms)))
|
|
self.lowering_parameters = lowering_parameters
|
|
|
|
@property
|
|
def backend(self) -> xb.XlaBackend:
|
|
# TODO(necula): clean the use of backend and backend_or_name vs. platforms
|
|
if len(self.platforms) > 1:
|
|
raise NotImplementedError(
|
|
"accessing .backend in multi-lowering setting. This can occur when "
|
|
"lowering a primitive that has not been adapted to multi-platform "
|
|
"lowering")
|
|
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
|
|
return xb.get_backend(self.backend_or_name)
|
|
return self.backend_or_name
|
|
|
|
def new_channel(self) -> int:
|
|
return next(self.channel_iterator)
|
|
|
|
# Adds an IFRT host callback object to the context. A reference to these
|
|
# callbacks will be provided to IFRT during compilation so it can do things
|
|
# like serialize them and keep them alive.
|
|
def add_host_callback(self, host_callback: Any) -> None:
|
|
self.host_callbacks.append(host_callback)
|
|
|
|
# Keeps a value alive as long as the Python executable is alive.
|
|
# TODO(phawkins): this feature is problematic, because you almost certainly
|
|
# want to keep alive values as long as the underlying runtime executable is
|
|
# still alive/executing. The Python executable object may have a shorter
|
|
# lifetime, so it's highly likely any caller of this method is buggy.
|
|
def add_keepalive(self, keepalive: Any) -> None:
|
|
self.keepalives.append(keepalive)
|
|
|
|
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LoweringRuleContext:
|
|
"""Per-rule context information for MLIR lowering."""
|
|
module_context: ModuleContext
|
|
name_stack: source_info_util.NameStack
|
|
primitive: core.Primitive | None
|
|
avals_in: Sequence[core.AbstractValue]
|
|
avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None.
|
|
tokens_in: TokenSet
|
|
tokens_out: TokenSet | None # Mutable store for output containers
|
|
axis_size_env: dict[core.Var, ir.Value] | None = None # Dynamic axis sizes
|
|
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
|
|
# in same order as module_context.shape_poly_state.dim_vars
|
|
|
|
def set_tokens_out(self, tokens_out: TokenSet):
|
|
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
|
|
self.tokens_out = tokens_out
|
|
|
|
def replace(self, **kw): return dataclasses.replace(self, **kw) # pytype: disable=wrong-arg-types # dataclasses-replace-types
|
|
|
|
|
|
if not MYPY:
|
|
class LoweringRule(Protocol):
|
|
def __call__(self, ctx: LoweringRuleContext,
|
|
*args: ir.Value | Sequence[ir.Value],
|
|
**kw) -> Sequence[ir.Value | Sequence[ir.Value]]:
|
|
"""Converts a JAX primitive invocation into MLIR."""
|
|
else:
|
|
LoweringRule = Any
|
|
|
|
_lowerings: dict[core.Primitive, LoweringRule] = {}
|
|
_platform_specific_lowerings: dict[str, dict[core.Primitive, LoweringRule]]
|
|
_platform_specific_lowerings = collections.defaultdict(dict)
|
|
|
|
def register_lowering(prim: core.Primitive, rule: LoweringRule,
|
|
platform: str | None = None):
|
|
if platform is None:
|
|
_lowerings[prim] = rule
|
|
else:
|
|
# For backward compatibility reasons, we allow rules to be registered
|
|
# under "gpu" even though the platforms are now called "cuda" and "rocm".
|
|
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
|
|
# this expansion.
|
|
for p in xb.expand_platform_alias(platform):
|
|
_platform_specific_lowerings[p][prim] = rule
|
|
return rule
|
|
|
|
|
|
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
|
|
def wrap_singleton_ir_values(x: ir.Value | Sequence[ir.Value]
|
|
) -> Sequence[ir.Value]:
|
|
"""Adds a consistent tuples to a mixture of tupled and untuple values."""
|
|
return (x,) if isinstance(x, ir.Value) else tuple(x)
|
|
|
|
def flatten_lowering_ir_args(
|
|
xs: Sequence[ir.Value | Sequence[ir.Value]]
|
|
) -> Sequence[Sequence[ir.Value]]:
|
|
return util.flatten(map(wrap_singleton_ir_values, xs))
|
|
|
|
_module_name_regex = re.compile(r"[^\w.-]")
|
|
|
|
def sharded_aval(aval: core.AbstractValue,
|
|
sharding: XLACompatibleSharding | None) -> core.AbstractValue:
|
|
"""Returns the new aval sharded based on sharding proto."""
|
|
if sharding is None:
|
|
return aval
|
|
if isinstance(aval, core.AbstractToken):
|
|
return aval
|
|
if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
|
|
raise NotImplementedError
|
|
return aval.update(sharding.shard_shape(aval.shape)) # type: ignore
|
|
|
|
|
|
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
|
shape: core.Shape) -> tuple[int | Value, ...]:
|
|
if config.dynamic_shapes.value:
|
|
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
|
|
else:
|
|
ctx = ctx.replace(
|
|
primitive="eval_dynamic_shape",
|
|
avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars),
|
|
tokens_out=None)
|
|
|
|
res = lower_fun(
|
|
partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
|
|
multiple_results=True)(ctx, *ctx.dim_var_values)
|
|
return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir
|
|
for d, d_ir in zip(shape, util.flatten(res))) # type: ignore
|
|
|
|
# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
|
|
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
|
|
shape: core.Shape) -> tuple[Value, ...]:
|
|
"""Evaluates the dynamic shapes as int32 values."""
|
|
def convert_dim(d: int | Value):
|
|
if type(d) is int:
|
|
return ir_constant(np.array(d, dtype=np.int32))
|
|
else:
|
|
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
|
|
if d.type != i32_type: # type: ignore
|
|
return hlo.convert(i32_type, d)
|
|
else:
|
|
return d
|
|
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
|
|
|
|
|
|
def eval_dynamic_shape_as_ivals(
|
|
ctx: LoweringRuleContext, shape: core.Shape
|
|
) -> tuple[int | Value, ...]:
|
|
"""Evaluates the dynamic shapes as int or ir.int32 values."""
|
|
def convert_dim(d: int | Value) -> int | ir.Value:
|
|
if type(d) is int:
|
|
return d
|
|
else:
|
|
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
|
|
if d.type != i32_type: # type: ignore
|
|
return hlo.convert(i32_type, d)
|
|
else:
|
|
return d
|
|
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
|
|
|
|
def eval_dynamic_shape_as_tensor(ctx: LoweringRuleContext,
|
|
shape: core.Shape) -> Value:
|
|
"""Evaluates the dynamic shapes as one 1d int32 tensor."""
|
|
return shape_tensor(eval_dynamic_shape(ctx, shape))
|
|
|
|
class LoweringResult(NamedTuple):
|
|
module: ir.Module
|
|
keepalive: Any | None
|
|
host_callbacks: list[Any]
|
|
shape_poly_state: ShapePolyLoweringState
|
|
|
|
|
|
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
|
|
|
|
|
def _to_physical_op_sharding(
|
|
aval: core.AbstractValue, sharding: XLACompatibleSharding | None,
|
|
) -> xc.OpSharding | None:
|
|
if sharding is None:
|
|
return None
|
|
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
|
if isinstance(aval, AbstractRef):
|
|
return _to_physical_op_sharding(aval.inner_aval, sharding)
|
|
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
|
sharding = aval.dtype._rules.physical_sharding(aval, sharding)
|
|
aval = core.physical_aval(aval)
|
|
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
|
|
|
|
|
|
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
|
|
if layout is None:
|
|
return "default"
|
|
if isinstance(layout, AutoLayout):
|
|
return "auto"
|
|
return layout._to_xla_layout()
|
|
|
|
|
|
def _get_mem_kind(s: XLACompatibleSharding | None) -> str | None:
|
|
if s is None:
|
|
return None
|
|
assert isinstance(s, sharding_impls.XLACompatibleSharding)
|
|
return s.memory_kind
|
|
|
|
|
|
def lower_jaxpr_to_module(
|
|
module_name: str,
|
|
jaxpr: core.ClosedJaxpr,
|
|
*,
|
|
ordered_effects: list[core.Effect],
|
|
backend_or_name: str | xb.XlaBackend | None,
|
|
platforms: Sequence[str],
|
|
axis_context: AxisContext,
|
|
name_stack: source_info_util.NameStack,
|
|
donated_args: Sequence[bool],
|
|
replicated_args: Sequence[bool] | None = None,
|
|
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
|
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
|
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
|
|
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
|
|
arg_names: Sequence[str | None] | None = None,
|
|
result_names: Sequence[str | None] | None = None,
|
|
num_replicas: int = 1,
|
|
num_partitions: int = 1,
|
|
all_default_mem_kind: bool = True,
|
|
input_output_aliases: None | tuple[int | None, ...] = None,
|
|
lowering_parameters: LoweringParameters,
|
|
) -> LoweringResult:
|
|
"""Lowers a top-level jaxpr to an MLIR module.
|
|
|
|
Handles the quirks of the argument/return value passing conventions of the
|
|
runtime.
|
|
"""
|
|
platforms = tuple(map(xb.canonicalize_platform, platforms))
|
|
|
|
in_avals = (jaxpr.in_avals if arg_shardings is None else
|
|
map(sharded_aval, jaxpr.in_avals, arg_shardings))
|
|
out_avals = (jaxpr.out_avals if result_shardings is None else
|
|
map(sharded_aval, jaxpr.out_avals, result_shardings))
|
|
if all_default_mem_kind:
|
|
arg_memory_kinds = None
|
|
result_memory_kinds = None
|
|
else:
|
|
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
|
|
if arg_shardings is not None else None)
|
|
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
|
if result_shardings is not None else None)
|
|
|
|
xla_donated_args = None
|
|
platforms_with_donation = [p for p in platforms
|
|
if p in _platforms_with_donation]
|
|
if platforms_with_donation:
|
|
if len(platforms_with_donation) != len(platforms):
|
|
raise NotImplementedError(
|
|
"In multi-platform lowering either all or no lowering platforms "
|
|
f"should support donation. Lowering for {platforms} of which "
|
|
f"only {platforms_with_donation} support donation")
|
|
if num_partitions > 1 and xla_extension_version >= 220 and (
|
|
result_shardings is None or all(s is None for s in result_shardings)):
|
|
xla_donated_args = donated_args
|
|
if xla_donated_args is None:
|
|
input_output_aliases, donated_args = _set_up_aliases(
|
|
input_output_aliases, in_avals, out_avals, donated_args,
|
|
arg_memory_kinds, result_memory_kinds)
|
|
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
|
if unlowerable_effects:
|
|
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
|
if xla_donated_args is None and any(donated_args):
|
|
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
|
|
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
|
if not platforms_with_donation:
|
|
msg = f"Donation is not implemented for {platforms}.\n{msg}"
|
|
if unused_donations:
|
|
warnings.warn("Some donated buffers were not usable:"
|
|
f" {', '.join(unused_donations)}.\n{msg}")
|
|
|
|
if xla_donated_args is not None:
|
|
assert input_output_aliases is None
|
|
if input_output_aliases is not None:
|
|
assert xla_donated_args is None
|
|
|
|
# Delete donated_args by default here, since it's not needed beyond this point
|
|
del donated_args
|
|
|
|
# HLO channels need to start at 1
|
|
channel_iter = itertools.count(1)
|
|
# Create a keepalives list that will be mutated during the lowering.
|
|
keepalives: list[Any] = []
|
|
host_callbacks: list[Any] = []
|
|
|
|
dim_vars: Sequence[str]
|
|
if not config.dynamic_shapes.value:
|
|
# Find the dimension variables
|
|
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
|
|
for d in aval.shape if not core.is_constant_dim(d)]
|
|
dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()),
|
|
all_dim_poly, set())))
|
|
else:
|
|
dim_vars = ()
|
|
|
|
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
|
|
else in_layouts)
|
|
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
|
|
else out_layouts)
|
|
|
|
ctx = ModuleContext(backend_or_name=backend_or_name,
|
|
platforms=platforms, axis_context=axis_context,
|
|
keepalives=keepalives,
|
|
channel_iterator=channel_iter,
|
|
host_callbacks=host_callbacks,
|
|
lowering_parameters=lowering_parameters,
|
|
shape_poly_state=ShapePolyLoweringState(
|
|
dim_vars, lowering_parameters.platforms))
|
|
with ctx.context, ir.Location.unknown(ctx.context):
|
|
# Remove module name characters that XLA would alter. This ensures that
|
|
# XLA computation preserves the module name.
|
|
attrs = ctx.module.operation.attributes
|
|
module_name = _module_name_regex.sub("_", module_name)
|
|
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
|
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
|
|
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
|
|
replace_tokens_with_dummy = lowering_parameters.replace_tokens_with_dummy
|
|
lower_jaxpr_to_fun(
|
|
ctx, "main", jaxpr, ordered_effects,
|
|
name_stack=name_stack,
|
|
public=True,
|
|
create_tokens=replace_tokens_with_dummy,
|
|
replace_tokens_with_dummy=replace_tokens_with_dummy,
|
|
num_output_tokens=0,
|
|
replicated_args=replicated_args,
|
|
arg_shardings=arg_shardings,
|
|
result_shardings=result_shardings,
|
|
input_output_aliases=input_output_aliases,
|
|
xla_donated_args=xla_donated_args,
|
|
arg_names=arg_names,
|
|
result_names=result_names,
|
|
arg_memory_kinds=arg_memory_kinds,
|
|
result_memory_kinds=result_memory_kinds,
|
|
arg_layouts=arg_layouts,
|
|
result_layouts=result_layouts)
|
|
|
|
try:
|
|
if not ctx.module.operation.verify():
|
|
raise ValueError(
|
|
"Cannot lower jaxpr with verifier errors." +
|
|
dump_module_message(ctx.module, "verification"))
|
|
except ir.MLIRError as e:
|
|
msg_lines = ["Cannot lower jaxpr with verifier errors:"]
|
|
def emit_diagnostic_info(d):
|
|
msg_lines.append(f"\t{d.message}")
|
|
msg_lines.append(f"\t\tat {d.location}")
|
|
for n in d.notes:
|
|
emit_diagnostic_info(n)
|
|
for d in e.error_diagnostics:
|
|
emit_diagnostic_info(d)
|
|
raise ValueError("\n".join(msg_lines) +
|
|
dump_module_message(ctx.module, "verification")) from e
|
|
|
|
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
|
|
ctx.shape_poly_state)
|
|
|
|
def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
|
|
arg_memory_kinds, result_memory_kinds):
|
|
if input_output_aliases is None:
|
|
input_output_aliases = [None] * len(avals_in)
|
|
else:
|
|
input_output_aliases = list(input_output_aliases)
|
|
# To match-up in-avals to out-avals we only care about the number of
|
|
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
|
|
strip_metadata = lambda a: a.strip_named_shape().strip_weak_type()
|
|
avals_in = map(strip_metadata, avals_in)
|
|
avals_out = map(strip_metadata, avals_out)
|
|
|
|
# Both arg and result memory kinds need to be specified to donate based on
|
|
# the memory kind. For jit's where out_shardings is not specified, we don't
|
|
# know the memory kind so don't condition the logic based on the memory kind.
|
|
# TODO(yashkatariya): Note that this logic should be in C++ where we make
|
|
# donation decisions are made after SPMD propagation passes and memory
|
|
# placement passes so that we have all the information.
|
|
if (arg_memory_kinds is None or result_memory_kinds is None or
|
|
any(a is None for a in arg_memory_kinds) or
|
|
any(r is None for r in result_memory_kinds)):
|
|
arg_memory_kinds = [None] * len(avals_in)
|
|
result_memory_kinds = [None] * len(avals_out)
|
|
|
|
donations = collections.defaultdict(collections.deque)
|
|
for i, (aval, am, donated, aliased) in enumerate(
|
|
zip(avals_in, arg_memory_kinds, donated_args, input_output_aliases)):
|
|
if donated and aliased is None:
|
|
donations[(aval, am)].append(i)
|
|
|
|
out_donated_args = list(donated_args)
|
|
for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)):
|
|
# Only donate if memory kinds match. Relax this when the compiler can
|
|
# donate across memories.
|
|
key = (aval, rm)
|
|
if donations.get(key, ()):
|
|
input_id = donations[key].popleft()
|
|
input_output_aliases[input_id] = i
|
|
out_donated_args[input_id] = False
|
|
|
|
return input_output_aliases, out_donated_args
|
|
|
|
Token = Sequence[ir.Value]
|
|
|
|
def token_type() -> Sequence[ir.Type]:
|
|
return [hlo.TokenType.get()]
|
|
|
|
def create_token() -> Token:
|
|
return wrap_singleton_ir_values(hlo.create_token())
|
|
|
|
class TokenSet:
|
|
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
|
|
effectful jaxprs, we need to thread HLO tokens to sequence them. Each effect
|
|
will need its own token that will be threaded in and out of the effectful
|
|
primitives. A `TokenSet` encapsulates a set of HLO tokens that will be
|
|
used by the lowering rules.
|
|
"""
|
|
_tokens: collections.OrderedDict[core.Effect, Token]
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._tokens = collections.OrderedDict(*args, **kwargs)
|
|
|
|
def __len__(self):
|
|
return len(self._tokens)
|
|
|
|
def get(self, effect: core.Effect) -> Token:
|
|
return self._tokens[effect]
|
|
|
|
@classmethod
|
|
def create(cls, effects: Sequence[core.Effect]) -> TokenSet:
|
|
"""Creates a `TokenSet` corresponding to a list of `core.Effect`s."""
|
|
tokens = [create_token() for _ in effects]
|
|
return TokenSet(zip(effects, tokens))
|
|
|
|
def items(self) -> Sequence[tuple[core.Effect, Token]]:
|
|
return tuple(self._tokens.items())
|
|
|
|
def effects(self) -> set[core.Effect]:
|
|
return set(self._tokens.keys())
|
|
|
|
def subset(self, effects: Sequence[core.Effect]) -> TokenSet:
|
|
"""Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
|
|
return TokenSet((eff, self._tokens[eff]) for eff in effects)
|
|
|
|
def update_tokens(self, tokens: TokenSet) -> TokenSet:
|
|
"""Returns a new `TokenSet` with tokens replaced with ones from the input `TokenSet`."""
|
|
new_tokens = []
|
|
for eff in self.effects():
|
|
if eff in tokens._tokens:
|
|
new_tokens.append((eff, tokens._tokens[eff]))
|
|
else:
|
|
new_tokens.append((eff, self._tokens[eff]))
|
|
return TokenSet(new_tokens)
|
|
|
|
def dummy_token_type() -> Sequence[ir.Type]:
|
|
# TODO(b/302258959): For now HLO does not allow hlo.TokenType among
|
|
# arguments and results, so we use bool[0] to pass tokens to the
|
|
# top-level function only.
|
|
return aval_to_ir_types(core.ShapedArray((0,), np.bool_))
|
|
|
|
def dummy_token() -> Sequence[ir.Value]:
|
|
return ir_constants(np.zeros(0, np.bool_))
|
|
|
|
def lower_jaxpr_to_fun(
|
|
ctx: ModuleContext,
|
|
name: str,
|
|
jaxpr: core.ClosedJaxpr,
|
|
effects: Sequence[core.Effect],
|
|
name_stack: source_info_util.NameStack,
|
|
*,
|
|
create_tokens: bool = False,
|
|
public: bool = False,
|
|
replace_tokens_with_dummy: bool = False,
|
|
replicated_args: Sequence[bool] | None = None,
|
|
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
|
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
|
use_sharding_annotations: bool = True,
|
|
input_output_aliases: Sequence[int | None] | None = None,
|
|
xla_donated_args: Sequence[bool] | None = None,
|
|
num_output_tokens: int = 0,
|
|
api_name: str = "jit",
|
|
arg_names: Sequence[str | None] | None = None,
|
|
result_names: Sequence[str | None] | None = None,
|
|
arg_memory_kinds: Sequence[str | None] | None = None,
|
|
result_memory_kinds: Sequence[str | None] | None = None,
|
|
arg_layouts: Sequence[str | None] | None = None,
|
|
result_layouts: Sequence[str | None] | None = None,
|
|
) -> func_dialect.FuncOp:
|
|
"""Lowers jaxpr and its callees to an IR function.
|
|
|
|
Assumes that an MLIR context, location, and insertion point are set.
|
|
|
|
Args:
|
|
ctx: the lowering context.
|
|
name: the function name. The name will be uniquified by the symbol table,
|
|
so it is ok to use the same name multiple times.
|
|
jaxpr: the jaxpr to lower.
|
|
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
|
|
that will be created in or used by the lowered function.
|
|
create_tokens: if true, the HLO will create tokens and ignore dummy input
|
|
tokens. See b/302258959.
|
|
public: if true, the function's visibility is set to "public".
|
|
replace_tokens_with_dummy: if true, token arguments/return values are
|
|
replaced with bool arrays of size [0]. See b/302258959.
|
|
replicated_args: if present, annotates arguments as replicated.
|
|
arg_shardings: sharding annotations for each argument (optional).
|
|
result_shardings: sharding annotations for each result (optional).
|
|
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
|
|
parameters and return values to express sharding. If False, use
|
|
hlo.custom_call operators with sharding annotations.
|
|
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are
|
|
propagated on non-entry functions during MLIR->HLO conversion.
|
|
input_output_aliases: optional sequence that maps argument numbers to the
|
|
corresponding output that should alias them.
|
|
xla_donated_args: optional sequence of args to set donation annotations.
|
|
api_name: The name of the higher level primitive which should show up in the
|
|
name stack.
|
|
Returns:
|
|
MLIR func op
|
|
"""
|
|
def aval_to_types(aval):
|
|
if replace_tokens_with_dummy and aval is core.abstract_token:
|
|
aval = core.ShapedArray((), np.dtype(np.bool_))
|
|
return aval_to_ir_types(aval)
|
|
|
|
# The first dimension variable may be the platform index
|
|
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
|
|
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
|
|
dim_var_types = map(aval_to_types, dim_var_avals)
|
|
|
|
# Function inputs: *dim_var_values, *tokens, *actual_inputs
|
|
input_types = map(aval_to_types, jaxpr.in_avals)
|
|
output_types = map(aval_to_types, jaxpr.out_avals)
|
|
num_tokens = len(effects)
|
|
|
|
if create_tokens:
|
|
# TODO(b/302258959): Use actual tokens
|
|
token_types = [dummy_token_type() for _ in effects]
|
|
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)]
|
|
else:
|
|
# If we aren't creating tokens they will be the initial inputs to the
|
|
# MLIR function.
|
|
output_token_types = []
|
|
token_types = [token_type() for _ in effects]
|
|
token_avals = [core.abstract_token] * num_tokens
|
|
# Order of arguments: dim vars, tokens, array inputs
|
|
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
|
|
input_types = [*dim_var_types, *token_types, *input_types]
|
|
output_avals = [core.abstract_token] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
|
|
output_types = [*output_token_types, *token_types, *output_types]
|
|
|
|
if input_output_aliases is not None:
|
|
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
|
|
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
|
|
# Update the existing aliases to account for the new output values
|
|
input_output_aliases = [None if a is None
|
|
else a + num_output_tokens + num_tokens
|
|
for a in input_output_aliases] # type: ignore
|
|
|
|
if arg_shardings is not None:
|
|
token_shardings = [None] * (num_dim_vars + num_tokens)
|
|
arg_shardings = [*token_shardings, *arg_shardings]
|
|
if result_shardings is not None:
|
|
token_shardings = [None] * (num_tokens + num_output_tokens)
|
|
result_shardings = [*token_shardings, *result_shardings]
|
|
if replicated_args is not None:
|
|
token_replicated_args = [False] * (num_dim_vars + num_tokens)
|
|
replicated_args = [*token_replicated_args, *replicated_args]
|
|
if arg_memory_kinds is not None:
|
|
token_memory_kinds = [None] * (num_dim_vars + num_tokens)
|
|
arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds]
|
|
if result_memory_kinds is not None:
|
|
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
|
|
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
|
|
if arg_layouts is not None:
|
|
token_layouts = [None] * (num_dim_vars + num_tokens)
|
|
arg_layouts = [*token_layouts, *arg_layouts]
|
|
if result_layouts is not None:
|
|
token_layouts = [None] * (num_tokens + num_output_tokens)
|
|
result_layouts = [*token_layouts, *result_layouts]
|
|
if xla_donated_args is not None:
|
|
xla_donated_args = [*([False] * (num_dim_vars + num_tokens)), *xla_donated_args]
|
|
|
|
flat_input_types = util.flatten(input_types)
|
|
flat_output_types = util.flatten(output_types)
|
|
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
|
|
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
|
|
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
|
|
"public" if public else "private")
|
|
ctx.symbol_table.insert(func_op)
|
|
|
|
ir_arg_shardings = None
|
|
if arg_shardings is not None:
|
|
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
|
|
ir_arg_shardings = util.flatten(
|
|
[[_to_physical_op_sharding(a, s)] * len(types)
|
|
for a, s, types in zip(in_avals, arg_shardings, input_types)])
|
|
del in_avals
|
|
|
|
ir_arg_memory_kinds = None
|
|
if arg_memory_kinds is not None:
|
|
ir_arg_memory_kinds = util.flatten(
|
|
[[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)])
|
|
|
|
ir_arg_layouts = None
|
|
if arg_layouts is not None:
|
|
ir_arg_layouts = util.flatten(
|
|
[[l] * len(types) for l, types in zip(arg_layouts, input_types)])
|
|
|
|
ir_donated_args = None
|
|
if xla_donated_args is not None:
|
|
ir_donated_args = util.flatten(
|
|
[[is_donated] * len(types) for is_donated, types in zip(xla_donated_args, input_types)])
|
|
|
|
ir_result_shardings = None
|
|
if result_shardings is not None:
|
|
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
|
|
ir_result_shardings = util.flatten(
|
|
[[_to_physical_op_sharding(a, s)] * len(types)
|
|
for a, s, types in zip(out_avals, result_shardings, output_types)])
|
|
del out_avals
|
|
|
|
ir_result_memory_kinds = None
|
|
if result_memory_kinds is not None:
|
|
ir_result_memory_kinds = util.flatten(
|
|
[[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)])
|
|
|
|
ir_result_layouts = None
|
|
if result_layouts is not None:
|
|
ir_result_layouts = util.flatten(
|
|
[[l] * len(types) for l, types in zip(result_layouts, output_types)])
|
|
|
|
if (
|
|
replicated_args is not None
|
|
or ir_arg_shardings is not None
|
|
or ir_arg_memory_kinds is not None
|
|
or ir_arg_layouts is not None
|
|
or input_output_aliases is not None
|
|
or ir_donated_args is not None
|
|
or arg_names is not None
|
|
or num_tokens > 0
|
|
or num_dim_vars > 0
|
|
):
|
|
arg_attrs: list[dict[str, ir.Attribute]] = [
|
|
{} for _ in range(len(flat_input_types))]
|
|
|
|
if replicated_args is not None:
|
|
replicated_ir_args = [[replicated] * len(types) for replicated, types
|
|
in zip(replicated_args, input_types)]
|
|
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
|
|
if replicated:
|
|
attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)
|
|
|
|
if use_sharding_annotations and ir_arg_shardings is not None:
|
|
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
|
|
if sharding is not None:
|
|
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
|
|
|
|
if ir_arg_memory_kinds is not None:
|
|
for attrs, memory_kind in zip(arg_attrs, ir_arg_memory_kinds):
|
|
if memory_kind is not None:
|
|
attrs["mhlo.memory_kind"] = ir.StringAttr.get(memory_kind)
|
|
|
|
if ir_arg_layouts is not None:
|
|
for attrs, layout in zip(arg_attrs, ir_arg_layouts):
|
|
if layout is not None:
|
|
attrs["mhlo.layout_mode"] = ir.StringAttr.get(layout)
|
|
|
|
if ir_donated_args is not None:
|
|
for attrs, is_donated in zip(arg_attrs, ir_donated_args):
|
|
if is_donated:
|
|
attrs["jax.buffer_donor"] = ir.BoolAttr.get(True)
|
|
|
|
if input_output_aliases is not None:
|
|
output_ids = util.unflatten(list(range(len(flat_output_types))),
|
|
map(len, output_types))
|
|
aliases: list[int | None] = []
|
|
for itypes, alias in zip(input_types, input_output_aliases):
|
|
if alias is None:
|
|
aliases.extend([None] * len(itypes))
|
|
else:
|
|
aliases.extend(output_ids[alias])
|
|
|
|
for attrs, alias in zip(arg_attrs, aliases):
|
|
if alias is not None:
|
|
attrs["tf.aliasing_output"] = i32_attr(alias)
|
|
|
|
if num_dim_vars > 0:
|
|
for var_name, attrs in zip(ctx.shape_poly_state.dim_vars,
|
|
arg_attrs[:num_dim_vars]):
|
|
attrs["jax.global_constant"] = ir.StringAttr.get(var_name)
|
|
elif ctx.lowering_parameters.global_constant_computation:
|
|
for attrs in arg_attrs:
|
|
attrs["jax.global_constant"] = ir.StringAttr.get("")
|
|
|
|
if num_tokens > 0:
|
|
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens]
|
|
for attrs in token_arg_attrs:
|
|
attrs["jax.token"] = ir.BoolAttr.get(True)
|
|
|
|
func_op.arg_attrs = ir.ArrayAttr.get(
|
|
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
|
|
|
|
result_attrs: list[dict[str, ir.Attribute]] = [
|
|
{} for _ in range(len(flat_output_types))]
|
|
|
|
if num_tokens > 0:
|
|
token_result_attrs = result_attrs[:num_tokens]
|
|
for attrs in token_result_attrs:
|
|
attrs["jax.token"] = ir.BoolAttr.get(True)
|
|
|
|
if result_names:
|
|
named_result_attrs = result_attrs[num_tokens:]
|
|
if len(named_result_attrs) == len(result_names):
|
|
for attrs, name_ in zip(named_result_attrs, result_names):
|
|
attrs['jax.result_info'] = ir.StringAttr.get(name_)
|
|
|
|
if use_sharding_annotations and ir_result_shardings is not None:
|
|
for attrs, sharding in zip(result_attrs, ir_result_shardings):
|
|
if sharding is not None:
|
|
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
|
|
|
|
if ir_result_memory_kinds is not None:
|
|
for attrs, mem_kind in zip(result_attrs, ir_result_memory_kinds):
|
|
if mem_kind is not None:
|
|
attrs['mhlo.memory_kind'] = ir.StringAttr.get(mem_kind)
|
|
|
|
if ir_result_layouts is not None:
|
|
for attrs, layout in zip(result_attrs, ir_result_layouts):
|
|
if layout is not None:
|
|
attrs['mhlo.layout_mode'] = ir.StringAttr.get(layout)
|
|
|
|
func_op.result_attrs = ir.ArrayAttr.get(
|
|
[ir.DictAttr.get(attrs) for attrs in result_attrs])
|
|
|
|
if arg_names:
|
|
arg_locs = [ir.Location.unknown()] * (num_dim_vars + num_tokens)
|
|
for n in arg_names:
|
|
arg_locs.append(ir.Location.name(n) if n else ir.Location.unknown())
|
|
entry_block = func_op.add_entry_block(arg_locs)
|
|
else:
|
|
entry_block = func_op.add_entry_block()
|
|
|
|
with ir.InsertionPoint(entry_block):
|
|
flat_args = entry_block.arguments
|
|
# We separate out the dimension variable inputs, the token inputs and
|
|
# the regular inputs. The dimension variables and token inputs
|
|
# will be passed to `jaxpr_subcomp` separately from the `args`.
|
|
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
|
|
# A lowering context just for function body entry/exit code.
|
|
entry_lowering_ctx = LoweringRuleContext(
|
|
module_context=ctx, name_stack=name_stack, primitive=None,
|
|
avals_in=[], avals_out=None,
|
|
tokens_in=TokenSet.create([]), tokens_out=None,
|
|
axis_size_env=None, dim_var_values=dim_var_values)
|
|
if not use_sharding_annotations and ir_arg_shardings is not None:
|
|
flat_args = [
|
|
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
|
|
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
|
|
|
|
if ir_arg_shardings is not None and name == "main":
|
|
flat_args = [
|
|
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore
|
|
if (a is not core.abstract_token and
|
|
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore
|
|
for o, s, a in zip(flat_args, ir_arg_shardings, input_avals)
|
|
]
|
|
|
|
_, token_args, unflattened_args = util.split_list(
|
|
util.unflatten(flat_args, map(len, input_types)),
|
|
[num_dim_vars, num_tokens])
|
|
if create_tokens:
|
|
tokens_in = TokenSet.create(effects)
|
|
else:
|
|
tokens_in = TokenSet(zip(effects, token_args))
|
|
args: list[list[ir.Value]] = []
|
|
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
|
|
if replace_tokens_with_dummy and aval is core.abstract_token:
|
|
args.append([hlo.create_token()])
|
|
else:
|
|
args.append(arg)
|
|
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
|
|
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
|
out_vals, tokens_out = jaxpr_subcomp(
|
|
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
|
|
consts, *args, dim_var_values=dim_var_values)
|
|
outs = []
|
|
if create_tokens:
|
|
for _ in range(num_output_tokens):
|
|
outs.append(dummy_token())
|
|
for _ in effects:
|
|
outs.append(dummy_token())
|
|
else:
|
|
for eff in effects:
|
|
outs.append(tokens_out.get(eff))
|
|
for aval, out in zip(jaxpr.out_avals, out_vals):
|
|
if replace_tokens_with_dummy and aval is core.abstract_token:
|
|
outs.append(ir_constants(np.zeros((), np.bool_)))
|
|
else:
|
|
outs.append(out)
|
|
|
|
flat_outputs = util.flatten(outs)
|
|
|
|
if not use_sharding_annotations and ir_result_shardings is not None:
|
|
flat_outputs = [
|
|
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
|
|
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
|
|
|
|
# Insert a custom call if output is on host because XLA needs that to do the
|
|
# transfer.
|
|
if ir_result_memory_kinds is not None:
|
|
flat_outputs = [
|
|
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
|
|
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
|
|
|
|
if ir_result_shardings is not None and name == "main":
|
|
flat_outputs = [
|
|
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # type: ignore
|
|
if (a is not core.abstract_token and
|
|
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # type: ignore
|
|
for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals)
|
|
]
|
|
|
|
func_dialect.return_(flat_outputs)
|
|
|
|
return func_op
|
|
|
|
|
|
def wrap_with_memory_kind(
|
|
x: ir.Value, memory_kind: str, aval_out: core.AbstractValue) -> ir.Value:
|
|
if aval_out is None:
|
|
result_type = x.type
|
|
else:
|
|
result_type = aval_to_ir_type(aval_out)
|
|
op = custom_call("annotate_device_placement", result_types=[result_type],
|
|
operands=[x], has_side_effect=True, api_version=1)
|
|
dict_attr = {"_xla_buffer_placement": ir.StringAttr.get(memory_kind)}
|
|
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
|
|
return op.result
|
|
|
|
|
|
def _emit_lowering_rule_as_fun(lowering_rule,
|
|
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
|
|
"""Emits the contents of a lowering rule as a private function."""
|
|
num_dim_vars = len(ctx.module_context.shape_poly_state.dim_vars)
|
|
# TODO(necula) maybe only pass the dim_vars if they are needed?
|
|
dim_var_types = map(aval_to_ir_types, [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars)
|
|
|
|
input_types = map(aval_to_ir_types, ctx.avals_in)
|
|
output_types = map(aval_to_ir_types, ctx.avals_out)
|
|
effs = list(ctx.tokens_in.effects())
|
|
token_types = [token_type() for _ in effs]
|
|
input_types = [*dim_var_types, *token_types, *input_types]
|
|
output_types = [*token_types, *output_types]
|
|
|
|
flat_input_types = util.flatten(input_types)
|
|
flat_output_types = util.flatten(output_types)
|
|
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
|
|
assert ctx.primitive is not None
|
|
func_op = func_dialect.FuncOp(ctx.primitive.name, ftype,
|
|
ip=ctx.module_context.ip)
|
|
func_op.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
|
ctx.module_context.symbol_table.insert(func_op)
|
|
entry_block = func_op.add_entry_block()
|
|
with ir.InsertionPoint(entry_block):
|
|
unflattened_args = util.unflatten(entry_block.arguments,
|
|
map(len, input_types))
|
|
dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, len(ctx.tokens_in)])
|
|
sub_ctx = ctx.replace(tokens_in=TokenSet(zip(effs, token_args)),
|
|
dim_var_values=dim_var_values)
|
|
outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args))
|
|
if sub_ctx.tokens_out:
|
|
outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs]
|
|
func_dialect.return_(util.flatten(map(wrap_singleton_ir_values, outs)))
|
|
return func_op
|
|
|
|
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
|
name_stack: source_info_util.NameStack,
|
|
tokens: TokenSet,
|
|
consts: Sequence[Sequence[ir.Value]],
|
|
*args: Sequence[ir.Value],
|
|
dim_var_values: Sequence[ir.Value]
|
|
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
|
|
"""Lowers a jaxpr into MLIR, inlined into an existing function.
|
|
|
|
Assumes that an MLIR context, location, and insertion point are set.
|
|
|
|
dim_var_values: the list of dimension variables values in the current
|
|
IR function, in the order of ctx.shape_poly_state.dim_vars.
|
|
"""
|
|
assert "gpu" not in ctx.platforms
|
|
def read(v: core.Atom) -> Sequence[ir.Value]:
|
|
if type(v) is core.Literal:
|
|
return ir_constants(xla.canonicalize_dtype(v.val))
|
|
else:
|
|
assert isinstance(v, core.Var)
|
|
return env[v]
|
|
|
|
def aval(v: core.Atom) -> core.AbstractValue:
|
|
if type(v) is core.Literal:
|
|
return xla.abstractify(v.val)
|
|
else:
|
|
return v.aval
|
|
|
|
def write(v: core.Var, node: Sequence[ir.Value]):
|
|
assert node is not None
|
|
env[v] = tuple(node)
|
|
|
|
def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None:
|
|
if ctx.lowering_parameters.override_lowering_rules is None:
|
|
return None
|
|
for p, rule in ctx.lowering_parameters.override_lowering_rules:
|
|
if primitive is p:
|
|
return rule
|
|
return None
|
|
|
|
env: dict[core.Var, tuple[ir.Value, ...]] = {}
|
|
|
|
assert isinstance(name_stack, source_info_util.NameStack), type(name_stack)
|
|
assert len(args) == len(jaxpr.invars), (jaxpr, args)
|
|
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
|
|
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
|
|
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
|
|
map(write, jaxpr.constvars, consts)
|
|
map(write, jaxpr.invars, args)
|
|
last_used = core.last_used(jaxpr)
|
|
for eqn in jaxpr.eqns:
|
|
in_nodes = map(read, eqn.invars)
|
|
source_info = eqn.source_info.replace(
|
|
name_stack=name_stack + eqn.source_info.name_stack)
|
|
loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
|
|
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
|
override_rule = get_override_lowering_rule(eqn.primitive)
|
|
platform_rules: dict[str, LoweringRule] = {}
|
|
default_rule: LoweringRule | None = None
|
|
# See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule`
|
|
if override_rule is not None:
|
|
default_rule = override_rule
|
|
else:
|
|
# First the platform-specific rules
|
|
for p in ctx.platforms:
|
|
if eqn.primitive in _platform_specific_lowerings[p]:
|
|
platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive]
|
|
elif eqn.primitive in xla._backend_specific_translations[p]:
|
|
platform_rules[p] = xla_fallback_lowering(eqn.primitive)
|
|
# Now the default rule
|
|
if eqn.primitive in _lowerings:
|
|
default_rule = _lowerings[eqn.primitive]
|
|
elif eqn.primitive in xla._translations:
|
|
default_rule = xla_fallback_lowering(eqn.primitive)
|
|
|
|
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
|
tokens_in = tokens.subset(effects)
|
|
avals_in = map(aval, eqn.invars)
|
|
rule_ctx = LoweringRuleContext(
|
|
module_context=ctx, primitive=eqn.primitive,
|
|
name_stack=source_info.name_stack,
|
|
avals_in=avals_in,
|
|
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
|
tokens_out=None, dim_var_values=dim_var_values)
|
|
if config.dynamic_shapes.value:
|
|
axis_size_env = {d: read(d)[0]
|
|
for a in avals_in if type(a) is core.DShapedArray
|
|
for d in a.shape if type(d) is core.Var}
|
|
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
|
|
|
|
rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
|
|
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
|
|
platform_rules, default_rule,
|
|
eqn.effects,
|
|
*rule_inputs, **eqn.params)
|
|
|
|
if effects:
|
|
# If there were ordered effects in the primitive, there should be output
|
|
# tokens we need for subsequent ordered effects.
|
|
tokens_out = rule_ctx.tokens_out
|
|
if tokens_out is None:
|
|
raise ValueError(
|
|
f'Lowering rule for `{eqn.primitive}` needs to set `tokens_out` '
|
|
f'because it has effects: {eqn.effects}.')
|
|
if tokens_out.effects() != tokens_in.effects():
|
|
raise ValueError(
|
|
f'Lowering rule for `{eqn.primitive}` '
|
|
'returns incorrect set of output tokens. '
|
|
f'Expected: {tuple(tokens_in.effects())} vs. Actual: {tuple(tokens_out.effects())}')
|
|
tokens = tokens.update_tokens(tokens_out)
|
|
|
|
try:
|
|
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
|
|
except TypeError as e:
|
|
raise ValueError("Output of translation rule must be iterable: "
|
|
f"{eqn}, got output {ans}") from e
|
|
|
|
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
|
|
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
|
|
ans, "lowering function returned a bad output", eqn)
|
|
assert len(ans) == len(eqn.outvars), (ans, eqn)
|
|
map(write, eqn.outvars, out_nodes)
|
|
core.clean_up_dead_vars(eqn, env, last_used)
|
|
return map(read, jaxpr.outvars), tokens
|
|
|
|
|
|
def lower_per_platform(ctx: LoweringRuleContext,
|
|
description: str,
|
|
platform_rules: dict[str, LoweringRule],
|
|
default_rule: LoweringRule | None,
|
|
effects: effects_lib.Effects,
|
|
*rule_args: ir.Value,
|
|
**rule_kwargs) -> ir.Value:
|
|
"""Emits code for a primitive for the current lowering platform(s).
|
|
|
|
For example, given
|
|
platform_rules = dict(tpu=rule0, cpu=rule0)
|
|
default_rule = rule1
|
|
|
|
and
|
|
ctx.module_context.lowering_parameters.platforms = ("cpu",)
|
|
|
|
emits:
|
|
rule0(ctx, *rule_args, **rule_kwargs)
|
|
|
|
In case of multi-platform lowering, e.g., if
|
|
ctx.module_context.lowering_parameters.platforms = ("cpu", "cuda", "tpu")
|
|
|
|
emits:
|
|
rule_idx = case current_platform_idx:
|
|
0: return 0 # cpu rule index
|
|
1: return 1 # cuda rule index
|
|
2: return 0 # tpu rule index
|
|
output = case rule_idx
|
|
0: return rule0(*rule_args, **rule_kwargs)
|
|
1: return rule1(*rule_args, **rule_kwargs)
|
|
|
|
Args:
|
|
ctx: lowering context.
|
|
description: a string to include in error messages.
|
|
platform_rules: map platform names, e.g., "cpu", "cuda", to
|
|
`LoweringRule`s, for the platforms that have non-default lowering.
|
|
default_rule: an optional rule to use for platforms not in `platform_rules`.
|
|
effects: the set of effects for the current primitive.
|
|
rule_args: the args of the lowering rules.
|
|
rule_kwargs: the kwargs of the lowering rules.
|
|
"""
|
|
platforms: Sequence[str] = ctx.module_context.platforms
|
|
# Special case the common case (single-platform lowering)
|
|
if len(platforms) == 1:
|
|
rule = platform_rules.get(platforms[0], default_rule)
|
|
if rule is None:
|
|
raise NotImplementedError(
|
|
f"MLIR translation rule for primitive '{description}' not "
|
|
f"found for platform {platforms[0]}")
|
|
|
|
# Multi-platform lowering
|
|
kept_rules: list[LoweringRule] = [] # Only the rules for the platforms of interest
|
|
platform_to_kept_rules_idx: dict[str, int] = {}
|
|
for p, prule in platform_rules.items():
|
|
if p not in platforms:
|
|
continue
|
|
platform_to_kept_rules_idx[p] = len(kept_rules)
|
|
kept_rules.append(prule)
|
|
|
|
platforms_without_specific_rule = [p for p in platforms
|
|
if p not in platform_to_kept_rules_idx]
|
|
if platforms_without_specific_rule:
|
|
if default_rule is None:
|
|
raise NotImplementedError(
|
|
f"MLIR translation rule for primitive '{description}' not "
|
|
f"found for platforms {platforms_without_specific_rule}")
|
|
for p in platforms_without_specific_rule:
|
|
platform_to_kept_rules_idx[p] = len(kept_rules)
|
|
kept_rules.append(default_rule)
|
|
|
|
assert kept_rules
|
|
# If there is a single rule left just apply the rule, without conditionals.
|
|
if len(kept_rules) == 1:
|
|
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
|
|
|
|
assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
|
|
assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable"
|
|
|
|
# The first dim_var_values is the platform index
|
|
current_platform_idx = ctx.dim_var_values[0]
|
|
# Compute the rule index based on the current platform
|
|
i32_type = aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
|
|
if current_platform_idx.type != i32_type:
|
|
current_platform_idx = hlo.convert(i32_type, current_platform_idx)
|
|
rule_idx_op = hlo.CaseOp([i32_type],
|
|
index=current_platform_idx,
|
|
num_branches=len(platforms))
|
|
for i, p in enumerate(platforms):
|
|
branch = rule_idx_op.regions[i].blocks.append()
|
|
with ir.InsertionPoint(branch):
|
|
hlo.return_(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
|
|
ordered_effects = effects_lib.ordered_effects.filter_in(effects)
|
|
rule_out_avals = [core.abstract_token] * len(ordered_effects) + ctx.avals_out
|
|
output_types = map(aval_to_ir_types, rule_out_avals)
|
|
case_op = hlo.CaseOp(util.flatten(output_types),
|
|
index=rule_idx_op,
|
|
num_branches=len(kept_rules))
|
|
for i, rule in enumerate(kept_rules):
|
|
inner_ctx = ctx.replace()
|
|
branch = case_op.regions[i].blocks.append()
|
|
with ir.InsertionPoint(branch):
|
|
output = rule(inner_ctx, *rule_args, **rule_kwargs)
|
|
try:
|
|
out_nodes = map(wrap_singleton_ir_values, output)
|
|
except TypeError as e:
|
|
raise ValueError("Output of translation rule must be iterable: "
|
|
f"{description}, got output {output}") from e
|
|
if inner_ctx.tokens_out is not None:
|
|
assert len(ordered_effects) == len(inner_ctx.tokens_out)
|
|
out_nodes = [inner_ctx.tokens_out.get(eff)
|
|
for eff in ordered_effects] + out_nodes
|
|
hlo.return_(util.flatten(map(wrap_singleton_ir_values, out_nodes)))
|
|
|
|
results = case_op.results
|
|
if ordered_effects:
|
|
tokens, results = util.split_list(
|
|
util.unflatten(results, map(len, output_types)),
|
|
[len(ordered_effects)])
|
|
tokens_out = ctx.tokens_in.update_tokens(TokenSet(zip(ordered_effects,
|
|
tokens)))
|
|
ctx.set_tokens_out(tokens_out)
|
|
return results
|
|
|
|
def _ir_consts(consts):
|
|
unique_consts = {id(const): const for const in consts}
|
|
ir_consts = {
|
|
id_: ir_constants(xla.canonicalize_dtype(const))
|
|
for id_, const in unique_consts.items()
|
|
}
|
|
return [ir_consts[id(const)] for const in consts]
|
|
|
|
|
|
def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
|
"""Converts a traceable JAX function `fun` into a lowering rule.
|
|
|
|
The returned function does not use `avals_out`, so callers may pass any value
|
|
as `avals_out`."""
|
|
def f_lowered(ctx, *args, **params):
|
|
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
|
wrapped_fun = lu.wrap_init(f, params)
|
|
|
|
if config.dynamic_shapes.value:
|
|
# We might be applying this function to arguments with dynamic shapes,
|
|
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
|
|
# case, we need to form a jaxpr with leading binders for those axis size
|
|
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
|
|
# and we need to call jaxpr_subcomp with these arguments made explicit.
|
|
args = (*ctx.axis_size_env.values(), *args)
|
|
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
|
|
i32_aval = core.ShapedArray((), np.dtype('int32'))
|
|
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
|
|
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape))
|
|
if type(a) is core.DShapedArray else a, True)
|
|
for a in ctx.avals_in]
|
|
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
|
|
else:
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
|
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
|
|
|
out, tokens = jaxpr_subcomp(
|
|
ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
|
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
|
dim_var_values=ctx.dim_var_values)
|
|
ctx.set_tokens_out(tokens)
|
|
return out
|
|
|
|
return f_lowered
|
|
|
|
|
|
def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack,
|
|
arg_names=None, result_names=None):
|
|
if not call_jaxpr.consts and arg_names is result_names is None:
|
|
# Cacheable.
|
|
key = (fn_name, call_jaxpr.jaxpr, tuple(effects))
|
|
try:
|
|
func_op = ctx.cached_primitive_lowerings[key]
|
|
except KeyError:
|
|
func_op = lower_jaxpr_to_fun(
|
|
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
|
|
result_names=result_names)
|
|
ctx.cached_primitive_lowerings[key] = func_op
|
|
else:
|
|
func_op = lower_jaxpr_to_fun(
|
|
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
|
|
result_names=result_names)
|
|
return func_op
|
|
|
|
|
|
def check_backend_matches(inner_backend: str | None,
|
|
lowering_platforms: Sequence[str]):
|
|
# For nested calls, the outermost call sets the backend for all inner calls;
|
|
# it's an error if the inner call has a conflicting explicit backend spec.
|
|
if inner_backend is None:
|
|
return
|
|
outer_backend, *more_lowering_platforms = lowering_platforms
|
|
if more_lowering_platforms:
|
|
raise NotImplementedError(
|
|
"Multi-platform lowering when a backend= parameter is specified")
|
|
if (inner_backend != outer_backend and
|
|
outer_backend not in xb.expand_platform_alias(inner_backend)):
|
|
raise ValueError(
|
|
f"Outer-jit backend specification {outer_backend} must match explicit "
|
|
f"inner-jit backend specification {inner_backend}.")
|
|
|
|
|
|
def call_lowering(fn_name, name_stack, call_jaxpr, backend,
|
|
ctx: ModuleContext, avals_in,
|
|
avals_out, tokens_in, *args,
|
|
dim_var_values: Sequence[ir.Value],
|
|
arg_names=None, result_names=None):
|
|
del avals_in
|
|
if isinstance(call_jaxpr, core.Jaxpr):
|
|
call_jaxpr = pe.close_jaxpr(call_jaxpr)
|
|
check_backend_matches(backend, ctx.platforms)
|
|
effects = list(tokens_in.effects())
|
|
output_types = map(aval_to_ir_types, avals_out)
|
|
output_types = [token_type()] * len(effects) + output_types
|
|
flat_output_types = util.flatten(output_types)
|
|
symbol_name = _lower_jaxpr_to_fun_cached(
|
|
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
|
|
result_names=result_names).name.value
|
|
tokens = [tokens_in.get(eff) for eff in effects]
|
|
args = (*dim_var_values, *tokens, *args)
|
|
call = func_dialect.CallOp(flat_output_types,
|
|
ir.FlatSymbolRefAttr.get(symbol_name),
|
|
flatten_lowering_ir_args(args))
|
|
out_nodes = util.unflatten(call.results, map(len, output_types))
|
|
tokens, out_nodes = util.split_list(out_nodes, [len(effects)])
|
|
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
|
|
return out_nodes, tokens_out
|
|
|
|
def core_call_lowering(ctx: LoweringRuleContext,
|
|
*args, name, backend=None, call_jaxpr):
|
|
out_nodes, tokens = call_lowering(
|
|
name, ctx.name_stack, call_jaxpr, backend, ctx.module_context,
|
|
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
|
|
dim_var_values=ctx.dim_var_values)
|
|
ctx.set_tokens_out(tokens)
|
|
return out_nodes
|
|
|
|
register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
|
|
register_lowering(core.closed_call_p,
|
|
partial(core_call_lowering, name="core_closed_call"))
|
|
|
|
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
|
|
broadcast_dimensions) -> ir.Value:
|
|
# broadcast_dimension[i] is the axis of the result where the axis i of
|
|
# op is broadcast.
|
|
# Lower a possibly-dynamic broadcast_in_dim
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended): # type: ignore
|
|
elt_shape = aval_out.dtype._rules.physical_element_aval( # type: ignore
|
|
aval_out.dtype).shape # type: ignore
|
|
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] # type: ignore
|
|
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
|
|
physical_aval_out = core.physical_aval(aval_out)
|
|
return broadcast_in_dim(
|
|
ctx, op, physical_aval_out, broadcast_dimensions=broadcast_dimensions)
|
|
else:
|
|
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
|
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore
|
|
return hlo.dynamic_broadcast_in_dim(
|
|
aval_to_ir_type(aval_out), op,
|
|
shape,
|
|
dense_int_array_v6(broadcast_dimensions),
|
|
)
|
|
else:
|
|
assert all(d != ir.ShapedType.get_dynamic_size()
|
|
for d in aval_out.shape), aval_out # type: ignore
|
|
return hlo.broadcast_in_dim(
|
|
aval_to_ir_type(aval_out), op,
|
|
dense_int_array_v6(broadcast_dimensions))
|
|
|
|
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
|
|
ops: Sequence[ir.Value],
|
|
ops_avals: Sequence[core.AbstractValue],
|
|
out_shape: core.Shape) -> Sequence[ir.Value]:
|
|
"""Broadcasts multiple ops to the out_shape."""
|
|
out = []
|
|
for op, op_aval in zip(ops, ops_avals):
|
|
op_aval_shape = op_aval.shape # type: ignore
|
|
if core.definitely_equal_shape(op_aval_shape, out_shape): # type: ignore
|
|
out.append(op)
|
|
else:
|
|
assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape)
|
|
broadcast_dimensions = list(range(len(out_shape) - len(op_aval_shape), len(out_shape)))
|
|
out.append(broadcast_in_dim(ctx, op,
|
|
core.ShapedArray(out_shape, op_aval.dtype), # type: ignore
|
|
broadcast_dimensions=broadcast_dimensions))
|
|
return out
|
|
|
|
def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Value:
|
|
aval_out = core.physical_aval(aval_out)
|
|
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
|
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore
|
|
return hlo.dynamic_reshape(
|
|
aval_to_ir_type(aval_out), op, shape,
|
|
)
|
|
else:
|
|
return hlo.reshape(aval_to_ir_type(aval_out), op)
|
|
|
|
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
|
start_indices, limit_indices, strides) -> ir.Value:
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
|
elt_shape = aval_out.dtype._rules.physical_element_aval(
|
|
aval_out.dtype).shape
|
|
trailing_zeros = [0] * len(elt_shape)
|
|
trailing_ones = [1] * len(elt_shape)
|
|
start_indices = (*start_indices, *trailing_zeros)
|
|
limit_indices = (*limit_indices, *elt_shape)
|
|
strides = (*strides, *trailing_ones)
|
|
physical_aval_out = core.physical_aval(aval_out)
|
|
return slice_op(ctx, x, physical_aval_out, start_indices=start_indices,
|
|
limit_indices=limit_indices, strides=strides)
|
|
else:
|
|
if any(not core.is_constant_shape(s) for s in (start_indices, limit_indices, strides)):
|
|
start_indices = eval_dynamic_shape_as_tensor(ctx, start_indices)
|
|
limit_indices = eval_dynamic_shape_as_tensor(ctx, limit_indices)
|
|
strides = eval_dynamic_shape_as_tensor(ctx, strides)
|
|
return hlo.real_dynamic_slice(
|
|
aval_to_ir_type(aval_out),
|
|
x, start_indices, limit_indices, strides)
|
|
else:
|
|
return hlo.slice(x,
|
|
dense_int_array(start_indices),
|
|
dense_int_array(limit_indices),
|
|
dense_int_array(strides))
|
|
|
|
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
|
start_indices) -> ir.Value:
|
|
x_aval = ctx.avals_in[0]
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
|
elt_shape = aval_out.dtype._rules.physical_element_aval(
|
|
aval_out.dtype).shape
|
|
index_avals = ctx.avals_in[1:]
|
|
dtype = dtypes.canonicalize_dtype(
|
|
index_avals[0].dtype if index_avals else 'int64') # type: ignore
|
|
trailing_zeros = [ir_constant(np.array(0, dtype))] * len(elt_shape)
|
|
start_indices = (*start_indices, *trailing_zeros)
|
|
aval_out = core.physical_aval(aval_out)
|
|
x_aval = core.physical_aval(x_aval)
|
|
|
|
slice_sizes = aval_out.shape
|
|
if not core.is_constant_shape(slice_sizes):
|
|
# lax.dynamic_slice clamps the start indices, but we are going to
|
|
# lower to RealDynamicSliceOp, which is a version of SliceOp, and does
|
|
# not have the clamping behavior. We clamp start ourselves.
|
|
slice_sizes = eval_dynamic_shape_as_tensor(ctx, slice_sizes)
|
|
clamped_start = hlo.clamp(
|
|
shape_tensor([0] * len(start_indices)),
|
|
shape_tensor(start_indices),
|
|
hlo.subtract(
|
|
eval_dynamic_shape_as_tensor(ctx, x_aval.shape), # type: ignore
|
|
slice_sizes))
|
|
return hlo.real_dynamic_slice(
|
|
aval_to_ir_type(aval_out), x,
|
|
clamped_start,
|
|
hlo.add(clamped_start, slice_sizes),
|
|
shape_tensor([1] * len(start_indices))
|
|
)
|
|
else:
|
|
return hlo.dynamic_slice(x, start_indices, dense_int_array(slice_sizes))
|
|
|
|
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
|
|
start_indices) -> ir.Value:
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
|
elt_shape = aval_out.dtype._rules.physical_element_aval(
|
|
aval_out.dtype).shape
|
|
index_avals = ctx.avals_in[2:]
|
|
dtype = dtypes.canonicalize_dtype(
|
|
index_avals[0].dtype if index_avals else 'int64') # type: ignore
|
|
zeros = [ir_constant(np.array(0, dtype=dtype))] * len(elt_shape)
|
|
start_indices = (*start_indices, *zeros)
|
|
physical_aval_out = core.physical_aval(aval_out)
|
|
return dynamic_update_slice(ctx, physical_aval_out, x, update,
|
|
start_indices=start_indices)
|
|
else:
|
|
# TODO(necula): handle dynamic shapes
|
|
return hlo.dynamic_update_slice(x, update, start_indices)
|
|
|
|
def pad(ctx: LoweringRuleContext, aval_out,
|
|
x, padding_value,
|
|
padding_low, padding_high, padding_interior) -> ir.Value:
|
|
if all(core.is_constant_shape(s) for s in (padding_low,
|
|
padding_high, padding_interior)):
|
|
return hlo.pad(x, padding_value,
|
|
dense_int_array(padding_low),
|
|
dense_int_array(padding_high),
|
|
dense_int_array(padding_interior))
|
|
else:
|
|
padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low)
|
|
padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high)
|
|
padding_interior = eval_dynamic_shape_as_tensor(ctx, padding_interior)
|
|
return hlo.dynamic_pad(
|
|
aval_to_ir_type(aval_out),
|
|
x, padding_value, padding_low, padding_high, padding_interior)
|
|
|
|
def iota(ctx: LoweringRuleContext, aval_out, *, dimension: int):
|
|
if not core.is_constant_shape(aval_out.shape):
|
|
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
|
|
return hlo.dynamic_iota(
|
|
aval_to_ir_type(aval_out),
|
|
shape,
|
|
i64_attr(dimension),
|
|
)
|
|
else:
|
|
return hlo.iota(aval_to_ir_type(aval_out), i64_attr(dimension))
|
|
|
|
def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
|
|
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
|
|
zero = ir_constant(np.array(value, dtypes.canonicalize_dtype(aval.dtype)))
|
|
return broadcast_in_dim(ctx, zero, aval, broadcast_dimensions=())
|
|
|
|
def add_jaxvals_lowering(ctx, x, y):
|
|
if (isinstance(a := ctx.avals_in[0], core.ShapedArray) and
|
|
dtypes.issubdtype(a.dtype, dtypes.extended)):
|
|
return lower_fun(lambda x, y: [a.dtype._rules.add(a.dtype, x, y)])(ctx, x, y) # type: ignore
|
|
return [hlo.add(x, y)]
|
|
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
|
|
|
|
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
|
|
|
|
|
|
def compare_hlo(x, y, direction: str, comparison_type: str | None = None):
|
|
"""Creates CompareOp."""
|
|
if comparison_type is None:
|
|
elem_type = ir.RankedTensorType(x.type).element_type
|
|
if ir.IntegerType.isinstance(elem_type):
|
|
comparison_type = ("UNSIGNED" if ir.IntegerType.is_unsigned(elem_type)
|
|
else "SIGNED")
|
|
else:
|
|
comparison_type = "FLOAT"
|
|
|
|
return hlo.compare(
|
|
x,
|
|
y,
|
|
hlo.ComparisonDirectionAttr.get(direction),
|
|
compare_type=hlo.ComparisonTypeAttr.get(comparison_type))
|
|
|
|
def _minmax_hlo(op, cmp, x, y):
|
|
"""Min/max that compares complex values lexicographically as pairs."""
|
|
tensor_type = ir.RankedTensorType(x.type)
|
|
if ir.ComplexType.isinstance(tensor_type.element_type):
|
|
rx = hlo.real(x)
|
|
ry = hlo.real(y)
|
|
real_eq = compare_hlo(rx, ry, "EQ", "FLOAT")
|
|
real_cmp = compare_hlo(rx, ry, cmp, "FLOAT")
|
|
imag_cmp = compare_hlo(hlo.imag(x), hlo.imag(y), cmp, "FLOAT")
|
|
which = hlo.select(real_eq, imag_cmp, real_cmp)
|
|
return hlo.select(which, x, y)
|
|
else:
|
|
return op(x, y)
|
|
|
|
min_hlo = partial(_minmax_hlo, hlo.minimum, "LT")
|
|
max_hlo = partial(_minmax_hlo, hlo.maximum, "GT")
|
|
|
|
|
|
def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
|
"""Variant of convert that has HLO semantics.
|
|
|
|
In particular, treat casts to boolean as x != 0, rather than truncating
|
|
integer values (b/209440332)."""
|
|
if (not dtypes.issubdtype(aval_out.dtype, dtypes.extended) and
|
|
aval_out.dtype == np.dtype(np.bool_)):
|
|
if dtypes.issubdtype(aval_in.dtype, np.inexact):
|
|
compare_type = "FLOAT"
|
|
elif dtypes.issubdtype(aval_in.dtype, np.signedinteger):
|
|
compare_type = "SIGNED"
|
|
else:
|
|
compare_type = "UNSIGNED"
|
|
x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE", compare_type)
|
|
# continue, to adjust the shape if needed
|
|
return hlo.convert(aval_to_ir_type(aval_out), x)
|
|
|
|
def _wrap_with_spmd_op(name: str,
|
|
ctx: LoweringRuleContext,
|
|
x: ir.Value,
|
|
aval_out: core.AbstractValue,
|
|
sharding_proto: xc.OpSharding,
|
|
unspecified_dims: set[int] | None = None,
|
|
has_side_effect: bool = False):
|
|
# unspecified_dims indicate dimensions whose shardings are not specified and
|
|
# XLA sharding propagation can change them.
|
|
if unspecified_dims:
|
|
backend_config = "unspecified_dims=[" + ",".join(
|
|
[str(i) for i in sorted(unspecified_dims)]) + "]"
|
|
else:
|
|
backend_config = ""
|
|
result_type = aval_to_ir_type(aval_out)
|
|
out_shape = core.physical_aval(aval_out).shape # type: ignore
|
|
if core.is_constant_shape(out_shape):
|
|
result_shapes = None
|
|
else:
|
|
result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)]
|
|
|
|
op = custom_call(name, result_types=[result_type], operands=[x],
|
|
backend_config=backend_config,
|
|
api_version=1,
|
|
result_shapes=result_shapes,
|
|
has_side_effect=has_side_effect)
|
|
set_sharding(op, sharding_proto)
|
|
return op.result
|
|
|
|
|
|
wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding")
|
|
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
|
|
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
|
|
|
|
def set_sharding(op, sharding_proto: xc.OpSharding):
|
|
op.attributes["mhlo.sharding"] = get_sharding_attr(sharding_proto)
|
|
|
|
|
|
def get_sharding_attr(sharding_proto: xc.OpSharding):
|
|
# If there are very large numbers of devices, use the proto representation.
|
|
# The MHLO to HLO conversion supports both, and the proto representation is
|
|
# more compact.
|
|
if len(sharding_proto.tile_assignment_devices) > 100:
|
|
return ir.StringAttr.get(sharding_proto.SerializeToString())
|
|
else:
|
|
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
|
|
|
|
|
|
# MLIR lowerings for lax primitives
|
|
|
|
def cache_lowering(f):
|
|
"""Decorator that causes the contents of a lowering rule to be reused.
|
|
|
|
The lowering will be emitted out-of-line in a separate function, together with
|
|
a call to that function. If the same primitive is called with the same shapes
|
|
and parameters, a new call to the original function will be added, without
|
|
emitting a new function. We allow for different lowering for the same
|
|
primitive for different platforms in the same module.
|
|
"""
|
|
@functools.wraps(f)
|
|
def cached_lowering(ctx, *args, **params):
|
|
assert ctx.primitive is not None
|
|
key = (f, ctx.primitive,
|
|
tuple(ctx.avals_in), tuple(ctx.avals_out),
|
|
tuple(params.items()))
|
|
try:
|
|
func = ctx.module_context.cached_primitive_lowerings.get(key)
|
|
except TypeError:
|
|
# If the parameters aren't hashable, give up on caching.
|
|
# TODO(phawkins): switch to requiring hashability, when XLA fallback
|
|
# computations have been ported to MLIR.
|
|
return f(ctx, *args, **params)
|
|
if func is None:
|
|
func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
|
|
ctx.module_context.cached_primitive_lowerings[key] = func
|
|
|
|
output_types = map(aval_to_ir_types, ctx.avals_out)
|
|
args = tuple(ctx.dim_var_values) + args
|
|
flat_output_types = util.flatten(output_types)
|
|
call = func_dialect.CallOp(flat_output_types,
|
|
ir.FlatSymbolRefAttr.get(func.name.value),
|
|
flatten_lowering_ir_args(args))
|
|
return util.unflatten(call.results, map(len, output_types))
|
|
return cached_lowering
|
|
|
|
|
|
def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
|
|
) -> ir.Module:
|
|
module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
|
|
return ir.Module.parse(module_str)
|
|
|
|
def merge_mlir_modules(dst_module: ir.Module,
|
|
sym_name: str,
|
|
src_module: ir.Module) -> str:
|
|
"""
|
|
Args:
|
|
dst_module: the module into which the contents of src_module should be
|
|
moved. Nothing in dst_module will be renamed.
|
|
sym_name: the desired name for the "main" function of src_module after
|
|
merging. This is a hint: the true name may be different because of symbol
|
|
uniquification, and the true name is returned by this function.
|
|
src_module: the module whose contents are to be alpha-renamed, set to
|
|
private visibility, and merged into dst_module. src_module must contain
|
|
exactly one symbol named "main".
|
|
|
|
Functions in src_module will be renamed such that they do not collide with
|
|
functions in dst_module.
|
|
|
|
This function mutates `src_module`. On return, `src_module` is left in an
|
|
undefined state.
|
|
|
|
Returns:
|
|
the name of src_module's main() function, after renaming.
|
|
"""
|
|
assert dst_module.context == src_module.context
|
|
|
|
src_symtab = ir.SymbolTable(src_module.operation)
|
|
dst_symtab = ir.SymbolTable(dst_module.operation)
|
|
used_names = set()
|
|
|
|
# Rename all symbols in src_module that clash with names in dst_module, or
|
|
# are the "main" symbol.
|
|
renamings = {}
|
|
for op in src_module.body.operations:
|
|
name = op.name.value
|
|
should_rename = name in dst_symtab or name == "main"
|
|
if should_rename:
|
|
base_name = sym_name if name == "main" else name
|
|
new_name = base_name
|
|
i = 0
|
|
# Replacements are chosen such that the new names are present in neither
|
|
# src_module, dst_module, or the set of fresh names we've already used.
|
|
# Since we rename names one at a time, if new names were in src_module,
|
|
# they might themselves collide with a later renaming.
|
|
while (new_name in src_symtab or new_name in dst_symtab or
|
|
new_name in used_names):
|
|
new_name = f"{base_name}_{i}"
|
|
i += 1
|
|
renamings[name] = new_name
|
|
used_names.add(new_name)
|
|
|
|
# Apply the symbol renamings to symbol definitions.
|
|
private = ir.StringAttr.get("private")
|
|
for op in src_module.body.operations:
|
|
if op.name.value in renamings:
|
|
src_symtab.set_symbol_name(op, renamings[op.name.value])
|
|
op.attributes["sym_visibility"] = private
|
|
|
|
# Apply the symbol renamings to symbol uses.
|
|
for old_name, new_name in renamings.items():
|
|
for op in src_module.body.operations:
|
|
src_symtab.replace_all_symbol_uses(old_name, new_name, op)
|
|
|
|
for op in src_module.body.operations:
|
|
dst_module.body.append(op)
|
|
|
|
return renamings["main"]
|
|
|
|
|
|
def xla_fallback_lowering(prim: core.Primitive):
|
|
@cache_lowering
|
|
def fallback(ctx: LoweringRuleContext, *args, **params):
|
|
module_ctx = ctx.module_context
|
|
axis_ctx = module_ctx.axis_context
|
|
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
|
axis_env = axis_ctx.unsafe_axis_env
|
|
else:
|
|
axis_env = module_ctx.axis_env
|
|
|
|
if any(hasattr(a, "shape") and
|
|
not core.is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
|
raise NotImplementedError(
|
|
f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623")
|
|
|
|
if len(module_ctx.platforms) > 1:
|
|
raise NotImplementedError(
|
|
"fallback lowering not implemented for multi-platform lowering")
|
|
xla_computation = xla.primitive_subcomputation(
|
|
module_ctx.platforms[0], axis_env, prim, ctx.avals_in,
|
|
ctx.avals_out, **params)
|
|
xla_module = xla_computation_to_mlir_module(xla_computation)
|
|
callee_name = merge_mlir_modules(
|
|
module_ctx.module, f"xla_fallback_{prim.name}", xla_module)
|
|
output_types = map(aval_to_ir_types, ctx.avals_out)
|
|
flat_output_types = util.flatten(output_types)
|
|
output_type = (ir.TupleType.get_tuple(flat_output_types)
|
|
if prim.multiple_results else flat_output_types[0])
|
|
|
|
call = func_dialect.CallOp([output_type],
|
|
ir.FlatSymbolRefAttr.get(callee_name),
|
|
flatten_lowering_ir_args(args)).result
|
|
if not prim.multiple_results:
|
|
return [call]
|
|
flat_results = [hlo.get_tuple_element(call, i32_attr(i))
|
|
for i in range(len(flat_output_types))]
|
|
|
|
return util.unflatten(flat_results, map(len, output_types))
|
|
return fallback
|
|
|
|
|
|
DEVICE_TO_DEVICE_TYPE = 1
|
|
SEND_TO_HOST_TYPE = 2
|
|
RECV_FROM_HOST_TYPE = 3
|
|
|
|
|
|
def is_empty_shape(s: core.Shape) -> bool:
|
|
return any(d == 0 for d in s)
|
|
|
|
def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
|
|
aval: core.ShapedArray, name: str, *,
|
|
sharding: xc.OpSharding | None = None) -> ir.Value:
|
|
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
|
|
send_op = hlo.SendOp([operand], token, channel_handle,
|
|
is_host_transfer=ir.BoolAttr.get(True))
|
|
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
|
dict(
|
|
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
|
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
|
if sharding is not None:
|
|
set_sharding(send_op, sharding)
|
|
return send_op.result
|
|
|
|
|
|
def receive_from_host(channel: int, token: hlo.TokenType,
|
|
out_aval: core.ShapedArray, name: str, *,
|
|
sharding: xc.OpSharding | None = None) -> ir.Value:
|
|
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
|
|
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
|
|
hlo.TokenType.get()], token, channel_handle,
|
|
is_host_transfer=ir.BoolAttr.get(True))
|
|
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
|
dict(
|
|
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
|
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
|
if sharding is not None:
|
|
set_sharding(recv_op, sharding)
|
|
# Token should be at the end of the results
|
|
result, token = recv_op.results
|
|
return token, result
|
|
|
|
|
|
def _emit_tpu_python_callback(
|
|
backend: xb.XlaBackend,
|
|
ctx: LoweringRuleContext,
|
|
callback,
|
|
token: Any | None,
|
|
operands: Sequence[ir.Value],
|
|
operand_avals: Sequence[core.ShapedArray],
|
|
operand_shapes: Sequence[xc.Shape],
|
|
result_avals: Sequence[core.ShapedArray],
|
|
result_shapes: Sequence[xc.Shape],
|
|
*,
|
|
sharding: xc.OpSharding | None = None
|
|
) -> tuple[Sequence[ir.Value], Any]:
|
|
token = token or hlo.create_token()
|
|
_wrapped_callback = callback
|
|
|
|
send_channels = []
|
|
if not operand_avals:
|
|
# If there are no operands to the callback, we need to insert a dummy send
|
|
# op or the callback will never be triggered!
|
|
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
|
|
# MLIR builder.
|
|
callback_without_args = _wrapped_callback
|
|
def _wrapped_callback(*args): # pylint: disable=function-redefined
|
|
del args
|
|
return callback_without_args()
|
|
send_channel = ctx.module_context.new_channel()
|
|
dummy_send_aval = core.ShapedArray((1,), np.float32)
|
|
dummy_send_val = ir_constant(np.zeros(1, np.float32))
|
|
operand_shapes = [*operand_shapes,
|
|
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
|
|
token = send_to_host(send_channel, token, dummy_send_val, dummy_send_aval,
|
|
callback.__name__, sharding=sharding)
|
|
send_channels.append(send_channel)
|
|
else:
|
|
for operand, operand_aval in zip(operands, operand_avals):
|
|
channel = ctx.module_context.new_channel()
|
|
token = send_to_host(channel, token, operand, operand_aval,
|
|
callback.__name__, sharding=sharding)
|
|
send_channels.append(channel)
|
|
|
|
recv_channels = []
|
|
outputs = []
|
|
for result_aval in result_avals:
|
|
channel = ctx.module_context.new_channel()
|
|
assert isinstance(result_aval, core.ShapedArray)
|
|
token, out = receive_from_host(channel, token, result_aval,
|
|
callback.__name__, sharding=sharding)
|
|
outputs.append(out)
|
|
recv_channels.append(channel)
|
|
ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
|
|
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
|
recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter
|
|
ctx.module_context.add_host_callback(ifrt_callback)
|
|
return outputs, token
|
|
|
|
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
|
|
if minor_to_major is None:
|
|
# Needed for token layouts
|
|
layout = np.zeros((0,), dtype="int64")
|
|
else:
|
|
layout = np.array(minor_to_major, dtype="int64")
|
|
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
|
|
|
|
def _aval_to_default_layouts(aval):
|
|
avals = [core.physical_aval(aval)]
|
|
# Row major order is default for `NumPy`.
|
|
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
|
|
|
|
def emit_python_callback(
|
|
ctx: LoweringRuleContext, callback, token: Any | None,
|
|
operands: Sequence[ir.Value], operand_avals: Sequence[core.ShapedArray],
|
|
result_avals: Sequence[core.ShapedArray],
|
|
has_side_effect: bool, *, sharding: xc.OpSharding | None = None,
|
|
operand_layouts: Sequence[Sequence[int] | None] | None = None,
|
|
result_layouts: Sequence[Sequence[int] | None] | None = None,
|
|
) -> tuple[Sequence[ir.Value], Any, Any]:
|
|
"""Emits MLIR that calls back to a provided Python function."""
|
|
if len(ctx.module_context.platforms) > 1:
|
|
raise NotImplementedError("multi-platform lowering for python_callback")
|
|
platform = ctx.module_context.platforms[0]
|
|
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
|
raise ValueError(
|
|
f"`EmitPythonCallback` not supported on {platform} backend.")
|
|
backend = ctx.module_context.backend
|
|
result_shapes = util.flatten(
|
|
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
|
|
operand_shapes = util.flatten(
|
|
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
|
|
# Handling layouts
|
|
if operand_layouts is None:
|
|
operand_layouts = util.concatenate(
|
|
map(_aval_to_default_layouts, operand_avals))
|
|
operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
|
|
if result_layouts is None:
|
|
result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
|
|
result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)
|
|
|
|
# First we apply checks to ensure output shapes and dtypes match the expected
|
|
# ones.
|
|
def _wrapped_callback(*args):
|
|
out_vals = callback(*args)
|
|
if len(out_vals) != len(result_avals):
|
|
raise RuntimeError(
|
|
"Mismatched number of outputs from callback. "
|
|
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
|
|
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
|
|
out_vals = tuple(np.asarray(a) for a in out_vals)
|
|
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
|
|
if out_val.shape != out_aval.shape:
|
|
raise RuntimeError(
|
|
f"Incorrect output shape for return value #{i}: "
|
|
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
|
|
if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype):
|
|
raise RuntimeError(
|
|
"Cannot return 64-bit values when `jax_enable_x64` is disabled. "
|
|
f"Actual: {out_val.dtype}")
|
|
if out_val.dtype != out_aval.dtype:
|
|
raise RuntimeError(
|
|
f"Incorrect output dtype for return value #{i}: "
|
|
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")
|
|
|
|
if platform == "tpu":
|
|
# On TPU we cannot receive empty arrays. So, we return from the wrapped
|
|
# callback only the non-empty results, and we will create empty constants
|
|
# in the receiving computation.
|
|
# TODO(b/238239458): fix TPU Recv to work with empty arrays.
|
|
non_empty_out_vals = tuple(
|
|
out_val
|
|
for out_val, result_aval in zip(out_vals, result_avals)
|
|
if not is_empty_shape(result_aval.shape))
|
|
return non_empty_out_vals
|
|
else:
|
|
return out_vals
|
|
|
|
if platform == "tpu":
|
|
non_empty_result_avals, non_empty_result_shapes = util.unzip2([
|
|
(aval, shape)
|
|
for aval, shape in zip(result_avals, result_shapes)
|
|
if not is_empty_shape(aval.shape)])
|
|
non_empty_outputs, token = _emit_tpu_python_callback(
|
|
backend, ctx, _wrapped_callback, token,
|
|
operands, operand_avals, operand_shapes,
|
|
non_empty_result_avals, non_empty_result_shapes,
|
|
sharding=sharding)
|
|
non_empty_outputs_iter = iter(non_empty_outputs)
|
|
outputs = [
|
|
ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
|
|
if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
|
|
for result_aval in result_avals]
|
|
return outputs, token, None
|
|
|
|
result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals])
|
|
if token:
|
|
|
|
callback_without_token = _wrapped_callback
|
|
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
|
|
return (token, *callback_without_token(*args))
|
|
|
|
operand_shapes = [
|
|
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
|
|
]
|
|
result_shapes = [
|
|
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
|
|
]
|
|
operands = [token, *operands]
|
|
result_types = [token_type()[0], *result_types]
|
|
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
|
|
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
|
|
callback_descriptor, ifrt_callback = (
|
|
backend.get_emit_python_callback_descriptor(_wrapped_callback,
|
|
operand_shapes,
|
|
result_shapes))
|
|
ctx.module_context.add_host_callback(ifrt_callback)
|
|
descriptor_operand = ir_constant(callback_descriptor)
|
|
callback_operands = [descriptor_operand, *operands]
|
|
if operand_mlir_layouts is not None:
|
|
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
|
|
result_type = ir.TupleType.get_tuple(result_types)
|
|
call_target_name = ("xla_python_gpu_callback"
|
|
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
|
|
result = hlo.CustomCallOp(
|
|
[result_type],
|
|
callback_operands,
|
|
call_target_name=ir.StringAttr.get(call_target_name),
|
|
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
|
api_version=i32_attr(2),
|
|
called_computations=ir.ArrayAttr.get([]),
|
|
backend_config=ir.StringAttr.get(str(callback_descriptor)),
|
|
operand_layouts=(
|
|
None if operand_mlir_layouts is None
|
|
else ir.ArrayAttr.get(operand_mlir_layouts)),
|
|
result_layouts=(
|
|
None if result_mlir_layouts is None
|
|
else ir.ArrayAttr.get(result_mlir_layouts)))
|
|
if sharding is not None:
|
|
set_sharding(result, sharding)
|
|
results = [
|
|
hlo.get_tuple_element(result, i32_attr(i))
|
|
for i in range(len(result_types))
|
|
]
|
|
if token:
|
|
token, *results = results
|
|
return results, token, ifrt_callback
|
|
|
|
def build_mlir_module_helper(
|
|
closed_jaxpr: core.ClosedJaxpr, *, name: str,
|
|
platforms: Sequence[str],
|
|
backend_or_name: str, axis_context: AxisContext) -> ir.Module:
|
|
"""Helper to generate pmap-style XLA computations for custom partitioners."""
|
|
unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects)
|
|
if unlowerable_effects:
|
|
raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}')
|
|
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
|
|
backend_or_name=backend_or_name, ordered_effects=[],
|
|
name_stack=source_info_util.NameStack(),
|
|
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
|
|
axis_context=axis_context, platforms=platforms,
|
|
lowering_parameters=LoweringParameters())
|
|
return lowering_result.module
|
|
|
|
def custom_call(
|
|
call_target_name: str,
|
|
*,
|
|
result_types: Sequence[ir.Type],
|
|
operands: Sequence[ir.Value],
|
|
backend_config: str | bytes | dict[str, ir.Attribute] = "",
|
|
has_side_effect: bool = False,
|
|
result_shapes: Sequence[ir.Value] | None = None,
|
|
called_computations: Sequence[str] = (),
|
|
api_version: int = 2,
|
|
operand_output_aliases: dict[int, int] | None = None,
|
|
operand_layouts: Sequence[Sequence[int]] | None = None,
|
|
result_layouts: Sequence[Sequence[int]] | None = None,
|
|
extra_attributes: dict[str, ir.Attribute] | None = None,
|
|
) -> ir.Operation:
|
|
"""Helper function for building an hlo.CustomCall.
|
|
|
|
Args:
|
|
call_target_name: the name of the custom call target
|
|
result_types: the MLIR types of the results of the custom call
|
|
operands: the MLIR IR values that are arguments to the custom call
|
|
backend_config: an opaque string passed to the custom call kernel
|
|
has_side_effect: if True, marks the custom call as effectful
|
|
result_shapes: tensors that represent the result shapes, to be used when
|
|
the results have dynamic shapes. If not-None, its length must match the
|
|
number of the results.
|
|
called_computations: the list of function names called by the custom call.
|
|
api_version: the ABI contract version of the custom call
|
|
operand_output_aliases: a dict mapping operand numbers to outputs they alias
|
|
operand_layouts: a sequence of layouts (dimension orders) for each operand
|
|
result_layouts: a sequence of layouts (dimension orders) for each result
|
|
extra_attributes: additional IR attributes to apply to the custom_call.
|
|
"""
|
|
operands = list(operands)
|
|
|
|
if backend_config is None:
|
|
backend_config_attr = ir.StringAttr.get("")
|
|
elif isinstance(backend_config, (str, bytes)):
|
|
backend_config_attr = ir.StringAttr.get(backend_config)
|
|
elif isinstance(backend_config, dict):
|
|
# TODO(necula): it seems that the CustomCallOp constructor requires that
|
|
# backend_config_attr be a string attribute, even though in some cases we
|
|
# need it to be a DictAttr, e.g., for ApproxTopK on TPU.
|
|
# "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
|
|
# To workaround this limitation we first set it to the empty string and we
|
|
# use an unregistered attribute mhlo.backend_config to hold the DictAttr.
|
|
# We must also use api_version=1 to ensure that mhlo.backend_config is
|
|
# handled properly.
|
|
backend_config_attr = ir.StringAttr.get("")
|
|
api_version = 1
|
|
else:
|
|
raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
|
|
attributes = dict(
|
|
call_target_name=ir.StringAttr.get(call_target_name),
|
|
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
|
backend_config=backend_config_attr,
|
|
api_version=i32_attr(api_version),
|
|
called_computations=ir.ArrayAttr.get(
|
|
[ir.FlatSymbolRefAttr.get(name) for name in called_computations]
|
|
),
|
|
)
|
|
if operand_output_aliases is not None:
|
|
attributes["output_operand_aliases"] = ir.ArrayAttr.get([
|
|
hlo.OutputOperandAlias.get(
|
|
# if len(result_types) == 1 then the aliasing refers implicitly to
|
|
# the only output.
|
|
output_tuple_indices=[output_idx] if len(result_types) > 1 else [],
|
|
operand_index=input_idx,
|
|
operand_tuple_indices=[],
|
|
)
|
|
for input_idx, output_idx in (operand_output_aliases.items() or ())
|
|
])
|
|
|
|
if extra_attributes is not None:
|
|
attributes.update(extra_attributes)
|
|
|
|
if result_shapes is not None:
|
|
# We add the result_shapes at the end of the operands, and must pass
|
|
# the indices_of_output_operands attribute. This attribute is not yet
|
|
# accepted by the CustomCall constructor, so we use build_generic
|
|
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
|
|
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
|
|
dtype=np.int64))
|
|
if operand_layouts is not None:
|
|
assert len(operand_layouts) == len(operands), (operand_layouts, operands)
|
|
operand_layouts = list(operand_layouts) + [(0,)] * len(result_shapes)
|
|
operands = list(operands) + list(result_shapes)
|
|
|
|
if operand_layouts is not None:
|
|
attributes["operand_layouts"] = ir.ArrayAttr.get([
|
|
ir.DenseIntElementsAttr.get(
|
|
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
|
type=ir.IndexType.get()) for l in operand_layouts
|
|
])
|
|
if result_layouts is not None:
|
|
assert result_layouts is not None
|
|
assert len(result_layouts) == len(result_types), (
|
|
result_layouts, result_types)
|
|
attributes["result_layouts"] = ir.ArrayAttr.get([
|
|
ir.DenseIntElementsAttr.get(
|
|
np.atleast_1d(np.asarray(l, dtype=np.int64)),
|
|
type=ir.IndexType.get()) for l in result_layouts
|
|
])
|
|
|
|
op = hlo.CustomCallOp.build_generic(results=result_types, operands=operands,
|
|
attributes=attributes)
|
|
if isinstance(backend_config, dict):
|
|
backend_config_attr = ir.DictAttr.get(backend_config)
|
|
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
|
return op
|
|
|
|
|
|
def reduce_window(
|
|
ctx: LoweringRuleContext,
|
|
*,
|
|
# Base name to be used for the reducer function
|
|
reducer_name: str,
|
|
# Compute the reducer body given the reducer.
|
|
reducer_body: Callable[[ir.Block], Sequence[ir.Value]],
|
|
operands: Sequence[ir.Value],
|
|
init_values: Sequence[ir.Value],
|
|
init_values_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue],
|
|
window_dimensions, window_strides, padding, base_dilation, window_dilation):
|
|
"""Builds a ReduceWindowOp, with support for dynamic shapes."""
|
|
|
|
scalar_types = [aval_to_ir_type(aval) for aval in init_values_avals]
|
|
if any(not core.is_constant_shape(s)
|
|
for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
|
|
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
|
|
# spatial dimension.
|
|
int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
|
|
def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
|
|
pads = eval_dynamic_shape_as_tensor(ctx, pad_lo_hi) # i32[2]
|
|
return hlo.reshape(int2d, pads)
|
|
d_padding = hlo.concatenate(list(map(prep_one_pad, padding)), i64_attr(0))
|
|
# Build the reducer
|
|
reducer_type = ir.FunctionType.get(scalar_types + scalar_types,
|
|
scalar_types)
|
|
with ir.InsertionPoint.at_block_begin(ctx.module_context.module.body):
|
|
reducer = func_dialect.FuncOp(reducer_name, reducer_type)
|
|
ctx.module_context.symbol_table.insert(reducer)
|
|
entry_block = reducer.add_entry_block()
|
|
with ir.InsertionPoint(entry_block):
|
|
hlo.return_(reducer_body(entry_block))
|
|
|
|
rw = custom_call(
|
|
"stablehlo.dynamic_reduce_window",
|
|
result_types=list(map(aval_to_ir_type, out_avals)),
|
|
operands=[
|
|
*operands, *init_values,
|
|
eval_dynamic_shape_as_tensor(ctx, window_dimensions),
|
|
eval_dynamic_shape_as_tensor(ctx, window_strides),
|
|
eval_dynamic_shape_as_tensor(ctx, base_dilation),
|
|
eval_dynamic_shape_as_tensor(ctx, window_dilation),
|
|
d_padding],
|
|
called_computations=[reducer.name.value],
|
|
)
|
|
else: # Static shapes
|
|
rw = hlo.ReduceWindowOp(
|
|
list(map(aval_to_ir_type, out_avals)),
|
|
operands, init_values,
|
|
dense_int_array_v6(window_dimensions),
|
|
window_strides=dense_int_array_v6(window_strides),
|
|
base_dilations=dense_int_array_v6(base_dilation),
|
|
window_dilations=dense_int_array_v6(window_dilation),
|
|
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
|
shape=(len(padding), 2)))
|
|
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
|
|
with ir.InsertionPoint(reducer):
|
|
hlo.return_(reducer_body(reducer))
|
|
return rw.results
|
|
|
|
|
|
def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
|
|
"""Refines the polymorphic shapes inside a module.
|
|
|
|
Given a module with static input shapes, but using dynamic shapes due to
|
|
shape polymorphism, runs shape refinement to resolve all the dynamic shapes.
|
|
Then verifies that there are no more dynamic shapes in the module.
|
|
"""
|
|
try:
|
|
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
|
module_to_bytecode(module), enable_shape_assertions=True,
|
|
validate_static_shapes=True)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
"Error refining shapes. " +
|
|
dump_module_message(module, "before_refine_polymorphic_shapes")) from e
|
|
|
|
context = make_ir_context()
|
|
with context:
|
|
return ir.Module.parse(refined_module_str)
|